Source code for firedrake.adjoint_utils.assembly

import numbers
from functools import wraps
from pyadjoint.tape import annotate_tape, stop_annotating, get_working_tape
from pyadjoint.overloaded_type import create_overloaded_object
from firedrake.adjoint_utils.blocks import AssembleBlock


[docs] def annotate_assemble(assemble): @wraps(assemble) def wrapper(form, *args, **kwargs): """When a form is assembled, the information about its nonlinear dependencies is lost, and it is no longer easy to manipulate. Therefore, we decorate :func:`.assemble` to *attach the form to the assembled object*. This lets the automatic annotation work, even when the user calls the lower-level :py:data:`solve(A, x, b)`. """ ad_block_tag = kwargs.pop("ad_block_tag", None) annotate = annotate_tape(kwargs) with stop_annotating(): from firedrake.assemble import BaseFormAssembler from firedrake.slate import slate if not isinstance(form, slate.TensorBase): # Preprocess the form at the annotation stage so that the `AssembleBlock` # records the preprocessed form. This facilitates derivation of the tangent linear/adjoint models. # For example, # -> `interp = Action(Interpolate(v1, v0), f)` with `v1` and `v0` being respectively `Argument` # and `Coargument`. Differentiating `interp` is not currently supported as the action's left slot # is a 2-form. However, after preprocessing, we obtain `Interpolate(f, v0)`, which can be differentiated. form = BaseFormAssembler.preprocess_base_form(form) kwargs['is_base_form_preprocessed'] = True output = assemble(form, *args, **kwargs) from firedrake.function import Function from firedrake.cofunction import Cofunction if isinstance(output, (numbers.Complex, Function, Cofunction)): # Assembling a 0-form or 1-form (e.g. Form or BaseFormOperator) if not annotate: return output if not isinstance(output, (float, Function, Cofunction)): raise NotImplementedError("Taping for complex-valued 0-forms not yet done!") output = create_overloaded_object(output) block = AssembleBlock(form, ad_block_tag=ad_block_tag) tape = get_working_tape() tape.add_block(block) if kwargs.get("tensor") is not None: # Create a new block variable when a tensor is provided to the assembly. # This is necessary as this tensor may belong to the block dependency as well, # which would result in a cyclic dependency. # Example (self-interpolation): # -> u.interpolate(u + c), with `u` a Function and `c` a Constant. block.add_output(output.create_block_variable()) else: block.add_output(output.block_variable) else: # Assembled a 2-form output.form = form return output return wrapper