Source code for firedrake.adjoint.ensemble_reduced_functional

from pyadjoint.reduced_functional import AbstractReducedFunctional, ReducedFunctional
from pyadjoint.enlisting import Enlist
from pyop2.mpi import MPI

from firedrake.function import Function
from firedrake.cofunction import Cofunction


[docs] class EnsembleReducedFunctional(AbstractReducedFunctional): """Enable solving simultaneously reduced functionals in parallel. Consider a functional :math:`J` and its gradient :math:`\\dfrac{dJ}{dm}`, where :math:`m` is the control parameter. Let us assume that :math:`J` is the sum of :math:`N` functionals :math:`J_i(m)`, i.e., .. math:: J = \\sum_{i=1}^{N} J_i(m). The gradient over a summation is a linear operation. Therefore, we can write the gradient :math:`\\dfrac{dJ}{dm}` as .. math:: \\frac{dJ}{dm} = \\sum_{i=1}^{N} \\frac{dJ_i}{dm}, The :class:`EnsembleReducedFunctional` allows simultaneous evaluation of :math:`J_i` and :math:`\\dfrac{dJ_i}{dm}`. After that, the allreduce :class:`~.ensemble.Ensemble` operation is employed to sum the functionals and their gradients over an ensemble communicator. If gather_functional is present, then all the values of J are communicated to all ensemble ranks, and passed in a list to gather_functional, which is a reduced functional that expects a list of that size of the relevant types. Parameters ---------- functional : pyadjoint.OverloadedType An instance of an OverloadedType, usually :class:`pyadjoint.AdjFloat`. This should be the functional that we want to reduce. control : pyadjoint.Control or list of pyadjoint.Control A single or a list of Control instances, which you want to map to the functional. ensemble : Ensemble An instance of the :class:`~.ensemble.Ensemble`. It is used to communicate the functionals and their derivatives between the ensemble members. scatter_control : bool Whether scattering a control (or a list of controls) over the ensemble communicator ``Ensemble.ensemble comm``. gather_functional : An instance of the :class:`pyadjoint.ReducedFunctional`. that takes in all of the Js. derivative_components : list of int The indices of the controls that the derivative should be computed with respect to. If present, it overwrites ``derivative_cb_pre`` and ``derivative_cb_post``. scale : float A scaling factor applied to the functional and its gradient(with respect to the control). tape : pyadjoint.Tape A tape object that the reduced functional will use to evaluate the functional and its gradients (or derivatives). eval_cb_pre : :func: Callback function before evaluating the functional. Input is a list of Controls. eval_cb_pos : :func: Callback function after evaluating the functional. Inputs are the functional value and a list of Controls. derivative_cb_pre : :func: Callback function before evaluating gradients (or derivatives). Input is a list of gradients (or derivatives). Should return a list of Controls (usually the same list as the input) to be passed to :func:`pyadjoint.compute_gradient`. derivative_cb_post : :func: Callback function after evaluating derivatives. Inputs are the functional, a list of gradients (or derivatives), and controls. All of them are the checkpointed versions. Should return a list of gradients (or derivatives) (usually the same list as the input) to be returned from ``self.derivative``. hessian_cb_pre : :func: Callback function before evaluating the Hessian. Input is a list of Controls. hessian_cb_post : :func: Callback function after evaluating the Hessian. Inputs are the functional, a list of Hessian, and controls. See Also -------- :class:`~.ensemble.Ensemble`, :class:`pyadjoint.ReducedFunctional`. Notes ----- The functionals :math:`J_i` and the control must be defined over a common `ensemble.comm` communicator. To understand more about how ensemble parallelism works, please refer to the `Firedrake manual <https://www.firedrakeproject.org/parallelism.html#ensemble-parallelism>`_. """ def __init__(self, functional, control, ensemble, scatter_control=True, gather_functional=None, derivative_components=None, scale=1.0, tape=None, eval_cb_pre=lambda *args: None, eval_cb_post=lambda *args: None, derivative_cb_pre=lambda controls: controls, derivative_cb_post=lambda checkpoint, derivative_components, controls: derivative_components, hessian_cb_pre=lambda *args: None, hessian_cb_post=lambda *args: None): self.local_reduced_functional = ReducedFunctional( functional, control, derivative_components=derivative_components, scale=scale, tape=tape, eval_cb_pre=eval_cb_pre, eval_cb_post=eval_cb_post, derivative_cb_pre=derivative_cb_pre, derivative_cb_post=derivative_cb_post, hessian_cb_pre=hessian_cb_pre, hessian_cb_post=hessian_cb_post ) self.ensemble = ensemble self.scatter_control = scatter_control self.gather_functional = gather_functional @property def controls(self): return self.local_reduced_functional.controls def _allgather_J(self, J): if isinstance(J, float): vals = self.ensemble.ensemble_comm.allgather(J) elif isinstance(J, Function): # allgather not implemented in ensemble.py vals = [] for i in range(self.ensemble.ensemble_comm.size): J0 = J.copy(deepcopy=True) vals.append(self.ensemble.bcast(J0, root=i)) else: raise NotImplementedError(f"Functionals of type {type(J).__name__} are not supported.") return vals
[docs] def __call__(self, values): """Computes the reduced functional with supplied control value. Parameters ---------- values : pyadjoint.OverloadedType If you have multiple controls this should be a list of new values for each control in the order you listed the controls to the constructor. If you have a single control it can either be a list or a single object. Each new value should have the same type as the corresponding control. Returns ------- pyadjoint.OverloadedType The computed value. Typically of instance of :class:`pyadjoint.AdjFloat`. """ local_functional = self.local_reduced_functional(values) ensemble_comm = self.ensemble.ensemble_comm if self.gather_functional: controls_g = self._allgather_J(local_functional) total_functional = self.gather_functional(controls_g) # if gather_functional is None then we do a sum elif isinstance(local_functional, float): total_functional = ensemble_comm.allreduce(sendobj=local_functional, op=MPI.SUM) elif isinstance(local_functional, Function): total_functional = type(local_functional)(local_functional.function_space()) total_functional = self.ensemble.allreduce(local_functional, total_functional) else: raise NotImplementedError("This type of functional is not supported.") return total_functional
[docs] def derivative(self, adj_input=1.0, apply_riesz=False): """Compute derivatives of a functional with respect to the control parameters. Parameters ---------- adj_input : float The adjoint input. apply_riesz: bool If True, apply the Riesz map of each control in order to return a primal gradient rather than a derivative in the dual space. Returns ------- dJdm_total : pyadjoint.OverloadedType The result of Allreduce operations of ``dJdm_local`` into ``dJdm_total`` over the`Ensemble.ensemble_comm`. See Also -------- :meth:`~.ensemble.Ensemble.allreduce`, :meth:`pyadjoint.ReducedFunctional.derivative`. """ if self.gather_functional: dJg_dmg = self.gather_functional.derivative(adj_input=adj_input, apply_riesz=False) i = self.ensemble.ensemble_comm.rank adj_input = dJg_dmg[i] dJdm_local = self.local_reduced_functional.derivative(adj_input=adj_input, apply_riesz=apply_riesz) if self.scatter_control: dJdm_local = Enlist(dJdm_local) dJdm_total = [] for dJdm in dJdm_local: if not isinstance(dJdm, (Cofunction, Function, float)): raise NotImplementedError( f"Gradients of type {type(dJdm).__name__} are not supported.") dJdm_total.append( self.ensemble.allreduce(dJdm, type(dJdm)(dJdm.function_space())) if isinstance(dJdm, (Cofunction, Function)) else self.ensemble.ensemble_comm.allreduce(sendobj=dJdm, op=MPI.SUM) ) return dJdm_local.delist(dJdm_total) return dJdm_local
[docs] def tlm(self, m_dot): """Return the action of the tangent linear model of the functional. The tangent linear model is evaluated w.r.t. the control on a vector m_dot, around the last supplied value of the control. Parameters ---------- m_dot : pyadjoint.OverloadedType The direction in which to compute the action of the tangent linear model. Returns ------- pyadjoint.OverloadedType: The action of the tangent linear model in the direction m_dot. Should be an instance of the same type as the functional. """ local_tlm = self.local_reduced_functional.tlm(m_dot) ensemble_comm = self.ensemble.ensemble_comm if self.gather_functional: mdot_g = self._allgather_J(local_tlm) total_tlm = self.gather_functional.tlm(mdot_g) # if gather_functional is None then we do a sum elif isinstance(local_tlm, float): total_tlm = ensemble_comm.allreduce(sendobj=local_tlm, op=MPI.SUM) elif isinstance(local_tlm, Function): total_tlm = type(local_tlm)(local_tlm.function_space()) total_tlm = self.ensemble.allreduce(local_tlm, total_tlm) else: raise NotImplementedError("This type of functional is not supported.") return total_tlm
[docs] def hessian(self, m_dot, hessian_input=None, evaluate_tlm=True, apply_riesz=False): """The Hessian is not yet implemented for ensemble reduced functional. Raises: NotImplementedError: This method is not yet implemented for ensemble reduced functional. """ raise NotImplementedError("Hessian is not yet implemented for ensemble reduced functional.")