Source code for firedrake.ml.pytorch.ml_operator

import os
import numpy as np
try:
    import torch
    import torch.autograd.functional as torch_func
except ImportError:
    if "FIREDRAKE_BUILDING_DOCS" in os.environ:
        # If building docs and pytorch is not installed, produce a mock
        # torch.autograd.Function class with the correct `__module__`
        # attribute. This is sufficient for the intersphinx reference to
        # resolve.
        from types import SimpleNamespace, new_class
        torch = SimpleNamespace()
        torch.autograd = SimpleNamespace()
        torch.autograd.Function = new_class("Function")
        torch.autograd.Function.__module__ = "torch.autograd"
    else:
        raise ImportError("PyTorch is not installed and is required to use the FiredrakeTorchOperator.")


from functools import partial

from firedrake.external_operators import MLOperator
from firedrake import utils
from firedrake.ml.pytorch import to_torch, from_torch
from firedrake.petsc import PETSc


[docs] class PytorchOperator(MLOperator): def __init__(self, *operands, function_space, derivatives=None, argument_slots=(), operator_data): """External operator class representing machine learning models implemented in PyTorch. The :class:`.PytorchOperator` allows users to embed machine learning models implemented in PyTorch into PDE systems implemented in Firedrake. The actual evaluation of the :class:`.PytorchOperator` is delegated to the specified PyTorch model. Similarly, differentiation through the :class:`.PytorchOperator` class is achieved via the `torch.autograd` module, which provides automatic differentiation capabilities that can be applied on the PyTorch model associated with the :class:`.PytorchOperator` object. Parameters ---------- *operands : ufl.core.expr.Expr or ufl.form.BaseForm Operands of the :class:`.PytorchOperator`. function_space : firedrake.functionspaceimpl.WithGeometryBase The function space the ML operator is mapping to. derivatives : tuple Tuple specifiying the derivative multiindex. *argument_slots : ufl.coefficient.BaseCoefficient or ufl.argument.BaseArgument Tuple containing the arguments of the linear form associated with the ML operator, i.e. the arguments with respect to which the ML operator is linear. Those arguments can be ufl.Argument objects, as a result of differentiation, or ufl.Coefficient objects, as a result of taking the action on a given function. operator_data : dict Dictionary to stash external data specific to the ML operator. This dictionary must at least contain the following: (i) 'model': The machine learning model implemented in PyTorch. (ii) 'inputs_format': The format of the inputs to the ML model: `0` for models acting globally on the inputs, `1` when acting locally/pointwise on the inputs. Other strategies can also be considered by subclassing the :class:`.PytorchOperator` class. """ MLOperator.__init__(self, *operands, function_space=function_space, derivatives=derivatives, argument_slots=argument_slots, operator_data=operator_data) # Convert default Firedrake data type (numpy type) to PyTorch data type # -> Use in-built torch's type conversion capabilities default_type = torch.tensor(np.empty(0, dtype=utils.ScalarType)).dtype # If no data type is specified -> default to Firedrake data type (float64 or complex64 in complex mode) dtype = operator_data.get("dtype", default_type) # Cast model to `dtype` self.model.type(dtype) # Stash the output of the neural network for conserving the PyTorch tape # -> This enables to only traverse the graph once instead of running multiple # forward pass for evaluation and backpropagation. @property def model_output(self): return self.operator_data.get('model_output') @model_output.setter def model_output(self, output): self.operator_data['model_output'] = output @utils.cached_property def torch_grad_enabled(self): # Default: set PyTorch annotation on, unless otherwise specified. return self.operator_data.get('torch_grad_enabled', True) # --- Callbacks --- # def _pre_forward_callback(self, *operands, unsqueeze=False): """Callback function to convert the Firedrake operand(s) to form the PyTorch input of the ML model.""" # Default: concatenate the operands to form the model inputs # -> For more complex cases, the user needs to overwrite this function # to state how the operands can be used to form the inputs. inputs = torch.cat([to_torch(op, requires_grad=True, batched=False) for op in operands]) if unsqueeze: return torch.unsqueeze(inputs, self.inputs_format) return inputs def _post_forward_callback(self, y_P): """Callback function to convert the PyTorch output of the ML model to a Firedrake function.""" space = self.ufl_function_space() return from_torch(y_P, space) # -- PyTorch routines for computing AD based quantities via `torch.autograd.functional` -- # # One could also extend the assembly to hessian, hvp (hessian-vector product) and # vhp (vector-hessian product) using `torch.autograd.functional.{hvp, hessian, vhp}` def _vjp(self, y): """Implement the vector-Jacobian product (VJP) for a given vector `y`.""" model = self.model x = self._pre_forward_callback(*self.ufl_operands) y_P = self._pre_forward_callback(y) _, vjp = torch_func.vjp(lambda x: model(x), x, y_P) vjp_F = self._post_forward_callback(vjp) return vjp_F def _jvp(self, z): """Implement the Jacobian-vector product (JVP) for a given vector `z`.""" model = self.model x = self._pre_forward_callback(*self.ufl_operands) z_P = self._pre_forward_callback(z) _, jvp = torch_func.jvp(lambda x: model(x), x, z_P) jvp_F = self._post_forward_callback(jvp) return jvp_F def _jac(self): """Compute the Jacobian of the PyTorch model.""" # Get the model model = self.model # Don't unsqueeze so that we end up with a rank 2 tensor x = self._pre_forward_callback(*self.ufl_operands, unsqueeze=False) jac = torch_func.jacobian(lambda x: model(x), x) # For big matrices, assembling the Jacobian is not a good idea and one should instead # look for the Jacobian action (e.g. via using matrix-free methods) which in turn would call `jvp`. n, m = jac.shape J = PETSc.Mat().create() J.setSizes([n, m]) J.setType("dense") J.setUp() # Set values using Jacobian computed by PyTorch J.setValues(range(n), range(m), jac.numpy().flatten()) J.assemble() return J def _forward(self): """Perform the forward pass through the PyTorch model.""" model = self.model # Get the input operands ops = self.ufl_operands # By default PyTorch annotation is on (i.e. equivalent to `with torch.enable_grad()`) with torch.set_grad_enabled(self.torch_grad_enabled): # Pre forward callback x_P = self._pre_forward_callback(*ops) # Vectorized forward pass y_P = model(x_P) # Stash model output self.model_output = y_P # Post forward callback y_F = self._post_forward_callback(y_P) return y_F
# Helper functions #
[docs] def ml_operator(model, function_space, inputs_format=0): """Helper function for instantiating the :class:`~.PytorchOperator` class. This function facilitates having a two-stage instantiation which dissociates between class arguments that are fixed, such as the function space or the ML model, and the operands of the operator, which may change, e.g. when the operator is used in a time-loop. Example ------- .. code-block:: python # Stage 1: Partially initialise the operator. N = ml_operator(model, function_space=V) # Stage 2: Define the operands and use the operator in a UFL expression. F = (inner(grad(u), grad(v)) + inner(N(u), v) - inner(f, v)) * dx Parameters ---------- model: collections.abc.Callable The PyTorch model to embed in Firedrake. function_space: firedrake.functionspaceimpl.WithGeometryBase The function space into which the machine learning model is mapping. inputs_format: int The format of the input data of the ML model: `0` for models acting globally on the inputs, `1` when acting locally/pointwise on the inputs. Other strategies can also be considered by subclassing the :class:`.PytorchOperator` class. Returns ------- collections.abc.Callable The partially initialised :class:`~.PytorchOperator` class. """ from firedrake_citations import Citations Citations().register("Bouziani2021") Citations().register("Bouziani2024") if inputs_format not in (0, 1): raise ValueError('Expecting inputs_format to be 0 or 1') operator_data = {'model': model, 'inputs_format': inputs_format} return partial(PytorchOperator, function_space=function_space, operator_data=operator_data)
[docs] def neuralnet(model, function_space, inputs_format=0): import warnings warnings.warn('`neuralnet` is deprecated, use `ml_operator` instead', FutureWarning) return ml_operator(model, function_space, inputs_format=inputs_format)
neuralnet.__doc__ = ml_operator.__doc__