import numpy
import ufl
from ufl import replace
from ufl.formatting.ufl2unicode import ufl2unicode
from enum import Enum
from pyadjoint import Block, stop_annotating, get_working_tape
from pyadjoint.enlisting import Enlist
import firedrake
from firedrake.adjoint_utils.checkpointing import maybe_disk_checkpoint
from .block_utils import isconstant
[docs]
class Solver(Enum):
"""Enum for solver types."""
FORWARD = 0
ADJOINT = 1
[docs]
class GenericSolveBlock(Block):
pop_kwargs_keys = ["adj_cb", "adj_bdy_cb", "adj2_cb", "adj2_bdy_cb",
"forward_args", "forward_kwargs", "adj_args",
"adj_kwargs"]
def __init__(self, lhs, rhs, func, bcs, *args, **kwargs):
super().__init__(ad_block_tag=kwargs.pop('ad_block_tag', None))
self.adj_cb = kwargs.pop("adj_cb", None)
self.adj_bdy_cb = kwargs.pop("adj_bdy_cb", None)
self.adj2_cb = kwargs.pop("adj2_cb", None)
self.adj2_bdy_cb = kwargs.pop("adj2_bdy_cb", None)
self.forward_args = []
self.forward_kwargs = {}
self.adj_args = []
self.adj_kwargs = {}
self.assemble_kwargs = {}
# Equation LHS
self.lhs = lhs
# Equation RHS
self.rhs = rhs
# Solution function
self.func = func
self.function_space = self.func.function_space()
# Boundary conditions
self.bcs = []
if bcs is not None:
self.bcs = Enlist(bcs)
if isinstance(self.lhs, ufl.Form) and isinstance(self.rhs, (ufl.Form, ufl.Cofunction)):
self.linear = True
for c in self.rhs.coefficients():
self.add_dependency(c, no_duplicates=True)
else:
self.linear = False
for c in self.lhs.coefficients():
self.add_dependency(c, no_duplicates=True)
for bc in self.bcs:
self.add_dependency(bc, no_duplicates=True)
mesh = self.lhs.ufl_domain()
self.add_dependency(mesh)
self._init_solver_parameters(args, kwargs)
def _init_solver_parameters(self, args, kwargs):
self.forward_args = kwargs.pop("forward_args", [])
self.forward_kwargs = kwargs.pop("forward_kwargs", {})
self.adj_args = kwargs.pop("adj_args", [])
self.adj_kwargs = kwargs.pop("adj_kwargs", {})
self.assemble_kwargs = {}
def __str__(self):
return "solve({} = {})".format(ufl2unicode(self.lhs),
ufl2unicode(self.rhs))
def _create_F_form(self):
# Process the equation forms, replacing values with checkpoints,
# and gathering lhs and rhs in one single form.
if self.linear:
tmp_u = firedrake.Function(self.function_space)
F_form = firedrake.action(self.lhs, tmp_u) - self.rhs
else:
tmp_u = self.func
F_form = self.lhs
replace_map = self._replace_map(F_form)
replace_map[tmp_u] = self.get_outputs()[0].saved_output
return ufl.replace(F_form, replace_map)
def _homogenize_bcs(self):
bcs = []
for bc in self.bcs:
if isinstance(bc, firedrake.DirichletBC):
bc = bc.reconstruct(g=0)
bcs.append(bc)
return bcs
def _create_initial_guess(self):
return firedrake.Function(self.function_space)
def _recover_bcs(self):
bcs = []
for block_variable in self.get_dependencies():
c = block_variable.output
c_rep = block_variable.saved_output
if isinstance(c, firedrake.DirichletBC):
bcs.append(c_rep)
return bcs
def _replace_map(self, form):
replace_coeffs = {}
for block_variable in self.get_dependencies():
coeff = block_variable.output
if coeff in form.coefficients():
replace_coeffs[coeff] = block_variable.saved_output
return replace_coeffs
def _replace_form(self, form, func=None):
"""Replace the form coefficients with checkpointed values
func represents the initial guess if relevant.
"""
replace_map = self._replace_map(form)
if func is not None and self.func in replace_map:
firedrake.Function.assign(func, replace_map[self.func])
replace_map[self.func] = func
return ufl.replace(form, replace_map)
def _should_compute_boundary_adjoint(self, relevant_dependencies):
# Check if DirichletBC derivative is relevant
bdy = False
for _, dep in relevant_dependencies:
if isinstance(dep.output, firedrake.DirichletBC):
bdy = True
break
return bdy
@property
def adj_sol(self):
return self.adj_state
[docs]
def prepare_evaluate_adj(self, inputs, adj_inputs, relevant_dependencies):
fwd_block_variable = self.get_outputs()[0]
u = fwd_block_variable.output
dJdu = adj_inputs[0]
F_form = self._create_F_form()
dFdu = firedrake.derivative(
F_form,
fwd_block_variable.saved_output,
firedrake.TrialFunction(
u.function_space()
)
)
dFdu_form = firedrake.adjoint(dFdu)
dJdu = dJdu.copy()
compute_bdy = self._should_compute_boundary_adjoint(
relevant_dependencies
)
adj_sol, adj_sol_bdy = self._assemble_and_solve_adj_eq(
dFdu_form, dJdu, compute_bdy
)
self.adj_state = adj_sol
if self.adj_cb is not None:
self.adj_cb(adj_sol)
if self.adj_bdy_cb is not None and compute_bdy:
self.adj_bdy_cb(adj_sol_bdy)
r = {}
r["form"] = F_form
r["adj_sol"] = adj_sol
r["adj_sol_bdy"] = adj_sol_bdy
return r
def _assemble_dFdu_adj(self, dFdu_adj_form, **kwargs):
return firedrake.assemble(dFdu_adj_form, **kwargs)
def _assemble_and_solve_adj_eq(self, dFdu_adj_form, dJdu, compute_bdy):
dJdu_copy = dJdu.copy()
# Homogenize and apply boundary conditions on adj_dFdu and dJdu.
bcs = self._homogenize_bcs()
dFdu = firedrake.assemble(dFdu_adj_form, bcs=bcs, **self.assemble_kwargs)
for bc in bcs:
bc.zero(dJdu)
adj_sol = firedrake.Function(self.function_space)
firedrake.solve(
dFdu, adj_sol, dJdu, *self.adj_args, **self.adj_kwargs
)
adj_sol_bdy = None
if compute_bdy:
adj_sol_bdy = self._compute_adj_bdy(
adj_sol, adj_sol_bdy, dFdu_adj_form, dJdu_copy)
return adj_sol, adj_sol_bdy
def _compute_adj_bdy(self, adj_sol, adj_sol_bdy, dFdu_adj_form, dJdu):
adj_sol_bdy = firedrake.assemble(dJdu - firedrake.action(dFdu_adj_form, adj_sol))
return adj_sol_bdy.riesz_representation("l2")
[docs]
def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx,
prepared=None):
if not self.linear and self.func == block_variable.output:
# We are not able to calculate derivatives wrt initial guess.
return None
F_form = prepared["form"]
adj_sol = prepared["adj_sol"]
adj_sol_bdy = prepared["adj_sol_bdy"]
c = block_variable.output
c_rep = block_variable.saved_output
if isconstant(c):
mesh = F_form.ufl_domain()
trial_function = firedrake.TrialFunction(
c._ad_function_space(mesh)
)
elif isinstance(c, (firedrake.Function, firedrake.Cofunction)):
trial_function = firedrake.TrialFunction(c.function_space())
elif isinstance(c, firedrake.DirichletBC):
tmp_bc = c.reconstruct(
g=extract_subfunction(adj_sol_bdy, c.function_space())
)
return [tmp_bc]
elif isinstance(c, firedrake.MeshGeometry):
# Using CoordinateDerivative requires us to do action before
# differentiating, might change in the future.
F_form_tmp = firedrake.action(F_form, adj_sol)
X = firedrake.SpatialCoordinate(c_rep)
dFdm = firedrake.derivative(
-F_form_tmp, X,
firedrake.TestFunction(c._ad_function_space())
)
if dFdm == 0:
return firedrake.Function(c._ad_function_space().dual())
dFdm = firedrake.assemble(dFdm, **self.assemble_kwargs)
return dFdm
dFdm = -firedrake.derivative(F_form, c_rep, trial_function)
if isinstance(dFdm, ufl.Form):
dFdm = firedrake.adjoint(dFdm)
dFdm = firedrake.action(dFdm, adj_sol)
else:
dFdm = dFdm(adj_sol)
dFdm = firedrake.assemble(dFdm, **self.assemble_kwargs)
return dFdm
[docs]
def prepare_evaluate_tlm(self, inputs, tlm_inputs, relevant_outputs):
fwd_block_variable = self.get_outputs()[0]
u = fwd_block_variable.output
F_form = self._create_F_form()
# Obtain dFdu.
dFdu = firedrake.derivative(
F_form,
fwd_block_variable.saved_output,
firedrake.TrialFunction(u.function_space())
)
return {
"form": F_form,
"dFdu": dFdu
}
[docs]
def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx,
prepared=None):
F_form = prepared["form"]
dFdu = prepared["dFdu"]
V = self.get_outputs()[idx].output.function_space()
bcs = []
dFdm = 0.
for block_variable in self.get_dependencies():
tlm_value = block_variable.tlm_value
c = block_variable.output
c_rep = block_variable.saved_output
if isinstance(c, firedrake.DirichletBC):
if tlm_value is None:
bcs.append(c.reconstruct(g=0))
else:
bcs.append(tlm_value)
continue
elif isinstance(c, firedrake.MeshGeometry):
X = firedrake.SpatialCoordinate(c)
c_rep = X
if tlm_value is None:
continue
if c == self.func and not self.linear:
continue
dFdm += firedrake.derivative(-F_form, c_rep, tlm_value)
if isinstance(dFdm, float):
v = dFdu.arguments()[0]
dFdm = firedrake.inner(
firedrake.Constant(numpy.zeros(v.ufl_shape)), v
) * firedrake.dx
dFdm = ufl.algorithms.expand_derivatives(dFdm)
dFdm = firedrake.assemble(dFdm)
dudm = firedrake.Function(V)
return self._assemble_and_solve_tlm_eq(
firedrake.assemble(dFdu, bcs=bcs, **self.assemble_kwargs),
dFdm, dudm, bcs
)
def _assemble_and_solve_tlm_eq(self, dFdu, dFdm, dudm, bcs):
return self._assembled_solve(dFdu, dFdm, dudm, bcs)
def _assemble_soa_eq_rhs(self, dFdu_form, adj_sol, hessian_input, d2Fdu2):
# Start piecing together the rhs of the soa equation
b = hessian_input.copy()
if len(d2Fdu2.integrals()) > 0:
b_form = firedrake.action(firedrake.adjoint(d2Fdu2), adj_sol)
else:
b_form = d2Fdu2
for bo in self.get_dependencies():
c = bo.output
c_rep = bo.saved_output
tlm_input = bo.tlm_value
if (c == self.func and not self.linear) or tlm_input is None:
continue
if isinstance(c, firedrake.MeshGeometry):
X = firedrake.SpatialCoordinate(c)
dFdu_adj = firedrake.action(firedrake.adjoint(dFdu_form),
adj_sol)
d2Fdudm = ufl.algorithms.expand_derivatives(
firedrake.derivative(dFdu_adj, X, tlm_input))
if len(d2Fdudm.integrals()) > 0:
b_form += d2Fdudm
elif not isinstance(c, firedrake.DirichletBC):
dFdu_adj = firedrake.action(firedrake.adjoint(dFdu_form),
adj_sol)
b_form += firedrake.derivative(dFdu_adj, c_rep, tlm_input)
b_form = ufl.algorithms.expand_derivatives(b_form)
if len(b_form.integrals()) > 0:
b -= firedrake.assemble(b_form)
return b
def _assemble_and_solve_soa_eq(self, dFdu_form, adj_sol, hessian_input,
d2Fdu2, compute_bdy):
b = self._assemble_soa_eq_rhs(dFdu_form, adj_sol, hessian_input,
d2Fdu2)
dFdu_form = firedrake.adjoint(dFdu_form)
adj_sol2, adj_sol2_bdy = self._assemble_and_solve_adj_eq(dFdu_form, b,
compute_bdy)
if self.adj2_cb is not None:
self.adj2_cb(adj_sol2)
if self.adj2_bdy_cb is not None and compute_bdy:
self.adj2_bdy_cb(adj_sol2_bdy)
return adj_sol2, adj_sol2_bdy
[docs]
def prepare_evaluate_hessian(self, inputs, hessian_inputs, adj_inputs,
relevant_dependencies):
# First fetch all relevant values
fwd_block_variable = self.get_outputs()[0]
hessian_input = hessian_inputs[0]
tlm_output = fwd_block_variable.tlm_value
if hessian_input is None:
return
if tlm_output is None:
return
F_form = self._create_F_form()
# Using the equation Form derive dF/du, d^2F/du^2 * du/dm * direction.
dFdu_form = firedrake.derivative(F_form,
fwd_block_variable.saved_output)
d2Fdu2 = ufl.algorithms.expand_derivatives(
firedrake.derivative(dFdu_form, fwd_block_variable.saved_output,
tlm_output))
adj_sol = self.adj_state
if adj_sol is None:
raise RuntimeError("Hessian computation was run before adjoint.")
bdy = self._should_compute_boundary_adjoint(relevant_dependencies)
adj_sol2, adj_sol2_bdy = self._assemble_and_solve_soa_eq(
dFdu_form, adj_sol, hessian_input, d2Fdu2, bdy
)
r = {}
r["adj_sol2"] = adj_sol2
r["adj_sol2_bdy"] = adj_sol2_bdy
r["form"] = F_form
r["adj_sol"] = adj_sol
return r
[docs]
def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs,
block_variable, idx, relevant_dependencies,
prepared=None):
c = block_variable.output
if c == self.func and not self.linear:
return None
adj_sol2 = prepared["adj_sol2"]
adj_sol2_bdy = prepared["adj_sol2_bdy"]
F_form = prepared["form"]
adj_sol = prepared["adj_sol"]
fwd_block_variable = self.get_outputs()[0]
tlm_output = fwd_block_variable.tlm_value
c_rep = block_variable.saved_output
# If m = DirichletBC then d^2F(u,m)/dm^2 = 0 and d^2F(u,m)/dudm = 0,
# so we only have the term dF(u,m)/dm * adj_sol2
if isinstance(c, firedrake.DirichletBC):
tmp_bc = c.reconstruct(
g=extract_subfunction(adj_sol2_bdy, c.function_space())
)
return [tmp_bc]
if isconstant(c_rep):
mesh = F_form.ufl_domain()
W = c._ad_function_space(mesh)
elif isinstance(c, firedrake.MeshGeometry):
X = firedrake.SpatialCoordinate(c)
W = c._ad_function_space()
else:
W = c.function_space()
dc = firedrake.TestFunction(W)
form_adj = firedrake.action(F_form, adj_sol)
form_adj2 = firedrake.action(F_form, adj_sol2)
if isinstance(c, firedrake.MeshGeometry):
dFdm_adj = firedrake.derivative(form_adj, X, dc)
dFdm_adj2 = firedrake.derivative(form_adj2, X, dc)
else:
dFdm_adj = firedrake.derivative(form_adj, c_rep, dc)
dFdm_adj2 = firedrake.derivative(form_adj2, c_rep, dc)
# TODO: Old comment claims this might break on split. Confirm if true
# or not.
d2Fdudm = ufl.algorithms.expand_derivatives(
firedrake.derivative(dFdm_adj, fwd_block_variable.saved_output,
tlm_output))
d2Fdm2 = 0
# We need to add terms from every other dependency
# i.e. the terms d^2F/dm_1dm_2
for _, bv in relevant_dependencies:
c2 = bv.output
c2_rep = bv.saved_output
if isinstance(c2, firedrake.DirichletBC):
continue
tlm_input = bv.tlm_value
if tlm_input is None:
continue
if c2 == self.func and not self.linear:
continue
# TODO: If tlm_input is a Sum, this crashes in some instances?
if isinstance(c2_rep, firedrake.MeshGeometry):
X = firedrake.SpatialCoordinate(c2_rep)
d2Fdm2 += ufl.algorithms.expand_derivatives(
firedrake.derivative(dFdm_adj, X, tlm_input)
)
else:
d2Fdm2 += ufl.algorithms.expand_derivatives(
firedrake.derivative(dFdm_adj, c2_rep, tlm_input)
)
hessian_form = ufl.algorithms.expand_derivatives(
d2Fdm2 + dFdm_adj2 + d2Fdudm
)
hessian_output = 0
if not hessian_form.empty():
hessian_output = firedrake.assemble(hessian_form)
hessian_output *= -1.
return hessian_output
[docs]
def prepare_recompute_component(self, inputs, relevant_outputs):
return self._replace_recompute_form()
def _replace_recompute_form(self):
func = self._create_initial_guess()
bcs = self._recover_bcs()
lhs = self._replace_form(self.lhs, func=func)
rhs = 0
if self.linear:
rhs = self._replace_form(self.rhs)
return lhs, rhs, func, bcs
def _forward_solve(self, lhs, rhs, func, bcs):
firedrake.solve(lhs == rhs, func, bcs, *self.forward_args,
**self.forward_kwargs)
return func
def _assembled_solve(self, lhs, rhs, func, bcs, **kwargs):
rhs_func = rhs.riesz_representation(riesz_map="l2")
for bc in bcs:
bc.apply(rhs_func)
rhs.assign(rhs_func.riesz_representation(riesz_map="l2"))
firedrake.solve(lhs, func, rhs, **kwargs)
return func
[docs]
def recompute_component(self, inputs, block_variable, idx, prepared):
lhs = prepared[0]
rhs = prepared[1]
func = prepared[2]
bcs = prepared[3]
result = self._forward_solve(lhs, rhs, func, bcs)
if isinstance(block_variable.checkpoint, firedrake.Function):
result = block_variable.checkpoint.assign(result)
return maybe_disk_checkpoint(result)
[docs]
def solve_init_params(self, args, kwargs, varform):
if len(self.forward_args) <= 0:
self.forward_args = args
if len(self.forward_kwargs) <= 0:
self.forward_kwargs = kwargs.copy()
if len(self.adj_args) <= 0:
self.adj_args = self.forward_args
if len(self.adj_kwargs) <= 0:
self.adj_kwargs = self.forward_kwargs.copy()
if varform:
if "J" in self.forward_kwargs:
self.adj_kwargs["J"] = firedrake.adjoint(
self.forward_kwargs["J"]
)
if "Jp" in self.forward_kwargs:
self.adj_kwargs["Jp"] = firedrake.adjoint(
self.forward_kwargs["Jp"]
)
if "M" in self.forward_kwargs:
raise NotImplementedError(
"Annotation of adaptive solves not implemented."
)
self.adj_kwargs.pop("appctx", None)
if "mat_type" in kwargs.get("solver_parameters", {}):
self.assemble_kwargs["mat_type"] = \
kwargs["solver_parameters"]["mat_type"]
if varform:
if "appctx" in kwargs:
self.assemble_kwargs["appctx"] = kwargs["appctx"]
[docs]
class SolveLinearSystemBlock(GenericSolveBlock):
def __init__(self, A, u, b, *args, **kwargs):
lhs = A.form
func = u.function
rhs = b.form
bcs = A.bcs if hasattr(A, "bcs") else []
super().__init__(lhs, rhs, func, bcs, *args, **kwargs)
# Set up parameters initialization
self.ident_zeros_tol = \
A.ident_zeros_tol if hasattr(A, "ident_zeros_tol") else None
self.assemble_system = \
A.assemble_system if hasattr(A, "assemble_system") else False
def _init_solver_parameters(self, args, kwargs):
super()._init_solver_parameters(args, kwargs)
solve_init_params(self, args, kwargs, varform=False)
[docs]
class NonlinearVariationalSolveBlock(GenericSolveBlock):
def __init__(self, equation, func, bcs, adj_cache, problem_J,
solver_params, solver_kwargs, **kwargs):
lhs = equation.lhs
rhs = equation.rhs
self._adj_cache = adj_cache
self._dFdm_cache = adj_cache.setdefault("dFdm_cache", {})
self.problem_J = problem_J
self.solver_params = solver_params.copy()
self.solver_kwargs = solver_kwargs
super().__init__(lhs, rhs, func, bcs, **{**solver_kwargs, **kwargs})
if self.problem_J is not None:
for coeff in self.problem_J.coefficients():
self.add_dependency(coeff, no_duplicates=True)
def _init_solver_parameters(self, args, kwargs):
super()._init_solver_parameters(args, kwargs)
solve_init_params(self, args, kwargs, varform=True)
[docs]
def recompute_component(self, inputs, block_variable, idx, prepared):
tape = get_working_tape()
if self._ad_solvers["recompute_count"] == tape.recompute_count - 1:
# Update how many times the block has been recomputed.
self._ad_solvers["recompute_count"] = tape.recompute_count
if self._ad_solvers["forward_nlvs"]._problem._constant_jacobian:
self._ad_solvers["forward_nlvs"].invalidate_jacobian()
self._ad_solvers["update_adjoint"] = True
return super().recompute_component(inputs, block_variable, idx, prepared)
def _forward_solve(self, lhs, rhs, func, bcs, **kwargs):
self._ad_solver_replace_forms()
self._ad_solvers["forward_nlvs"].parameters.update(self.solver_params)
self._ad_solvers["forward_nlvs"].solve()
func.assign(self._ad_solvers["forward_nlvs"]._problem.u)
return func
def _adjoint_solve(self, dJdu, compute_bdy):
dJdu_copy = dJdu.copy()
# Homogenize and apply boundary conditions on adj_dFdu and dJdu.
for bc in self.bcs:
bc.zero(dJdu)
if (
self._ad_solvers["forward_nlvs"]._problem._constant_jacobian
and self._ad_solvers["update_adjoint"]
):
# Update left hand side of the adjoint equation.
self._ad_solver_replace_forms(Solver.ADJOINT)
self._ad_solvers["adjoint_lvs"].invalidate_jacobian()
self._ad_solvers["update_adjoint"] = False
elif not self._ad_solvers["forward_nlvs"]._problem._constant_jacobian:
# Update left hand side of the adjoint equation.
self._ad_solver_replace_forms(Solver.ADJOINT)
# Update the right hand side of the adjoint equation.
# problem.F._component[1] is the right hand side of the adjoint.
self._ad_solvers["adjoint_lvs"]._problem.F._components[1].assign(dJdu)
# Solve the adjoint linear variational solver.
self._ad_solvers["adjoint_lvs"].solve()
u_sol = self._ad_solvers["adjoint_lvs"]._problem.u
adj_sol_bdy = None
if compute_bdy:
jac_adj = self._ad_solvers["adjoint_lvs"]._problem.J
adj_sol_bdy = self._compute_adj_bdy(
u_sol, adj_sol_bdy, jac_adj, dJdu_copy)
return u_sol, adj_sol_bdy
def _ad_assign_map(self, form, solver):
if solver == Solver.FORWARD:
count_map = self._ad_solvers["forward_nlvs"]._problem._ad_count_map
else:
count_map = self._ad_solvers["adjoint_lvs"]._problem._ad_count_map
assign_map = {}
form_ad_count_map = dict((count_map[coeff], coeff)
for coeff in form.coefficients())
for block_variable in self.get_dependencies():
coeff = block_variable.output
if isinstance(coeff,
(firedrake.Coefficient, firedrake.Constant,
firedrake.Cofunction)):
coeff_count = coeff.count()
if coeff_count in form_ad_count_map:
assign_map[form_ad_count_map[coeff_count]] = \
block_variable.saved_output
if (
solver == Solver.ADJOINT
and not self._ad_solvers["forward_nlvs"]._problem._constant_jacobian
):
block_variable = self.get_outputs()[0]
coeff_count = block_variable.output.count()
if coeff_count in form_ad_count_map:
assign_map[form_ad_count_map[coeff_count]] = \
block_variable.saved_output
return assign_map
def _ad_assign_coefficients(self, form, solver):
assign_map = self._ad_assign_map(form, solver)
for coeff, value in assign_map.items():
coeff.assign(value)
def _ad_solver_replace_forms(self, solver=Solver.FORWARD):
if solver == Solver.FORWARD:
problem = self._ad_solvers["forward_nlvs"]._problem
self._ad_assign_coefficients(problem.F, solver)
self._ad_assign_coefficients(problem.J, solver)
else:
self._ad_assign_coefficients(
self._ad_solvers["adjoint_lvs"]._problem.J, solver)
[docs]
def prepare_evaluate_adj(self, inputs, adj_inputs, relevant_dependencies):
compute_bdy = self._should_compute_boundary_adjoint(
relevant_dependencies
)
adj_sol, adj_sol_bdy = self._adjoint_solve(adj_inputs[0], compute_bdy)
self.adj_state = adj_sol
if self.adj_cb is not None:
self.adj_cb(adj_sol)
if self.adj_bdy_cb is not None and compute_bdy:
self.adj_bdy_cb(adj_sol_bdy)
r = {}
r["form"] = self._create_F_form()
r["adj_sol"] = self.adj_state
r["adj_sol_bdy"] = adj_sol_bdy
return r
[docs]
def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx,
prepared=None):
if not self.linear and self.func == block_variable.output:
# We are not able to calculate derivatives wrt initial guess.
return None
F_form = prepared["form"]
adj_sol = prepared["adj_sol"]
adj_sol_bdy = prepared["adj_sol_bdy"]
c = block_variable.output
c_rep = block_variable.saved_output
if isinstance(c, firedrake.Function):
trial_function = firedrake.TrialFunction(c.function_space())
elif isinstance(c, firedrake.Constant):
mesh = F_form.ufl_domain()
trial_function = firedrake.TrialFunction(
c._ad_function_space(mesh)
)
elif isinstance(c, firedrake.DirichletBC):
tmp_bc = c.reconstruct(
g=extract_subfunction(adj_sol_bdy, c.function_space())
)
return [tmp_bc]
elif isinstance(c, firedrake.MeshGeometry):
# Using CoordianteDerivative requires us to do action before
# differentiating, might change in the future.
F_form_tmp = firedrake.action(F_form, adj_sol)
X = firedrake.SpatialCoordinate(c_rep)
dFdm = firedrake.derivative(-F_form_tmp, X, firedrake.TestFunction(
c._ad_function_space())
)
dFdm = firedrake.assemble(dFdm, **self.assemble_kwargs)
return dFdm
# dFdm_cache works with original variables, not block saved outputs.
if c in self._dFdm_cache:
dFdm = self._dFdm_cache[c]
else:
dFdm = -firedrake.derivative(self.lhs, c, trial_function)
dFdm = firedrake.adjoint(dFdm)
self._dFdm_cache[c] = dFdm
# Replace the form coefficients with checkpointed values.
replace_map = self._replace_map(dFdm)
replace_map[self.func] = self.get_outputs()[0].saved_output
dFdm = replace(dFdm, replace_map)
dFdm = dFdm * adj_sol
dFdm = firedrake.assemble(dFdm, **self.assemble_kwargs)
return dFdm
[docs]
class ProjectBlock(SolveVarFormBlock):
def __init__(self, v, V, output, bcs=[], *args, **kwargs):
mesh = kwargs.pop("mesh", None)
if mesh is None:
mesh = V.mesh()
dx = firedrake.dx(mesh)
w = firedrake.TestFunction(V)
Pv = firedrake.TrialFunction(V)
a = firedrake.inner(Pv, w) * dx
L = firedrake.inner(v, w) * dx
super().__init__(a == L, output, bcs, *args, **kwargs)
def _init_solver_parameters(self, args, kwargs):
super()._init_solver_parameters(args, kwargs)
solve_init_params(self, args, kwargs, varform=True)
[docs]
class SupermeshProjectBlock(Block):
r"""
Annotates supermesh projection.
Suppose we have a source space, :math:`V_A`, and a target space,
:math:`V_B`. Projecting a source from :math:`V_A` to :math:`V_B` amounts to
solving the linear system
.. math::
M_B * v_B = M_{AB} * v_A,
where
* :math:`M_B` is the mass matrix on :math:`V_B`,
* :math:`M_{AB}` is the mixed mass matrix for :math:`V_A` and
:math:`V_B`,
* :math:`v_A` and :math:`v_B` are vector representations of the source
and target :class:`.Function` s.
This can be broken into two steps:
Step 1. form RHS, multiplying the source with the mixed mass matrix;
Step 2. solve linear system.
"""
def __init__(self, source, target_space, target, bcs=[], **kwargs):
super(SupermeshProjectBlock, self).__init__(
ad_block_tag=kwargs.pop("ad_block_tag", None)
)
import firedrake.supermeshing as supermesh
# Process args and kwargs
if not isinstance(source, firedrake.Function):
raise NotImplementedError(
f"Source function must be a Function, not {type(source)}."
)
if bcs != []:
raise NotImplementedError(
"Boundary conditions not yet considered."
)
# Store spaces
mesh = kwargs.pop("mesh", None)
if mesh is None:
mesh = target_space.mesh()
self.source_space = source.function_space()
self.target_space = target_space
self.projector = firedrake.Projector(source, target_space, **kwargs)
# Assemble mixed mass matrix
with stop_annotating():
self.mixed_mass = supermesh.assemble_mixed_mass_matrix(
source.function_space(), target_space
)
# Add dependencies
self.add_dependency(source, no_duplicates=True)
for bc in bcs:
self.add_dependency(bc, no_duplicates=True)
[docs]
def apply_mixedmass(self, a):
b = firedrake.Function(self.target_space.dual())
with a.dat.vec_ro as vsrc, b.dat.vec_wo as vrhs:
self.mixed_mass.mult(vsrc, vrhs)
return b
[docs]
def recompute_component(self, inputs, block_variable, idx, prepared):
if not isinstance(inputs[0], firedrake.Function):
raise NotImplementedError(
f"Source function must be a Function, not {type(inputs[0])}."
)
target = firedrake.Function(self.target_space)
rhs = self.apply_mixedmass(inputs[0]) # Step 1
self.projector.apply_massinv(target, rhs) # Step 2
return maybe_disk_checkpoint(target)
def _recompute_component_transpose(self, inputs):
if not isinstance(inputs[0], firedrake.Cofunction):
raise NotImplementedError(
f"Source function must be a Cofunction, not {type(inputs[0])}."
)
out = firedrake.Cofunction(self.source_space.dual())
tmp = firedrake.Function(self.target_space)
# Adjoint of step 2 (mass is self-adjoint)
self.projector.apply_massinv(tmp, inputs[0])
with tmp.dat.vec_ro as vtmp, out.dat.vec_wo as vout:
self.mixed_mass.multTranspose(vtmp, vout) # Adjoint of step 1
return out
[docs]
def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx,
prepared=None):
"""
Evaluate the adjoint to one output of the block
Recall that the forward propagation can be broken down as:
Step 1. multiply :math:`w := M_{AB} * v_A`;
Step 2. solve :math:`M_B * v_B = w`.
For a seed vector :math:`v_B^{seed}` from the target space, the adjoint
is given by:
Adjoint of step 2. solve :math:`M_B^T * w = v_B^{seed}` for `w`;
Adjoint of step 1. multiply :math:`v_A^{adj} := M_{AB}^T * w`.
"""
if len(adj_inputs) != 1:
raise NotImplementedError(
"SupermeshProjectBlock must have a single output"
)
return self._recompute_component_transpose(adj_inputs)
[docs]
def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx,
prepared=None):
"""
Given that the input is a `Function`, we just have a linear operation.
As such, the tlm is just the sum of each tlm input projected into the
target space.
"""
dJdm = firedrake.Function(self.target_space)
for tlm_input in tlm_inputs:
if tlm_input is None:
continue
dJdm += self.recompute_component([tlm_input], block_variable, idx,
prepared)
return dJdm
[docs]
def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs,
block_variable, idx,
relevant_dependencies, prepared=None):
if len(hessian_inputs) != 1:
raise NotImplementedError(
"SupermeshProjectBlock must have a single output"
)
return self.evaluate_adj_component(inputs, hessian_inputs,
block_variable, idx)
def __str__(self):
target_string = f"〈{str(self.target_space.ufl_element().shortstr())}〉"
return f"project({self.get_dependencies()[0]}, {target_string}))"