Source code for firedrake.adjoint_utils.projection
from functools import wraps
from pyadjoint.tape import annotate_tape, stop_annotating, get_working_tape
from firedrake.adjoint_utils.blocks import ProjectBlock, SupermeshProjectBlock
from firedrake import function
from ufl.domain import extract_unique_domain
[docs]
def annotate_project(project):
@wraps(project)
def wrapper(*args, **kwargs):
"""The project call performs an equation solve, and so it too must be annotated so that the
adjoint and tangent linear models may be constructed automatically by pyadjoint.
To disable the annotation of this function, just pass :py:data:`annotate=False`. This is useful in
cases where the solve is known to be irrelevant or diagnostic for the purposes of the adjoint
computation (such as projecting fields to other function spaces for the purposes of
visualisation)."""
ad_block_tag = kwargs.pop("ad_block_tag", None)
annotate = annotate_tape(kwargs)
if annotate:
bcs = kwargs.get("bcs", [])
sb_kwargs = ProjectBlock.pop_kwargs(kwargs)
if isinstance(args[1], function.Function):
# block should be created before project because output might also be an input that needs checkpointing
output = args[1]
V = output.function_space()
if isinstance(args[0], function.Function) and extract_unique_domain(args[0]) != V.mesh():
block = SupermeshProjectBlock(args[0], V, output, bcs, ad_block_tag=ad_block_tag, **sb_kwargs)
else:
block = ProjectBlock(args[0], V, output, bcs, ad_block_tag=ad_block_tag, **sb_kwargs)
with stop_annotating():
output = project(*args, **kwargs)
if annotate:
tape = get_working_tape()
if not isinstance(args[1], function.Function):
if isinstance(args[0], function.Function) and extract_unique_domain(args[0]) != args[1].mesh():
block = SupermeshProjectBlock(args[0], args[1], output, bcs, ad_block_tag=ad_block_tag, **sb_kwargs)
else:
block = ProjectBlock(args[0], args[1], output, bcs, ad_block_tag=ad_block_tag, **sb_kwargs)
tape.add_block(block)
block.add_output(output.create_block_variable())
return output
return wrapper