Source code for firedrake.adjoint_utils.constant

from functools import wraps
from pyadjoint.adjfloat import AdjFloat
from pyadjoint.tape import get_working_tape, annotate_tape
from pyadjoint.overloaded_type import OverloadedType, create_overloaded_object
from pyadjoint.reduced_functional_numpy import gather

from firedrake.functionspace import FunctionSpace
from firedrake.adjoint_utils.blocks import ConstantAssignBlock

from ufl.domain import extract_unique_domain

import numpy


[docs] class ConstantMixin(OverloadedType): @staticmethod def _ad_annotate_init(init): @wraps(init) def wrapper(self, *args, **kwargs): annotate = kwargs.pop("annotate", True) OverloadedType.__init__(self, *args, block_class=kwargs.pop("block_class", None), _ad_floating_active=kwargs.pop("_ad_floating_active", False), _ad_args=kwargs.pop("_ad_args", None), output_block_class=kwargs.pop("output_block_class", None), _ad_output_args=kwargs.pop("_ad_output_args", None), _ad_outputs=kwargs.pop("_ad_outputs", None), annotate=annotate, **kwargs) init(self, *args, **kwargs) other = args[0] if isinstance(other, (type(self), AdjFloat)): self.assign(other, annotate=annotate) return wrapper @staticmethod def _ad_annotate_assign(assign): @wraps(assign) def wrapper(self, *args, **kwargs): ad_block_tag = kwargs.pop("ad_block_tag", None) annotate = annotate_tape(kwargs) if annotate: other = args[0] if not isinstance(other, OverloadedType): other = create_overloaded_object(other) block = ConstantAssignBlock(other, ad_block_tag=ad_block_tag) tape = get_working_tape() tape.add_block(block) ret = assign(self, *args, **kwargs) if annotate: block.add_output(self.create_block_variable()) return ret return wrapper
[docs] def get_derivative(self, options={}): return self._ad_convert_type(self.adj_value, options=options)
def _ad_convert_type(self, value, options={}): if value is None: # TODO: Should the default be 0 constant here or return just None? return type(self)(numpy.zeros(self.ufl_shape)) value = gather(value) return self._constant_from_values(value) def _ad_function_space(self, mesh): element = self.ufl_element() fs_element = element.reconstruct(cell=mesh.ufl_cell()) return FunctionSpace(mesh, fs_element) def _ad_create_checkpoint(self): return self._constant_from_values() def _ad_restore_at_checkpoint(self, checkpoint): return checkpoint def _ad_mul(self, other): return self._constant_from_values(self.dat.data_ro.reshape(-1) * other) def _ad_add(self, other): return self._constant_from_values( self.dat.data_ro.reshape(-1) + other.dat.data_ro.reshape(-1) ) def _ad_dot(self, other, options=None): if type(other) is AdjFloat: return sum(self.dat.data_ro.reshape(-1) * other) else: return sum(self.dat.data_ro.reshape(-1) * other.dat.data_ro.reshape(-1)) @staticmethod def _ad_assign_numpy(dst, src, offset): l = dst.ufl_element().value_size dst.assign(numpy.reshape(src[offset:offset + l], dst.ufl_shape), annotate=False) offset += l return dst, offset @staticmethod def _ad_to_list(m): return m.dat.data_ro.reshape(-1).tolist() def _ad_copy(self): return self._constant_from_values() def _ad_dim(self): return self.dat.cdim def _ad_imul(self, other): self.assign(self._constant_from_values(self.dat.data_ro.reshape(-1) * other)) def _ad_iadd(self, other): self.assign(self._constant_from_values( self.dat.data_ro.reshape(-1) + other.dat.data_ro.reshape(-1) )) def _reduce(self, r, r0): npdata = self.dat.data_ro.reshape(-1) for i in range(len(npdata)): r0 = r(npdata[i], r0) return r0 def _applyUnary(self, f): npdata = self.dat.data_ro.reshape(-1) npdatacopy = npdata.copy() for i in range(len(npdata)): npdatacopy[i] = f(npdata[i]) self.assign(self._constant_from_values(npdatacopy)) def _applyBinary(self, f, y): npdata = self.dat.data_ro.reshape(-1) npdatacopy = self.dat.data_ro.reshape(-1).copy() npdatay = y.dat.data_ro.reshape(-1) for i in range(len(npdata)): npdatacopy[i] = f(npdata[i], npdatay[i]) self.assign(self._constant_from_values(npdatacopy)) def __deepcopy__(self, memodict={}): return self._constant_from_values() def _constant_from_values(self, values=None): """Returns a new Constant with self.dat.data_ro.reshape(-1) while preserving self.ufl_shape. If the optional argument `values` is provided, then `values` will be the values of the new Constant instead, still preserving the ufl_shape of self. Args: values (numpy.array): An optional argument to use instead of ``self.dat.data_ro.reshape(-1)``. Returns: Constant: The created Constant """ values = self.dat.data_ro.reshape(-1) if values is None else values return type(self)(numpy.reshape(values, self.ufl_shape), domain=extract_unique_domain(self))