6. Generating Python

When iterating on a function design, it is sometimes useful to be able to evaluate the function in a python script or notebook. To that end, wrenfold can generate python functions that invoke the NumPy/PyTorch, and JAX APIs.

To directly generate a python callable, we can use generate_python():

# `step_np_func` will be a callable that accepts scalar `x`:
# `step_code` is a string of python code.
step_np_func, step_code = code_generation.generate_python(
    func=step_clamped,
    target=code_generation.PythonGeneratorTarget.NumPy)

y, optional_outputs = step_np_func(x=0.32, compute_df=True)

print(y)                # prints: 0.241664
print(optional_outputs) # prints: {'df': array([[1.3055999], [2.16]], dtype=float32)}

The listing above produces a callable step_np_func that operates on NumPy types. Note that optional output arguments are converted into optional return values when targeting Python. We pass compute_df=True to request that output df be computed, and it is returned in the dictionary optional_outputs. When a return value is also present (as is the case here), a tuple of the form (return value, optional argument dict) is returned.

Warning

By default, generated python functions use float32 types internally. This is to minimize impedance with frameworks like PyTorch and JAX, where 32-bit floating point is the default. You can change this behavior by specifying the generator_type argument and explicitly instantiating a PythonGenerator directly.

When targeting JAX or PyTorch, it is advantageous to leave conditionals like wrenfold.sym.where() in ternary form and convert them to jax.where or torch.where instead of Python conditional logic. generate_python will do this automatically:

step_jax_func, step_code = code_generation.generate_python(
    func=step_clamped,
    target=code_generation.PythonGeneratorTarget.JAX)

print(step_code)
# The generated JAX function:
def step_clamped(x: jnp.ndarray, compute_df: bool) -> T.Tuple[jnp.ndarray, T.Dict[str, jnp.ndarray]]:
    v002 = x
    v006 = jnp.where(
        (v002 < jnp.asarray(0, dtype=jnp.float32)),
        jnp.asarray(0, dtype=jnp.float32),
        v002,
    )
    # ... output truncated ...
    return (
        v009
        * v009
        * (
            jnp.asarray(3, dtype=jnp.float32) + jnp.asarray(2, dtype=jnp.float32) * v043
        ),
        dict(df=df),
    )

By leaving conditionals in ternary form (or “element selection” form), we retain the ability to use them during back-propagation. This behavior can be disabled (thereby producing if-statements) by specifying convert_ternaries=True.

Tip

For longer form examples of python generation, see jax_camera_model and cart-pole.