Source code for firedrake.adjoint_utils.solving
from functools import wraps
from pyadjoint.tape import get_working_tape, stop_annotating, annotate_tape
from firedrake.adjoint_utils.blocks import SolveVarFormBlock, SolveLinearSystemBlock, GenericSolveBlock, ProjectBlock
import ufl
[docs]
def annotate_solve(solve):
"""This solve routine wraps the Firedrake :func:`.solve` call. Its purpose is to annotate the model,
recording what solves occur and what forms are involved, so that the adjoint and tangent linear models may be
constructed automatically by pyadjoint.
To disable the annotation, just pass ``annotate=False`` to this routine, and it acts exactly like the
Firedrake solve call. This is useful in cases where the solve is known to be irrelevant or diagnostic
for the purposes of the adjoint computation (such as projecting fields to other function spaces
for the purposes of visualisation).
The overloaded solve takes optional callback functions to extract adjoint solutions.
All of the callback functions follow the same signature, taking a single argument of type Function.
Keyword Args:
adj_cb (:obj:`firedrake.function`, optional):
callback function supplying the adjoint solution in the interior. The boundary values are zero.
adj_bdy_cb (:obj:`firedrake.function`, optional):
callback function supplying the adjoint solution on the boundary.
The interior values are not guaranteed to be zero.
adj2_cb (:obj:`firedrake.function`, optional):
callback function supplying the second-order adjoint solution in the interior.
The boundary values are zero.
adj2_bdy_cb (:obj:`firedrake.function`, optional):
callback function supplying the second-order adjoint solution on
the boundary. The interior values are not guaranteed to be zero.
ad_block_tag (:obj:`string`, optional):
tag used to label the resulting block on the Pyadjoint tape. This
is useful for identifying which block is associated with which equation in the forward model.
"""
@wraps(solve)
def wrapper(*args, **kwargs):
ad_block_tag = kwargs.pop("ad_block_tag", None)
annotate = annotate_tape(kwargs)
if annotate:
tape = get_working_tape()
solve_block_type = SolveVarFormBlock
if not isinstance(args[0], ufl.equation.Equation):
solve_block_type = SolveLinearSystemBlock
sb_kwargs = solve_block_type.pop_kwargs(kwargs)
sb_kwargs.update(kwargs)
block = solve_block_type(*args, ad_block_tag=ad_block_tag, **sb_kwargs)
tape.add_block(block)
with stop_annotating():
output = solve(*args, **kwargs)
if annotate:
if hasattr(args[1], "create_block_variable"):
block_variable = args[1].create_block_variable()
else:
block_variable = args[1].function.create_block_variable()
block.add_output(block_variable)
return output
return wrapper
[docs]
def get_solve_blocks():
"""
Extract all blocks of the tape which correspond
to PDE solves, except for those which correspond
to calls of the ``project`` operator.
"""
return [
block
for block in get_working_tape().get_blocks()
if issubclass(type(block), GenericSolveBlock)
and not issubclass(type(block), ProjectBlock)
]