Source code for irksome.deriv

from ufl.differentiation import Derivative
from ufl.core.ufl_type import ufl_type
from ufl.corealg.multifunction import MultiFunction
from ufl.algorithms.map_integrands import map_integrand_dags, map_expr_dag
from ufl.algorithms.apply_derivatives import GenericDerivativeRuleset


[docs] @ufl_type(num_ops=1, inherit_shape_from_operand=0, inherit_indices_from_operand=0) class TimeDerivative(Derivative): """UFL node representing a time derivative of some quantity/field. Note: Currently form compilers do not understand how to process these nodes. Instead, Irksome pre-processes forms containing `TimeDerivative` nodes.""" __slots__ = () def __new__(cls, f): return Derivative.__new__(cls) def __init__(self, f): Derivative.__init__(self, (f,)) def __str__(self): return "d{%s}/dt" % (self.ufl_operands[0],)
[docs] def Dt(f): """Short-hand function to produce a :class:`TimeDerivative` of the input.""" return TimeDerivative(f)
[docs] class TimeDerivativeRuleset(GenericDerivativeRuleset): """Apply AD rules to time derivative expressions. WIP""" def __init__(self, t, timedep_coeffs): GenericDerivativeRuleset.__init__(self, ()) self.t = t self.timedep_coeffs = timedep_coeffs
[docs] def coefficient(self, o): if o in self.timedep_coeffs: return TimeDerivative(o) else: return self.independent_terminal(o)
# def indexed(self, o, Ap, ii): # print(o, type(o)) # print(Ap, type(Ap)) # print(ii, type(ii)) # 1/0 # mapping rules to splat out time derivatives so that replacement should # work on more complex problems.
[docs] class TimeDerivativeRuleDispatcher(MultiFunction): def __init__(self, t, timedep_coeffs): MultiFunction.__init__(self) self.t = t self.timedep_coeffs = timedep_coeffs
[docs] def terminal(self, o): return o
[docs] def derivative(self, o): raise NotImplementedError("Missing derivative handler for {0}.".format(type(o).__name__))
expr = MultiFunction.reuse_if_untouched
[docs] def grad(self, o): from firedrake import grad if isinstance(o, TimeDerivative): return TimeDerivative(grad(*o.ufl_operands)) return o
[docs] def div(self, o): return o
[docs] def reference_grad(self, o): return o
[docs] def coefficient_derivative(self, o): return o
[docs] def coordinate_derivative(self, o): return o
[docs] def time_derivative(self, o): f, = o.ufl_operands rules = TimeDerivativeRuleset(self.t, self.timedep_coeffs) return map_expr_dag(rules, f)
[docs] def apply_time_derivatives(expression, t, timedep_coeffs=[]): rules = TimeDerivativeRuleDispatcher(t, timedep_coeffs) return map_integrand_dags(rules, expression)