Source code for firedrake.external_operators.ml_operator
from firedrake.external_operators import AbstractExternalOperator, assemble_method
from firedrake.matrix import AssembledMatrix
[docs]
class MLOperator(AbstractExternalOperator):
def __init__(self, *operands, function_space, derivatives=None, argument_slots=(), operator_data):
"""External operator base class representing machine learning models implemented in a given
machine learning framework.
The :class:`.MLOperator` allows users to embed machine learning models implemented in a given
machine learning framework into PDE systems implemented in Firedrake. The actual evaluation of
the :class:`.MLOperator` subclass is delegated to the specified ML model using the ML framework considered.
Parameters
----------
*operands : ufl.core.expr.Expr or ufl.form.BaseForm
Operands of the ML operator.
function_space : firedrake.functionspaceimpl.WithGeometryBase
The function space the ML operator is mapping to.
derivatives : tuple
Tuple specifiying the derivative multiindex.
*argument_slots : ufl.coefficient.BaseCoefficient or ufl.argument.BaseArgument
Tuple containing the arguments of the linear form associated with the ML operator,
i.e. the arguments with respect to which the ML operator is linear. Those arguments
can be ufl.Argument objects, as a result of differentiation, or ufl.Coefficient objects,
as a result of taking the action on a given function.
operator_data : dict
Dictionary to stash external data specific to the ML operator. This dictionary must
at least contain the following:
(i) 'model': The machine learning model implemented in the ML framework considered.
(ii) 'inputs_format': The format of the inputs to the ML model: `0` for models acting globally on the inputs, `1` when acting locally/pointwise on the inputs.
Other strategies can also be considered by subclassing the :class:`.MLOperator` class.
"""
AbstractExternalOperator.__init__(self, *operands, function_space=function_space, derivatives=derivatives,
argument_slots=argument_slots, operator_data=operator_data)
@property
def model(self):
return self.operator_data['model']
@property
def inputs_format(self):
return self.operator_data['inputs_format']
# -- Assembly methods -- #
[docs]
@assemble_method(0, (0,))
def assemble_model(self, *args, **kwargs):
"""Assemble the operator via a forward pass through the ML model."""
return self._forward()
[docs]
@assemble_method(1, (0, 1))
def assemble_jacobian(self, *args, **kwargs):
"""Assemble the Jacobian using the AD engine of the ML framework."""
# Delegate computation to the ML framework.
J = self._jac()
# Set bcs
bcs = ()
return AssembledMatrix(self, bcs, J)
[docs]
@assemble_method(1, (1, 0))
def assemble_jacobian_adjoint(self, *args, **kwargs):
"""Assemble the Jacobian Hermitian transpose using the AD engine of the ML framework."""
# Delegate computation to the ML framework.
J = self._jac()
# Set bcs
bcs = ()
# Take the adjoint (Hermitian transpose)
J.hermitianTranspose()
return AssembledMatrix(self, bcs, J)
[docs]
@assemble_method(1, (0, None))
def assemble_jacobian_action(self, *args, **kwargs):
"""Assemble the Jacobian action using the AD engine of the ML framework."""
w = self.argument_slots()[-1]
return self._jvp(w)
[docs]
@assemble_method(1, (None, 0))
def assemble_jacobian_adjoint_action(self, *args, **kwargs):
"""Assemble the action of the Jacobian adjoint using the AD engine of the ML framework."""
w = self.argument_slots()[0]
return self._vjp(w)
# -- ML framework-specific methods -- #
def _forward(self):
raise NotImplementedError("Forward pass not implemented.")
def _jvp(self):
raise NotImplementedError("Jacobian-vector product not implemented.")
def _vjp(self):
raise NotImplementedError("Vector-Jacobian product not implemented.")
def _jac(self):
raise NotImplementedError("Jacobian not implemented.")