Source code for firedrake.ml.jax.fem_operator

import os
try:
    import jax
    import jax.numpy as jnp
except ImportError:
    if "FIREDRAKE_BUILDING_DOCS" in os.environ:
        # If building docs and jax is not installed, produce a mock `jax.custom_vjp` function.
        # This is sufficient for the intersphinx reference to resolve.
        from types import SimpleNamespace
        jax = SimpleNamespace()

        def custom_vjp(_, **kwargs):
            pass

        jax.custom_vjp = custom_vjp
    else:
        raise ImportError("JAX is not installed and is required to use the FiredrakeJaxOperator.")

import collections
import numpy as np
from functools import partial
from typing import Union, Optional

from firedrake.function import Function
from firedrake.cofunction import Cofunction
from firedrake.vector import Vector
from firedrake.functionspaceimpl import WithGeometry
from firedrake.constant import Constant
from firedrake_citations import Citations

from pyadjoint.reduced_functional import ReducedFunctional


__all__ = ['FiredrakeJaxOperator', 'fem_operator', 'to_jax', 'from_jax']


[docs] class FiredrakeJaxOperator: """JAX custom operator representing a set of Firedrake operations expressed as a reduced functional `F`. `FiredrakeJaxOperator` executes forward and backward passes by directly calling the reduced functional `F`. Parameters ---------- F The reduced functional to wrap. """ def __init__(self, F: ReducedFunctional): super(FiredrakeJaxOperator, self).__init__() self.F = F self.V_controls = [c.control.function_space() for c in F.controls] self.V_output = _extract_function_space(F.functional) # Register forward and backward passes self.forward.defvjp(type(self).fwd, type(self).bwd) @partial(jax.custom_vjp, nondiff_argnums=(0,)) def forward(self, *x_P: "jax.Array") -> "jax.Array": """Forward pass of the JAX custom operator. Parameters ---------- *x_P JAX tensors representing the inputs to the Firedrake operator `F`. Returns ------- jax.Array JAX tensor representing the output of the Firedrake operator `F`. """ # Convert JAX input (i.e. controls) to Firedrake x_F = [from_jax(xi, Vi) for xi, Vi in zip(x_P, self.V_controls)] # Forward operator: delegated to pyadjoint.ReducedFunctional which recomputes the blocks on the tape y_F = self.F(x_F) # Convert Firedrake output to JAX y_P = to_jax(y_F) return y_P
[docs] def fwd(self, *x_P: "jax.Array") -> "jax.Array": """Forward pass of the JAX custom operator. """ return type(self).forward(self, *x_P), ()
[docs] def bwd(self, _, grad_output: "jax.Array") -> "jax.Array": """Backward pass of the JAX custom operator. """ V = self.V_output # Convert JAX gradient to Firedrake V_adj = V.dual() if V else V adj_input = from_jax(grad_output, V_adj) if isinstance(adj_input, Constant) and adj_input.ufl_shape == (): # This will later on result in an `AdjFloat` adjoint input instead of a Constant adj_input = float(adj_input) # Compute adjoint model of `F`: delegated to pyadjoint.ReducedFunctional adj_output = self.F.derivative(adj_input=adj_input) # Tuplify adjoint output adj_output = (adj_output,) if not isinstance(adj_output, collections.abc.Sequence) else adj_output return tuple(to_jax(di) for di in adj_output)
[docs] def fem_operator(F: ReducedFunctional) -> FiredrakeJaxOperator: """Cast a Firedrake reduced functional to a JAX operator. The resulting :class:`~FiredrakeJaxOperator` will take JAX tensors as inputs and return JAX tensors as outputs. Parameters ---------- F The reduced functional to wrap. Returns ------- firedrake.ml.jax.fem_operator.FiredrakeJaxOperator A JAX custom operator that wraps the reduced functional `F`. """ Citations().register("Bouziani2024") if not isinstance(F, ReducedFunctional): raise ValueError("F must be a ReducedFunctional") jax_op = FiredrakeJaxOperator(F) # `jax_op.forward` currently does not work and causes issues related to the function # signature during JAX compilation. As a workaround, we use `functools.partial` instead. return partial(FiredrakeJaxOperator.forward, jax_op)
def _extract_function_space(x: Union[float, Function, Vector]) -> Union[WithGeometry, None]: """Extract the function space from a Firedrake object `x`. Parameters ---------- x Firedrake object from which to extract the function space. Returns ------- firedrake.functionspaceimpl.WithGeometry or None Extracted function space. """ if isinstance(x, (Function, Cofunction)): return x.function_space() elif isinstance(x, Vector): return _extract_function_space(x.function) elif isinstance(x, float): return None else: raise ValueError("Cannot infer the function space of %s" % x)
[docs] def to_jax(x: Union[Function, Vector, Constant], gather: Optional[bool] = False, batched: Optional[bool] = False, **kwargs) -> "jax.Array": """Convert a Firedrake object `x` into a JAX tensor. Parameters ---------- x Firedrake object to convert. gather If True, gather data from all processes batched If True, add a batch dimension to the tensor kwargs Additional arguments to be passed to the :class:`jax.Array` constructor such as: - device: device on which the tensor is allocated - dtype: the desired data type of returned tensor (default: type of `x.dat.data`) Returns ------- jax.Array JAX tensor representing the Firedrake object `x`. """ if isinstance(x, (Function, Cofunction, Vector)): if gather: # Gather data from all processes x_P = jnp.array(x.vector().gather(), **kwargs) else: # Use local data x_P = jnp.array(x.vector().get_local(), **kwargs) if batched: # Default behaviour: add batch dimension after converting to JAX return x_P[None, :] return x_P elif isinstance(x, Constant): return jnp.array(x.values(), **kwargs) elif isinstance(x, (float, int)): if isinstance(x, float): # Set double-precision kwargs['dtype'] = jnp.double return jnp.array(x, **kwargs) else: raise ValueError("Cannot convert %s to a JAX tensor" % str(type(x)))
[docs] def from_jax(x: "jax.Array", V: Optional[WithGeometry] = None) -> Union[Function, Constant]: """Convert a JAX tensor `x` into a Firedrake object. Parameters ---------- x JAX tensor to convert. V Function space of the corresponding :class:`.Function` or None when `x` is to be mapped to a :class:`.Constant`. Returns ------- firedrake.function.Function or firedrake.constant.Constant Firedrake object representing the JAX tensor `x`. """ if isinstance(x, jax.core.Tracer): x = jax.core.get_aval(x) if isinstance(x, jax.core.ShapedArray): if not isinstance(x, jax.core.ConcreteArray): raise TypeError("Cannot convert a JAX abstract array to a Firedrake object.") x = x.val if not isinstance(x, np.ndarray) and x.device.platform != "cpu": raise NotImplementedError("Firedrake does not support GPU/TPU tensors") if V is None: val = np.asarray(x) if val.shape == (1,): val = val[0] return Constant(val) else: x = np.asarray(x) x_F = Function(V) x_F.vector().set_local(x) return x_F