from pyadjoint.reduced_functional import AbstractReducedFunctional, ReducedFunctional
from pyadjoint.enlisting import Enlist
from pyop2.mpi import MPI
from firedrake.function import Function
from firedrake.cofunction import Cofunction
[docs]
class EnsembleReducedFunctional(AbstractReducedFunctional):
    """Enable solving simultaneously reduced functionals in parallel.
    Consider a functional :math:`J` and its gradient :math:`\\dfrac{dJ}{dm}`,
    where :math:`m` is the control parameter. Let us assume that :math:`J` is the sum of
    :math:`N` functionals :math:`J_i(m)`, i.e.,
    .. math::
        J = \\sum_{i=1}^{N} J_i(m).
    The gradient over a summation is a linear operation. Therefore, we can write the gradient
    :math:`\\dfrac{dJ}{dm}` as
    .. math::
        \\frac{dJ}{dm} = \\sum_{i=1}^{N} \\frac{dJ_i}{dm},
    The :class:`EnsembleReducedFunctional` allows simultaneous evaluation of :math:`J_i` and
    :math:`\\dfrac{dJ_i}{dm}`. After that, the allreduce :class:`~.ensemble.Ensemble`
    operation is employed to sum the functionals and their gradients over an ensemble
    communicator.
    If gather_functional is present, then all the values of J are communicated to all ensemble
    ranks, and passed in a list to gather_functional, which is a reduced functional that expects
    a list of that size of the relevant types.
    Parameters
    ----------
    functional : pyadjoint.OverloadedType
        An instance of an OverloadedType, usually :class:`pyadjoint.AdjFloat`.
        This should be the functional that we want to reduce.
    control : pyadjoint.Control or list of pyadjoint.Control
        A single or a list of Control instances, which you want to map to the functional.
    ensemble : Ensemble
        An instance of the :class:`~.ensemble.Ensemble`. It is used to communicate the
        functionals and their derivatives between the ensemble members.
    scatter_control : bool
        Whether scattering a control (or a list of controls) over the ensemble communicator
        ``Ensemble.ensemble comm``.
    gather_functional : An instance of the :class:`pyadjoint.ReducedFunctional`.
        that takes in all of the Js.
    derivative_components : list of int
        The indices of the controls that the derivative should be computed with respect to.
        If present, it overwrites ``derivative_cb_pre`` and ``derivative_cb_post``.
    scale : float
        A scaling factor applied to the functional and its gradient(with respect to the control).
    tape : pyadjoint.Tape
        A tape object that the reduced functional will use to evaluate the functional and
        its gradients (or derivatives).
    eval_cb_pre : :func:
        Callback function before evaluating the functional. Input is a list of Controls.
    eval_cb_pos : :func:
        Callback function after evaluating the functional. Inputs are the functional value
        and a list of Controls.
    derivative_cb_pre : :func:
        Callback function before evaluating gradients (or derivatives). Input is a list of
        gradients (or derivatives). Should return a list of Controls (usually the same list as
        the input) to be passed to :func:`pyadjoint.compute_gradient`.
    derivative_cb_post : :func:
        Callback function after evaluating derivatives. Inputs are the functional, a list of
        gradients (or derivatives), and controls. All of them are the checkpointed versions.
        Should return a list of gradients (or derivatives) (usually the same list as the input)
        to be returned from ``self.derivative``.
    hessian_cb_pre : :func:
        Callback function before evaluating the Hessian. Input is a list of Controls.
    hessian_cb_post : :func:
        Callback function after evaluating the Hessian. Inputs are the functional, a list of
        Hessian, and controls.
    See Also
    --------
    :class:`~.ensemble.Ensemble`, :class:`pyadjoint.ReducedFunctional`.
    Notes
    -----
    The functionals :math:`J_i` and the control must be defined over a common
    `ensemble.comm` communicator. To understand more about how ensemble parallelism
    works, please refer to the `Firedrake manual
    <https://www.firedrakeproject.org/ensemble_parallelism.html>`_.
    """
    def __init__(self, functional, control, ensemble, scatter_control=True,
                 gather_functional=None,
                 derivative_components=None,
                 scale=1.0, tape=None,
                 eval_cb_pre=lambda *args: None,
                 eval_cb_post=lambda *args: None,
                 derivative_cb_pre=lambda controls: controls,
                 derivative_cb_post=lambda checkpoint, derivative_components, controls: derivative_components,
                 hessian_cb_pre=lambda *args: None,
                 hessian_cb_post=lambda *args: None):
        self.local_reduced_functional = ReducedFunctional(
            functional, control,
            derivative_components=derivative_components,
            scale=scale, tape=tape,
            eval_cb_pre=eval_cb_pre,
            eval_cb_post=eval_cb_post,
            derivative_cb_pre=derivative_cb_pre,
            derivative_cb_post=derivative_cb_post,
            hessian_cb_pre=hessian_cb_pre,
            hessian_cb_post=hessian_cb_post
        )
        self.ensemble = ensemble
        self.scatter_control = scatter_control
        self.gather_functional = gather_functional
    @property
    def controls(self):
        return self.local_reduced_functional.controls
    def _allgather_J(self, J):
        if isinstance(J, float):
            vals = self.ensemble.ensemble_comm.allgather(J)
        elif isinstance(J, Function):
            #  allgather not implemented in ensemble.py
            vals = []
            for i in range(self.ensemble.ensemble_comm.size):
                J0 = J.copy(deepcopy=True)
                vals.append(self.ensemble.bcast(J0, root=i))
        else:
            raise NotImplementedError(f"Functionals of type {type(J).__name__} are not supported.")
        return vals
[docs]
    def __call__(self, values):
        """Computes the reduced functional with supplied control value.
        Parameters
        ----------
        values : pyadjoint.OverloadedType
            If you have multiple controls this should be a list of
            new values for each control in the order you listed the controls to the constructor.
            If you have a single control it can either be a list or a single object.
            Each new value should have the same type as the corresponding control.
        Returns
        -------
        pyadjoint.OverloadedType
            The computed value. Typically of instance of :class:`pyadjoint.AdjFloat`.
        """
        local_functional = self.local_reduced_functional(values)
        ensemble_comm = self.ensemble.ensemble_comm
        if self.gather_functional:
            controls_g = self._allgather_J(local_functional)
            total_functional = self.gather_functional(controls_g)
        # if gather_functional is None then we do a sum
        elif isinstance(local_functional, float):
            total_functional = ensemble_comm.allreduce(sendobj=local_functional, op=MPI.SUM)
        elif isinstance(local_functional, Function):
            total_functional = type(local_functional)(local_functional.function_space())
            total_functional = self.ensemble.allreduce(local_functional, total_functional)
        else:
            raise NotImplementedError("This type of functional is not supported.")
        return total_functional 
[docs]
    def derivative(self, adj_input=1.0, apply_riesz=False):
        """Compute derivatives of a functional with respect to the control parameters.
        Parameters
        ----------
        adj_input : float
            The adjoint input.
        apply_riesz: bool
            If True, apply the Riesz map of each control in order to return
            a primal gradient rather than a derivative in the dual space.
        Returns
        -------
            dJdm_total : pyadjoint.OverloadedType
            The result of Allreduce operations of ``dJdm_local`` into ``dJdm_total`` over the`Ensemble.ensemble_comm`.
        See Also
        --------
        :meth:`~.ensemble.Ensemble.allreduce`, :meth:`pyadjoint.ReducedFunctional.derivative`.
        """
        if self.gather_functional:
            dJg_dmg = self.gather_functional.derivative(adj_input=adj_input,
                                                        apply_riesz=False)
            i = self.ensemble.ensemble_comm.rank
            adj_input = dJg_dmg[i]
        dJdm_local = self.local_reduced_functional.derivative(adj_input=adj_input,
                                                              apply_riesz=apply_riesz)
        if self.scatter_control:
            dJdm_local = Enlist(dJdm_local)
            dJdm_total = []
            for dJdm in dJdm_local:
                if not isinstance(dJdm, (Cofunction, Function, float)):
                    raise NotImplementedError(
                        f"Gradients of type {type(dJdm).__name__} are not supported.")
                dJdm_total.append(
                    self.ensemble.allreduce(dJdm, type(dJdm)(dJdm.function_space()))
                    if isinstance(dJdm, (Cofunction, Function))
                    else self.ensemble.ensemble_comm.allreduce(sendobj=dJdm, op=MPI.SUM)
                )
            return dJdm_local.delist(dJdm_total)
        return dJdm_local 
[docs]
    def tlm(self, m_dot):
        """Return the action of the tangent linear model of the functional.
        The tangent linear model is evaluated w.r.t. the control on a vector
        m_dot, around the last supplied value of the control.
        Parameters
        ----------
        m_dot : pyadjoint.OverloadedType
            The direction in which to compute the action of the tangent linear model.
        Returns
        -------
            pyadjoint.OverloadedType: The action of the tangent linear model in the
            direction m_dot.  Should be an instance of the same type as the functional.
        """
        local_tlm = self.local_reduced_functional.tlm(m_dot)
        ensemble_comm = self.ensemble.ensemble_comm
        if self.gather_functional:
            mdot_g = self._allgather_J(local_tlm)
            total_tlm = self.gather_functional.tlm(mdot_g)
        # if gather_functional is None then we do a sum
        elif isinstance(local_tlm, float):
            total_tlm = ensemble_comm.allreduce(sendobj=local_tlm, op=MPI.SUM)
        elif isinstance(local_tlm, Function):
            total_tlm = type(local_tlm)(local_tlm.function_space())
            total_tlm = self.ensemble.allreduce(local_tlm, total_tlm)
        else:
            raise NotImplementedError("This type of functional is not supported.")
        return total_tlm 
[docs]
    def hessian(self, m_dot, hessian_input=None, evaluate_tlm=True, apply_riesz=False):
        """The Hessian is not yet implemented for ensemble reduced functional.
        Raises:
            NotImplementedError: This method is not yet implemented for ensemble reduced functional.
        """
        raise NotImplementedError("Hessian is not yet implemented for ensemble reduced functional.")