Source code for firedrake.adjoint_utils.dirichletbc
from functools import wraps
from pyadjoint.overloaded_type import FloatingType
from .blocks import DirichletBCBlock
from pyadjoint.tape import stop_annotating, annotate_tape
[docs]
class DirichletBCMixin(FloatingType):
@staticmethod
def _ad_annotate_init(init):
@wraps(init)
def wrapper(self, *args, **kwargs):
FloatingType.__init__(self,
*args,
block_class=DirichletBCBlock,
_ad_args=args,
_ad_floating_active=True,
**kwargs)
init(self, *args, **kwargs)
return wrapper
@staticmethod
def _ad_annotate_apply(apply):
@wraps(apply)
def wrapper(self, *args, **kwargs):
annotate = annotate_tape(kwargs)
if annotate:
for arg in args:
if not hasattr(arg, "bcs"):
arg.bcs = []
arg.bcs.append(self)
with stop_annotating():
ret = apply(self, *args, **kwargs)
return ret
return wrapper
def _ad_create_checkpoint(self):
deps = self.block.get_dependencies()
if len(deps) <= 0:
# We don't have any dependencies so the supplied value was not an OverloadedType.
# Most probably it was just a float that is immutable so will never change.
return None
return deps[0]
def _ad_restore_at_checkpoint(self, checkpoint):
if checkpoint is not None:
self.set_value(checkpoint.saved_output)
return self