7. Calling external functions

It is sometimes desirable to call an external handwritten function from a code-generated one, and use the result as part of further symbolic expressions. This can enable a few useful behaviors:

  • Evaluate complex logic that cannot easily be expressed in a functional expression tree. For example, solving a small numerical optimization iteratively within a larger expression.

  • Insert custom error-checking or logging logic into generated functions.

  • Interface with user-provided types that are not adequately expressed as a dataclass. For example, you might wish to pass a dynamically sized buffer of values, and perform bilinear interpolation from within the generated code.

These use cases can be achieved by declaring an external function. There are a few caveats to keep in mind:

  • By necessity, wrenfold must assume all calls to external functions are pure (without any side effects). Any two identical calls are assumed to be interchangeable and will be de-duplicated during code-generation.

  • Because external functions are effectively a black box, we cannot propagate the derivative through them.

7.1. Declaring an external function

In this example, we will pass a lookup table to our generated function. To begin with, we need a type to represent the table. We do this by inheriting from wrenfold.type_annotations.Opaque:

class LookupTable(type_annotations.Opaque):
    """
    A placeholder we will map to our actual type during the code-generation step.
    """

The type itself requires no further additions, it is merely a placeholder that we will later map to a real type in the target language.

Next, we declare a function to represent our lookup operation:

interpolate_table = external_functions.declare_external_function(
    name="interpolate_table",
    arguments=[("table", LookupTable), ("arg", type_annotations.FloatScalar)],
    return_type=type_annotations.FloatScalar)  # [interpolate_table_end]

interpolate_table is an instance of wrenfold.external_functions.ExternalFunction. We can call it with symbolic expressions, provided they match the expected types we specified in the arguments list.

Now we can define a symbolic function that uses interpolate_table. We will write a function that computes the bearing vector between two points \(\mathbf{v} = \mathbf{p}_1 - \mathbf{p}_0\), and uses the direction angle of vector \(\theta = \text{atan2}\left(\mathbf{v}_y, \mathbf{v}_x\right)\) as an argument to the lookup table:

def lookup_angle(table: LookupTable, p_0: type_annotations.Vector2, p_1: type_annotations.Vector2):
    """
    Compute bearing angle between two points, and use it as an argument to our lookup table.
    """
    v = p_1 - p_0
    angle = sym.atan2(v[1], v[0])

    # Normalize between [0, 1] (where 0 corresponds to -pi, and 1 corresponds to pi).
    angle_normalized = (angle + sym.pi) / (2 * sym.pi)

    # Perform the lookup.
    table_value = interpolate_table(table=table, arg=angle_normalized)

    # Do some more symbolic operations with the result:
    result = table_value * (p_1 - p_0).squared_norm()
    return [
        code_generation.ReturnValue(result),
    ]

To emit actual code for our LookupTable type and interpolate_table function, we customize the code generator:

class CustomCppGenerator(code_generation.CppGenerator):

    def format_call_external_function(self, element: ast.CallExternalFunction) -> str:
        """
        Place our external function in the ``utilities`` namespace.
        """
        if element.function == interpolate_table:
            args = ', '.join(self.format(x) for x in element.args)
            return f'utilities::{element.function.name}({args})'
        return self.super_format(element)

    def format_custom_type(self, element: type_info.CustomType) -> str:
        """
        Assume the lookup table is implemented as a std::vector<double>.
        """
        if element.python_type == LookupTable:
            return 'std::vector<double>'
        return self.super_format(element)


code = code_generation.generate_function(func=lookup_angle, generator=CustomCppGenerator())
print(code)

Which produces the following C++:

// Our generated method correctly accepts a `std::vector<double>`, and invokes
// the appropriately-namespaced `interpolate_table`.
template <typename Scalar, typename T1, typename T2>
Scalar lookup_angle(const std::vector<double> &table, const T1 &p_0, const T2 &p_1)
{
    auto _p_0 = wf::make_input_span<2, 1>(p_0);
    auto _p_1 = wf::make_input_span<2, 1>(p_1);

    // ...

    const Scalar v002 = _p_0(0, 0);
    const Scalar v008 = _p_0(1, 0);
    const Scalar v000 = _p_1(0, 0);
    const Scalar v007 = _p_1(1, 0);
    const Scalar v005 = v000 + -v002;
    const Scalar v010 = v007 + -v008;
    return (v005 * v005 + v010 * v010) *
            utilities::interpolate_table(
                table, static_cast<Scalar>(0.5) *
                            (static_cast<Scalar>(M_PI) + std::atan2(v010, v005)) *
                            (static_cast<Scalar>(1) / static_cast<Scalar>(M_PI)));
}

Note that nowhere did we explicitly tell wrenfold anything about the nature of std::vector<double>. For the most part we would prefer not to, since it would induce a great deal more complexity in the code-generator. Instead we treat it as an opaque type that can be passed through the generated function to a handwritten one.