firedrake.ml.jax package

Submodules

firedrake.ml.jax.fem_operator module

class firedrake.ml.jax.fem_operator.FiredrakeJaxOperator(F: ReducedFunctional)[source]

Bases: object

JAX custom operator representing a set of Firedrake operations expressed as a reduced functional F.

FiredrakeJaxOperator executes forward and backward passes by directly calling the reduced functional F.

Parameters:

F – The reduced functional to wrap.

bwd(_, grad_output: jax.Array) jax.Array[source]

Backward pass of the JAX custom operator.

forward = None
fwd(*x_P: jax.Array) jax.Array[source]

Forward pass of the JAX custom operator.

firedrake.ml.jax.fem_operator.fem_operator(F: ReducedFunctional) FiredrakeJaxOperator[source]

Cast a Firedrake reduced functional to a JAX operator.

The resulting FiredrakeJaxOperator will take JAX tensors as inputs and return JAX tensors as outputs.

Parameters:

F – The reduced functional to wrap.

Returns:

A JAX custom operator that wraps the reduced functional F.

Return type:

firedrake.ml.jax.fem_operator.FiredrakeJaxOperator

firedrake.ml.jax.fem_operator.from_jax(x: jax.Array, V: WithGeometry | None = None) Function | Constant[source]

Convert a JAX tensor x into a Firedrake object.

Parameters:
  • x – JAX tensor to convert.

  • V – Function space of the corresponding Function or None when x is to be mapped to a Constant.

Returns:

Firedrake object representing the JAX tensor x.

Return type:

firedrake.function.Function or firedrake.constant.Constant

firedrake.ml.jax.fem_operator.to_jax(x: Function | Vector | Constant, gather: bool | None = False, batched: bool | None = False, **kwargs) jax.Array[source]

Convert a Firedrake object x into a JAX tensor.

Parameters:
  • x – Firedrake object to convert.

  • gather – If True, gather data from all processes

  • batched – If True, add a batch dimension to the tensor

  • kwargs

    Additional arguments to be passed to the jax.Array constructor such as:
    • device: device on which the tensor is allocated

    • dtype: the desired data type of returned tensor (default: type of x.dat.data)

Returns:

JAX tensor representing the Firedrake object x.

Return type:

jax.Array

firedrake.ml.jax.ml_operator module

class firedrake.ml.jax.ml_operator.JaxOperator(*operands: Expr | BaseForm, function_space: WithGeometryBase, derivatives: tuple | None = None, argument_slots: tuple[BaseCoefficient | BaseArgument] | None, operator_data: dict | None = {})[source]

Bases: MLOperator

External operator class representing machine learning models implemented in JAX.

The JaxOperator allows users to embed machine learning models implemented in JAX into PDE systems implemented in Firedrake. The actual evaluation of the JaxOperator is delegated to the specified JAX model. Similarly, differentiation through the JaxOperator class is achieved using JAX differentiation on the JAX model associated with the JaxOperator object.

Parameters:
  • *operands – Operands of the JaxOperator.

  • function_space – The function space the ML operator is mapping to.

  • derivatives – Tuple specifiying the derivative multiindex.

  • *argument_slots – Tuple containing the arguments of the linear form associated with the ML operator, i.e. the arguments with respect to which the ML operator is linear. Those arguments can be ufl.Argument objects, as a result of differentiation, or ufl.Coefficient objects, as a result of taking the action on a given function.

  • operator_data – Dictionary to stash external data specific to the ML operator. This dictionary must at least contain the following: (i) ‘model’: The machine learning model implemented in JaX (ii) ‘inputs_format’: The format of the inputs to the ML model: 0 for models acting globally on the inputs, 1 when acting locally/pointwise on the inputs. Other strategies can also be considered by subclassing the JaxOperator class.

ufl_operands
firedrake.ml.jax.ml_operator.custom_vjp(_, **kwargs)[source]
firedrake.ml.jax.ml_operator.ml_operator(model: Callable, function_space: WithGeometryBase, inputs_format: int | None = 0) Callable[source]

Helper function for instantiating the JaxOperator class.

This function facilitates having a two-stage instantiation which dissociates between class arguments that are fixed, such as the function space or the ML model, and the operands of the operator, which may change, e.g. when the operator is used in a time-loop.

Example

# Stage 1: Partially initialise the operator.
N = ml_operator(model, function_space=V)
# Stage 2: Define the operands and use the operator in a UFL expression.
F = (inner(grad(u), grad(v)) + inner(N(u), v) - inner(f, v)) * dx
Parameters:
  • model – The JAX model to embed in Firedrake.

  • function_space – The function space into which the machine learning model is mapping.

  • inputs_format – The format of the input data of the ML model: 0 for models acting globally on the inputs, 1 when acting locally/pointwise on the inputs. Other strategies can also be considered by subclassing the JaxOperator class.

Returns:

The partially initialised JaxOperator class.

Return type:

Callable

Module contents