6. Interfacing with existing types#
The wrenfold code-generator can integrate with existing types from your codebase, both as function arguments and return values. This can unlock some useful functionality:
Leverage existing group or manifold types in your codebase. For example, if your optimization is implemented in GTSAM you might wish to pass
gtsam::Pose3
directly to and from generated functions.Pass parameter structs directly to a generated function without breaking them into scalar/vector pieces.
In general, this can improve type-safety and legibility by eliminating type conversion clutter where generated methods are invoked. Our overall philosophy is BYOT (Bring Your Own Types): Rather than code-generate new structs, wrenfold aims to facilitate relatively easy integration of structs that exist a priori in the user codebase.
There are two kinds of user-provided types:
dataclass types, which we will outline in this section.
Opaque
types, which are primarily useful in combination with external functions.
6.1. Defining a custom type#
In order to use an externally defined struct, wrenfold needs your assistance in two places:
We need a python definition of the object that lays out the struct members and their types.
The code-generator may need to be customized to emit valid calls to member accessors and constructors.
In this example we will implement support for a simple vec2
type. We assume the existence of a
C++ struct of the form:
namespace geo {
struct vec2 {
constexpr vec2(double x, double y) noexcept : x_(x), y_(y) {}
// ...
constexpr double x() const noexcept { return x_; }
constexpr double y() const noexcept { return y_; }
// ...
private:
double x_;
double y_;
};
} // namespace geo
In python, custom types are declared as dataclasses with type-annotated members. We declare a
symbolic equivalent of our vec2
type:
@dataclasses.dataclass
class Vec2:
"""Symbolic version of geo::vec2."""
x: type_annotations.FloatScalar
y: type_annotations.FloatScalar
def to_vector(self) -> sym.MatrixExpr:
return sym.vector(self.x, self.y)
@classmethod
def from_vector(cls, v: sym.MatrixExpr):
assert v.shape == (2, 1)
return cls(v[0], v[1])
When defining a dataclass for use with wrenfold, all members must be type annotated with:
One of the types from the type annotations module, or a similarly declared type that inherits from
sym.Expr
orsym.MatrixExpr
.Another custom dataclass type. Thus nested structs are supported.
Next, we will create a simple example function that uses our vector type. We can accept `Vec2
as
an argument, and return it as well:
def rotate_vector(angle: type_annotations.FloatScalar, v: Vec2):
"""Rotate vector `v` by `angle` radians."""
R = sym.matrix([[sym.cos(angle), -sym.sin(angle)], [sym.sin(angle), sym.cos(angle)]])
v_rot = R * v.to_vector()
# Compute the jacobian of `v_rot` wrt `angle`.
# This is a 2x1 vector, so we'll put it in Vec2.
v_rot_diff = v_rot.jacobian([angle])
# We also want to return `Vec2`:
return [
code_generation.ReturnValue(Vec2.from_vector(v_rot)),
code_generation.OutputArg(Vec2.from_vector(v_rot_diff), name="v_rot_D_angle")
]
6.2. Customizing code generation#
With our symbolic rotate_vector
method in hand, we are nearly ready to generate code. First we
need to make some minor customizations to the code formatter:
class CustomCppGenerator(code_generation.CppGenerator):
def format_get_field(self, element: ast.GetField) -> str:
"""
geo::vec2 members are private, so call the accessor method instead:
"""
if element.struct_type.python_type == Vec2:
return f"{self.format(element.arg)}.{element.field_name}()"
return self.super_format(element)
def format_custom_type(self, element: type_info.CustomType) -> str:
"""
Place our custom type into the `geo` namespace.
"""
if element.python_type == Vec2:
return 'geo::vec2'
return self.super_format(element)
The default C++ code-generation logic for constructors assumes initializer-list syntax, which is
already valid for our geo::vec2
type - we do not need to customize that. Now we leverage our new
custom generator:
code = code_generation.generate_function(func=rotate_vector, generator=CustomCppGenerator())
print(code)
template <typename Scalar>
geo::vec2 rotate_vector(const Scalar angle, const geo::vec2& v, geo::vec2& v_rot_D_angle)
{
const Scalar v001 = angle;
const Scalar v004 = v.y();
const Scalar v002 = std::sin(v001);
const Scalar v008 = v.x();
const Scalar v007 = std::cos(v001);
const Scalar v013 = v004 * v007 + v002 * v008;
const Scalar v010 = v007 * v008 + -(v002 * v004);
v_rot_D_angle = geo::vec2{
-v013,
v010
};
return geo::vec2{
v010,
v013
};
}
Voila - our generated methods uses geo::vec2
for input arguments, output arguments, and the
return value.
6.3. Emitting a custom constructor call#
Suppose we want to generate this method in Rust as well, and invoke a custom constructor
geo::Vec2::new(...)
. To override the construction logic, we implement
format_construct_custom_type
:
class CustomRustGenerator(code_generation.RustGenerator):
def format_get_field(self, element: ast.GetField) -> str:
if element.struct_type.python_type == Vec2:
return f"{self.format(element.arg)}.{element.field_name}()"
return self.super_format(element)
def format_custom_type(self, element: type_info.CustomType) -> str:
"""
Place our custom type into the `geo` crate.
"""
if element.python_type == Vec2:
return 'geo::Vec2'
return self.super_format(element)
def format_construct_custom_type(self, element: ast.ConstructCustomType) -> str:
if element.type.python_type == Vec2:
x = self.format(element.get_field_value('x'))
y = self.format(element.get_field_value('y'))
return f"geo::Vec2::new({x}, {y})"
return self.super_format(element)
The generated code now looks like:
#[inline]
#[allow(non_snake_case, clippy::unused_unit, clippy::collapsible_else_if,
clippy::needless_late_init, unused_variables)]
pub fn rotate_vector<>(angle: f64, v: &geo::Vec2, v_rot_D_angle: &mut geo::Vec2) -> geo::Vec2
{
let v001: f64 = angle;
let v004: f64 = v.y();
let v002: f64 = (v001).sin();
let v008: f64 = v.x();
let v007: f64 = (v001).cos();
let v013: f64 = v004 * v007 + v002 * v008;
let v010: f64 = v007 * v008 + -(v002 * v004);
*v_rot_D_angle = geo::Vec2::new(-v013, v010);
geo::Vec2::new(v010, v013)
}
Note
For a more complicated demonstration, refer to the custom_types example.