7. Calling external functions¶
It is sometimes desirable to call an external handwritten function from a code-generated one, and use the result as part of further symbolic expressions. This can enable a few useful behaviors:
Evaluate complex logic that cannot easily be expressed in a functional expression tree. For example, solving a small numerical optimization iteratively within a larger expression.
Insert custom error-checking or logging logic into generated functions.
Interface with user-provided types that are not adequately expressed as a dataclass. For example, you might wish to pass a dynamically sized buffer of values, and perform bilinear interpolation from within the generated code.
These use cases can be achieved by declaring an external function. There are a few caveats to keep in mind:
By necessity, wrenfold must assume all calls to external functions are pure (without any side effects). Any two identical calls are assumed to be interchangeable and will be de-duplicated during code-generation.
Because external functions are effectively a black box, we cannot propagate the derivative through them.
7.1. Declaring an external function¶
In this example, we will pass a lookup table to our generated function. To begin with, we need a
type to represent the table. We do this by inheriting from
wrenfold.type_annotations.Opaque
:
class LookupTable(type_annotations.Opaque):
"""
A placeholder we will map to our actual type during the code-generation step.
"""
The type itself requires no further additions, it is merely a placeholder that we will later map to a real type in the target language.
Next, we declare a function to represent our lookup operation:
interpolate_table = external_functions.declare_external_function(
name="interpolate_table",
arguments=[("table", LookupTable), ("arg", type_annotations.FloatScalar)],
return_type=type_annotations.FloatScalar) # [interpolate_table_end]
interpolate_table
is an instance of wrenfold.external_functions.ExternalFunction
. We
can call it with symbolic expressions, provided they match the expected types we specified in the
arguments
list.
Now we can define a symbolic function that uses interpolate_table
. We will write a function that
computes the bearing vector between two points \(\mathbf{v} = \mathbf{p}_1 - \mathbf{p}_0\), and
uses the direction angle of vector
\(\theta = \text{atan2}\left(\mathbf{v}_y, \mathbf{v}_x\right)\) as an argument to the lookup
table:
def lookup_angle(table: LookupTable, p_0: type_annotations.Vector2, p_1: type_annotations.Vector2):
"""
Compute bearing angle between two points, and use it as an argument to our lookup table.
"""
v = p_1 - p_0
angle = sym.atan2(v[1], v[0])
# Normalize between [0, 1] (where 0 corresponds to -pi, and 1 corresponds to pi).
angle_normalized = (angle + sym.pi) / (2 * sym.pi)
# Perform the lookup.
table_value = interpolate_table(table=table, arg=angle_normalized)
# Do some more symbolic operations with the result:
result = table_value * (p_1 - p_0).squared_norm()
return [
code_generation.ReturnValue(result),
]
To emit actual code for our LookupTable
type and interpolate_table
function, we customize
the code generator:
class CustomCppGenerator(code_generation.CppGenerator):
def format_call_external_function(self, element: ast.CallExternalFunction) -> str:
"""
Place our external function in the ``utilities`` namespace.
"""
if element.function == interpolate_table:
args = ', '.join(self.format(x) for x in element.args)
return f'utilities::{element.function.name}({args})'
return self.super_format(element)
def format_custom_type(self, element: type_info.CustomType) -> str:
"""
Assume the lookup table is implemented as a std::vector<double>.
"""
if element.python_type == LookupTable:
return 'std::vector<double>'
return self.super_format(element)
code = code_generation.generate_function(func=lookup_angle, generator=CustomCppGenerator())
print(code)
Which produces the following C++:
// Our generated method correctly accepts a `std::vector<double>`, and invokes
// the appropriately-namespaced `interpolate_table`.
template <typename Scalar, typename T1, typename T2>
Scalar lookup_angle(const std::vector<double> &table, const T1 &p_0, const T2 &p_1)
{
auto _p_0 = wf::make_input_span<2, 1>(p_0);
auto _p_1 = wf::make_input_span<2, 1>(p_1);
// ...
const Scalar v002 = _p_0(0, 0);
const Scalar v008 = _p_0(1, 0);
const Scalar v000 = _p_1(0, 0);
const Scalar v007 = _p_1(1, 0);
const Scalar v005 = v000 + -v002;
const Scalar v010 = v007 + -v008;
return (v005 * v005 + v010 * v010) *
utilities::interpolate_table(
table, static_cast<Scalar>(0.5) *
(static_cast<Scalar>(M_PI) + std::atan2(v010, v005)) *
(static_cast<Scalar>(1) / static_cast<Scalar>(M_PI)));
}
Note that nowhere did we explicitly tell wrenfold anything about the nature of
std::vector<double>
. For the most part we would prefer not to, since it would induce a great
deal more complexity in the code-generator. Instead we treat it as an opaque type that can be
passed through the generated function to a handwritten one.