firedrake.ml.jax package¶
Submodules¶
firedrake.ml.jax.fem_operator module¶
- class firedrake.ml.jax.fem_operator.FiredrakeJaxOperator(F: ReducedFunctional)[source]¶
Bases:
object
JAX custom operator representing a set of Firedrake operations expressed as a reduced functional F.
FiredrakeJaxOperator executes forward and backward passes by directly calling the reduced functional F.
- Parameters:
F – The reduced functional to wrap.
- forward = None¶
- firedrake.ml.jax.fem_operator.fem_operator(F: ReducedFunctional) FiredrakeJaxOperator [source]¶
Cast a Firedrake reduced functional to a JAX operator.
The resulting
FiredrakeJaxOperator
will take JAX tensors as inputs and return JAX tensors as outputs.- Parameters:
F – The reduced functional to wrap.
- Returns:
A JAX custom operator that wraps the reduced functional F.
- Return type:
- firedrake.ml.jax.fem_operator.from_jax(x: jax.Array, V: WithGeometry | None = None) Function | Constant [source]¶
Convert a JAX tensor x into a Firedrake object.
- Parameters:
- Returns:
Firedrake object representing the JAX tensor x.
- Return type:
- firedrake.ml.jax.fem_operator.to_jax(x: Function | Vector | Constant, gather: bool | None = False, batched: bool | None = False, **kwargs) jax.Array [source]¶
Convert a Firedrake object x into a JAX tensor.
- Parameters:
x – Firedrake object to convert.
gather – If True, gather data from all processes
batched – If True, add a batch dimension to the tensor
kwargs –
- Additional arguments to be passed to the
jax.Array
constructor such as: device: device on which the tensor is allocated
dtype: the desired data type of returned tensor (default: type of x.dat.data)
- Additional arguments to be passed to the
- Returns:
JAX tensor representing the Firedrake object x.
- Return type:
firedrake.ml.jax.ml_operator module¶
- class firedrake.ml.jax.ml_operator.JaxOperator(*operands: Expr | BaseForm, function_space: WithGeometryBase, derivatives: tuple | None = None, argument_slots: tuple[BaseCoefficient | BaseArgument] | None, operator_data: dict | None = {})[source]¶
Bases:
MLOperator
External operator class representing machine learning models implemented in JAX.
The
JaxOperator
allows users to embed machine learning models implemented in JAX into PDE systems implemented in Firedrake. The actual evaluation of theJaxOperator
is delegated to the specified JAX model. Similarly, differentiation through theJaxOperator
class is achieved using JAX differentiation on the JAX model associated with theJaxOperator
object.- Parameters:
*operands – Operands of the
JaxOperator
.function_space – The function space the ML operator is mapping to.
derivatives – Tuple specifiying the derivative multiindex.
*argument_slots – Tuple containing the arguments of the linear form associated with the ML operator, i.e. the arguments with respect to which the ML operator is linear. Those arguments can be ufl.Argument objects, as a result of differentiation, or ufl.Coefficient objects, as a result of taking the action on a given function.
operator_data – Dictionary to stash external data specific to the ML operator. This dictionary must at least contain the following: (i) ‘model’: The machine learning model implemented in JaX (ii) ‘inputs_format’: The format of the inputs to the ML model: 0 for models acting globally on the inputs, 1 when acting locally/pointwise on the inputs. Other strategies can also be considered by subclassing the
JaxOperator
class.
- ufl_operands¶
- firedrake.ml.jax.ml_operator.ml_operator(model: Callable, function_space: WithGeometryBase, inputs_format: int | None = 0) Callable [source]¶
Helper function for instantiating the
JaxOperator
class.This function facilitates having a two-stage instantiation which dissociates between class arguments that are fixed, such as the function space or the ML model, and the operands of the operator, which may change, e.g. when the operator is used in a time-loop.
Example
# Stage 1: Partially initialise the operator. N = ml_operator(model, function_space=V) # Stage 2: Define the operands and use the operator in a UFL expression. F = (inner(grad(u), grad(v)) + inner(N(u), v) - inner(f, v)) * dx
- Parameters:
model – The JAX model to embed in Firedrake.
function_space – The function space into which the machine learning model is mapping.
inputs_format – The format of the input data of the ML model: 0 for models acting globally on the inputs, 1 when acting locally/pointwise on the inputs. Other strategies can also be considered by subclassing the
JaxOperator
class.
- Returns:
The partially initialised
JaxOperator
class.- Return type: