import numpy
from firedrake import Function
from firedrake import NonlinearVariationalProblem as NLVP
from firedrake import NonlinearVariationalSolver as NLVS
from firedrake import TestFunction, assemble, dx, inner, norm
from firedrake.dmhooks import pop_parent, push_parent
from .bcs import EmbeddedBCData, bc2space
from .dirk_stepper import DIRKTimeStepper
from .explicit_stepper import ExplicitTimeStepper
from .getForm import AI, getForm
from .imex import RadauIIAIMEXMethod, DIRKIMEXMethod
from .manipulation import extract_terms
from .stage import StageValueTimeStepper
[docs]
def TimeStepper(F, butcher_tableau, t, dt, u0, **kwargs):
"""Helper function to dispatch between various back-end classes
for doing time stepping. Returns an instance of the
appropriate class.
:arg F: A :class:`ufl.Form` instance describing the semi-discrete problem
F(t, u; v) == 0, where `u` is the unknown
:class:`firedrake.Function and `v` iss the
:class:firedrake.TestFunction`.
:arg butcher_tableau: A :class:`ButcherTableau` instance giving
the Runge-Kutta method to be used for time marching.
:arg t: a :class:`Function` on the Real space over the same mesh as
`u0`. This serves as a variable referring to the current time.
:arg dt: a :class:`Function` on the Real space over the same mesh as
`u0`. This serves as a variable referring to the current time step.
The user may adjust this value between time steps.
:arg u0: A :class:`firedrake.Function` containing the current
state of the problem to be solved.
:arg bcs: An iterable of :class:`firedrake.DirichletBC` containing
the strongly-enforced boundary conditions. Irksome will
manipulate these to obtain boundary conditions for each
stage of the RK method.
:arg nullspace: A list of tuples of the form (index, VSB) where
index is an index into the function space associated with
`u` and VSB is a :class: `firedrake.VectorSpaceBasis`
instance to be passed to a
`firedrake.MixedVectorSpaceBasis` over the larger space
associated with the Runge-Kutta method
:arg stage_type: Whether to formulate in terms of a stage
derivatives or stage values.
:arg splitting: An callable used to factor the Butcher matrix
:arg bc_type: For stage derivative formulation, how to manipulate
the strongly-enforced boundary conditions.
:arg solver_parameters: A :class:`dict` of solver parameters that
will be used in solving the algebraic problem associated
with each time step.
:arg update_solver_parameters: A :class:`dict` of parameters for
inverting the mass matrix at each step (only used if
stage_type is "value")
:arg adaptive_parameters: A :class:`dict` of parameters for use with
adaptive time stepping (only used if stage_type is "deriv")
"""
valid_kwargs_per_stage_type = {
"deriv": ["stage_type", "bcs", "nullspace", "solver_parameters", "appctx",
"bc_type", "splitting", "adaptive_parameters"],
"value": ["stage_type", "basis_type", "bc_constraints", "bcs", "nullspace", "solver_parameters",
"update_solver_parameters", "appctx", "splitting"],
"dirk": ["stage_type", "bcs", "nullspace", "solver_parameters", "appctx"],
"explicit": ["stage_type", "bcs", "solver_parameters", "appctx"],
"imex": ["Fexp", "stage_type", "bcs", "nullspace",
"it_solver_parameters", "prop_solver_parameters",
"splitting", "appctx",
"num_its_initial", "num_its_per_step"],
"dirkimex": ["Fexp", "stage_type", "bcs", "nullspace", "solver_parameters", "mass_parameters", "appctx"]}
valid_adapt_parameters = ["tol", "dtmin", "dtmax", "KI", "KP",
"max_reject", "onscale_factor",
"safety_factor", "gamma0_params"]
stage_type = kwargs.get("stage_type", "deriv")
adapt_params = kwargs.get("adaptive_parameters")
if adapt_params is not None:
assert stage_type == "deriv", "Adaptive time stepping is only implemented for derivative stage type"
for cur_kwarg in kwargs.keys():
if cur_kwarg not in valid_kwargs_per_stage_type:
assert cur_kwarg in valid_kwargs_per_stage_type[stage_type]
if stage_type == "deriv":
bcs = kwargs.get("bcs")
bc_type = kwargs.get("bc_type", "DAE")
splitting = kwargs.get("splitting", AI)
appctx = kwargs.get("appctx")
solver_parameters = kwargs.get("solver_parameters")
nullspace = kwargs.get("nullspace")
if adapt_params is None:
return StageDerivativeTimeStepper(
F, butcher_tableau, t, dt, u0, bcs, appctx=appctx,
solver_parameters=solver_parameters, nullspace=nullspace,
bc_type=bc_type, splitting=splitting)
else:
for param in adapt_params:
assert param in valid_adapt_parameters
tol = adapt_params.get("tol", 1e-3)
dtmin = adapt_params.get("dtmin", 1.e-15)
dtmax = adapt_params.get("dtmax", 1.0)
KI = adapt_params.get("KI", 1/15)
KP = adapt_params.get("KP", 0.13)
max_reject = adapt_params.get("max_reject", 10)
onscale_factor = adapt_params.get("onscale_factor", 1.2)
safety_factor = adapt_params.get("safety_factor", 0.9)
gamma0_params = adapt_params.get("gamma0_params")
return AdaptiveTimeStepper(
F, butcher_tableau, t, dt, u0, bcs, appctx=appctx,
solver_parameters=solver_parameters, nullspace=nullspace,
bc_type=bc_type, splitting=splitting,
tol=tol, dtmin=dtmin, dtmax=dtmax, KI=KI, KP=KP,
max_reject=max_reject, onscale_factor=onscale_factor,
safety_factor=safety_factor, gamma0_params=gamma0_params)
elif stage_type == "value":
bcs = kwargs.get("bcs")
bc_constraints = kwargs.get("bc_constraints")
splitting = kwargs.get("splitting", AI)
appctx = kwargs.get("appctx")
solver_parameters = kwargs.get("solver_parameters")
basis_type = kwargs.get("basis_type")
update_solver_parameters = kwargs.get("update_solver_parameters")
nullspace = kwargs.get("nullspace")
return StageValueTimeStepper(
F, butcher_tableau, t, dt, u0, bcs=bcs, appctx=appctx,
solver_parameters=solver_parameters,
splitting=splitting, basis_type=basis_type,
bc_constraints=bc_constraints,
update_solver_parameters=update_solver_parameters,
nullspace=nullspace)
elif stage_type == "dirk":
bcs = kwargs.get("bcs")
appctx = kwargs.get("appctx")
solver_parameters = kwargs.get("solver_parameters")
nullspace = kwargs.get("nullspace")
return DIRKTimeStepper(
F, butcher_tableau, t, dt, u0, bcs,
solver_parameters, appctx, nullspace)
elif stage_type == "explicit":
bcs = kwargs.get("bcs")
appctx = kwargs.get("appctx")
solver_parameters = kwargs.get("solver_parameters")
return ExplicitTimeStepper(
F, butcher_tableau, t, dt, u0, bcs,
solver_parameters, appctx)
elif stage_type == "imex":
Fexp = kwargs.get("Fexp")
assert Fexp is not None, "Calling an IMEX scheme with no explicit form. Did you really mean to do this?"
bcs = kwargs.get("bcs")
appctx = kwargs.get("appctx")
splitting = kwargs.get("splitting", AI)
it_solver_parameters = kwargs.get("it_solver_parameters")
prop_solver_parameters = kwargs.get("prop_solver_parameters")
nullspace = kwargs.get("nullspace")
num_its_initial = kwargs.get("num_its_initial", 0)
num_its_per_step = kwargs.get("num_its_per_step", 0)
return RadauIIAIMEXMethod(
F, Fexp, butcher_tableau, t, dt, u0, bcs,
it_solver_parameters, prop_solver_parameters,
splitting, appctx, nullspace,
num_its_initial, num_its_per_step)
elif stage_type == "dirkimex":
Fexp = kwargs.get("Fexp")
assert Fexp is not None, "Calling an IMEX scheme with no explicit form. Did you really mean to do this?"
bcs = kwargs.get("bcs")
appctx = kwargs.get("appctx")
solver_parameters = kwargs.get("solver_parameters")
mass_parameters = kwargs.get("mass_parameters")
nullspace = kwargs.get("nullspace")
return DIRKIMEXMethod(
F, Fexp, butcher_tableau, t, dt, u0, bcs,
solver_parameters, mass_parameters, appctx, nullspace)
[docs]
class StageDerivativeTimeStepper:
"""Front-end class for advancing a time-dependent PDE via a Runge-Kutta
method formulated in terms of stage derivatives.
:arg F: A :class:`ufl.Form` instance describing the semi-discrete problem
F(t, u; v) == 0, where `u` is the unknown
:class:`firedrake.Function and `v` is the
:class:firedrake.TestFunction`.
:arg butcher_tableau: A :class:`ButcherTableau` instance giving
the Runge-Kutta method to be used for time marching.
:arg t: a :class:`Function` on the Real space over the same mesh as
`u0`. This serves as a variable referring to the current time.
:arg dt: a :class:`Function` on the Real space over the same mesh as
`u0`. This serves as a variable referring to the current time step.
The user may adjust this value between time steps.
:arg u0: A :class:`firedrake.Function` containing the current
state of the problem to be solved.
:arg bcs: An iterable of :class:`firedrake.DirichletBC` containing
the strongly-enforced boundary conditions. Irksome will
manipulate these to obtain boundary conditions for each
stage of the RK method.
:arg bc_type: How to manipulate the strongly-enforced boundary
conditions to derive the stage boundary conditions.
Should be a string, either "DAE", which implements BCs as
constraints in the style of a differential-algebraic
equation, or "ODE", which takes the time derivative of the
boundary data and evaluates this for the stage values
:arg solver_parameters: A :class:`dict` of solver parameters that
will be used in solving the algebraic problem associated
with each time step.
:arg splitting: An callable used to factor the Butcher matrix
:arg appctx: An optional :class:`dict` containing application context.
This gets included with particular things that Irksome will
pass into the nonlinear solver so that, say, user-defined preconditioners
have access to it.
:arg nullspace: A list of tuples of the form (index, VSB) where
index is an index into the function space associated with
`u` and VSB is a :class: `firedrake.VectorSpaceBasis`
instance to be passed to a
`firedrake.MixedVectorSpaceBasis` over the larger space
associated with the Runge-Kutta method
"""
def __init__(self, F, butcher_tableau, t, dt, u0, bcs=None,
solver_parameters=None, splitting=AI,
appctx=None, nullspace=None, bc_type="DAE"):
self.u0 = u0
self.F = F
self.orig_bcs = bcs
self.t = t
self.dt = dt
self.num_fields = len(u0.function_space())
self.num_stages = len(butcher_tableau.b)
self.butcher_tableau = butcher_tableau
self.num_steps = 0
self.num_nonlinear_iterations = 0
self.num_linear_iterations = 0
bigF, stages, bigBCs, bigNSP = \
getForm(F, butcher_tableau, t, dt, u0, bcs, bc_type, splitting, nullspace)
self.stages = stages
self.bigBCs = bigBCs
problem = NLVP(bigF, stages, bigBCs)
appctx_irksome = {"F": F,
"butcher_tableau": butcher_tableau,
"t": t,
"dt": dt,
"u0": u0,
"bcs": bcs,
"bc_type": bc_type,
"splitting": splitting,
"nullspace": nullspace}
if appctx is None:
appctx = appctx_irksome
else:
appctx = {**appctx, **appctx_irksome}
push_parent(u0.function_space().dm, stages.function_space().dm)
self.solver = NLVS(problem,
appctx=appctx,
solver_parameters=solver_parameters,
nullspace=bigNSP)
pop_parent(u0.function_space().dm, stages.function_space().dm)
if self.num_stages == 1 and self.num_fields == 1:
self.ws = (stages,)
else:
self.ws = stages.subfunctions
A1, A2 = splitting(butcher_tableau.A)
try:
self.updateb = numpy.linalg.solve(A2.T, butcher_tableau.b)
except numpy.linalg.LinAlgError:
raise NotImplementedError("A=A1 A2 splitting needs A2 invertible")
boo = numpy.zeros(self.updateb.shape, dtype=self.updateb.dtype)
boo[-1] = 1
if numpy.allclose(self.updateb, boo):
self._update = self._update_A2Tmb
else:
self._update = self._update_general
def _update_general(self):
"""Assuming the algebraic problem for the RK stages has been
solved, updates the solution. This will not typically be
called by an end user."""
b = self.updateb
dtc = float(self.dt)
u0 = self.u0
ns = self.num_stages
nf = self.num_fields
ws = self.ws
u0bits = u0.subfunctions
for s in range(ns):
for i, u0bit in enumerate(u0bits):
u0bit += dtc * float(b[s]) * ws[nf*s+i]
def _update_A2Tmb(self):
"""Assuming the algebraic problem for the RK stages has been
solved, updates the solution. This will not typically be
called by an end user. This handles the common but highly
specialized case of `w = Ak` or `A = I A` splitting where
A2^{-T} b = e_{num_stages}"""
dtc = float(self.dt)
u0 = self.u0
ns = self.num_stages
nf = self.num_fields
ws = self.ws
u0bits = u0.subfunctions
for i, u0bit in enumerate(u0bits):
u0bit += dtc * ws[nf*(ns-1)+i]
[docs]
def advance(self):
"""Advances the system from time `t` to time `t + dt`.
Note: overwrites the value `u0`."""
push_parent(self.u0.function_space().dm, self.stages.function_space().dm)
self.solver.solve()
pop_parent(self.u0.function_space().dm, self.stages.function_space().dm)
self.num_steps += 1
self.num_nonlinear_iterations += self.solver.snes.getIterationNumber()
self.num_linear_iterations += self.solver.snes.getLinearSolveIterations()
self._update()
[docs]
def solver_stats(self):
return (self.num_steps, self.num_nonlinear_iterations, self.num_linear_iterations)
[docs]
class AdaptiveTimeStepper(StageDerivativeTimeStepper):
"""Front-end class for advancing a time-dependent PDE via an adaptive
Runge-Kutta method.
:arg F: A :class:`ufl.Form` instance describing the semi-discrete problem
F(t, u; v) == 0, where `u` is the unknown
:class:`firedrake.Function and `v` is the
:class:firedrake.TestFunction`.
:arg butcher_tableau: A :class:`ButcherTableau` instance giving
the Runge-Kutta method to be used for time marching.
:arg t: A :class:`firedrake.Constant` instance that always
contains the time value at the beginning of a time step
:arg dt: A :class:`firedrake.Constant` containing the size of the
current time step. The user may adjust this value between
time steps; however, note that the adaptive time step
controls may adjust this before the step is taken.
:arg u0: A :class:`firedrake.Function` containing the current
state of the problem to be solved.
:arg tol: The temporal truncation error tolerance
:arg dtmin: Minimal acceptable time step. An exception is raised
if the step size drops below this threshhold.
:arg dtmax: Maximal acceptable time step, imposed as a hard cap;
this can be adjusted externally once the time-stepper is
instantiated, by modifying `stepper.dt_max`
:arg KI: Integration gain for step-size controller. Should be less
than 1/p, where p is the expected order of the scheme. Larger
values lead to faster (attempted) increases in time-step size
when steps are accepted. See Gustafsson, Lundh, and Soderlind,
BIT 1988.
:arg KP: Proportional gain for step-size controller. Controls dependence
on ratio of (error estimate)/(step size) in determining new
time-step size when steps are accepted. See Gustafsson, Lundh,
and Soderlind, BIT 1988.
:arg max_reject: Maximum number of rejected timesteps in a row that
does not lead to a failure
:arg onscale_factor: Allowable tolerance in determining initial
timestep to be "on scale"
:arg safety_factor: Safety factor used when shrinking timestep if
a proposed step is rejected
:arg gamma0_params: Solver parameters for mass matrix solve when using
an embedded scheme with explicit first stage
:arg bcs: An iterable of :class:`firedrake.DirichletBC` containing
the strongly-enforced boundary conditions. Irksome will
manipulate these to obtain boundary conditions for each
stage of the RK method.
:arg solver_parameters: A :class:`dict` of solver parameters that
will be used in solving the algebraic problem associated
with each time step.
:arg nullspace: A list of tuples of the form (index, VSB) where
index is an index into the function space associated with
`u` and VSB is a :class: `firedrake.VectorSpaceBasis`
instance to be passed to a
`firedrake.MixedVectorSpaceBasis` over the larger space
associated with the Runge-Kutta method
"""
def __init__(self, F, butcher_tableau, t, dt, u0,
bcs=None, appctx=None, solver_parameters=None,
bc_type="DAE", splitting=AI, nullspace=None,
tol=1.e-3, dtmin=1.e-15, dtmax=1.0, KI=1/15, KP=0.13,
max_reject=10, onscale_factor=1.2, safety_factor=0.9,
gamma0_params=None):
assert butcher_tableau.btilde is not None
super(AdaptiveTimeStepper, self).__init__(F, butcher_tableau,
t, dt, u0, bcs=bcs, appctx=appctx, solver_parameters=solver_parameters,
bc_type=bc_type, splitting=splitting, nullspace=nullspace)
from firedrake.petsc import PETSc
self.print = lambda x: PETSc.Sys.Print(x)
self.dt_min = dtmin
self.dt_max = dtmax
self.dt_old = 0.0
self.delb = butcher_tableau.btilde - butcher_tableau.b
self.gamma0 = butcher_tableau.gamma0
self.gamma0_params = gamma0_params
self.KI = KI
self.KP = KP
self.max_reject = max_reject
self.onscale_factor = onscale_factor
self.safety_factor = safety_factor
self.error_func = Function(u0.function_space())
self.tol = tol
self.err_old = 0.0
self.contreject = 0
split_form = extract_terms(F)
self.dtless_form = -split_form.remainder
# Set up and cache boundary conditions for error estimate
embbc = []
if self.gamma0 != 0:
# Grab spaces for BCs
v = F.arguments()[0]
V = v.function_space()
num_fields = len(V)
ws = self.ws
for bc in bcs:
gVsp = bc2space(bc, V)
gdat = EmbeddedBCData(bc, self.t, self.dt, num_fields, butcher_tableau, ws, self.u0)
embbc.append(bc.reconstruct(V=gVsp, g=gdat))
self.embbc = embbc
def _estimate_error(self):
"""Assuming that the RK stages have been evaluated, estimates
the temporal truncation error by taking the norm of the
difference between the new solutions computed by the two
methods. Typically will not be called by the end user."""
dtc = float(self.dt)
delb = self.delb
ws = self.ws
nf = self.num_fields
ns = self.num_stages
u0 = self.u0
# Initialize e to be gamma*h*f(old value of u)
error_func = Function(u0.function_space())
# Only do the hard stuff if gamma0 is not zero
if self.gamma0 != 0.0:
error_test = TestFunction(u0.function_space())
f_form = inner(error_func, error_test)*dx-self.gamma0*dtc*self.dtless_form
f_problem = NLVP(f_form, error_func, bcs=self.embbc)
f_solver = NLVS(f_problem, solver_parameters=self.gamma0_params)
f_solver.solve()
# Accumulate delta-b terms over stages
error_func_bits = error_func.subfunctions
for s in range(ns):
for i, e in enumerate(error_func_bits):
e += dtc*float(delb[s])*ws[nf*s+i]
return norm(assemble(error_func))
[docs]
def advance(self):
"""Attempts to advances the system from time `t` to time `t +
dt`. If the error threshhold is exceeded, will adaptively
decrease the time step until the step is accepted. Also
predicts new time step once the step is accepted.
Note: overwrites the value `u0`."""
if float(self.dt) > self.dt_max:
self.dt.assign(self.dt_max)
self.print("\tTrying dt = %e" % (float(self.dt)))
while 1:
self.solver.solve()
self.num_nonlinear_iterations += self.solver.snes.getIterationNumber()
self.num_linear_iterations += self.solver.snes.getLinearSolveIterations()
err_current = float(self._estimate_error())
err_old = float(self.err_old)
dt_old = float(self.dt_old)
dt_current = float(self.dt)
tol = float(self.tol)
dt_pred = dt_current*((dt_current*tol)/err_current)**(1/self.butcher_tableau.embedded_order)
self.print("\tTruncation error is %e" % (err_current))
# Rejected step shrinks the time-step
if err_current >= dt_current*tol:
dtnew = dt_current*(self.safety_factor*dt_current*tol/err_current)**(1./self.butcher_tableau.embedded_order)
self.print("\tShrinking time-step to %e" % (dtnew))
self.dt.assign(dtnew)
self.contreject += 1
if dtnew <= self.dt_min or numpy.isfinite(dtnew) is False:
raise RuntimeError("The time-step became an invalid number.")
if self.contreject >= self.max_reject:
raise RuntimeError(f"The time-step was rejected {self.max_reject} times in a row. Please increase the tolerance or decrease the starting time-step.")
# Initial time-step selector
elif self.num_steps == 0 and dt_current < self.dt_max and dt_pred > self.onscale_factor*dt_current and self.contreject <= self.max_reject:
# Increase the initial time-step
dtnew = min(dt_pred, self.dt_max)
self.print("\tIncreasing time-step to %e" % (dtnew))
self.dt.assign(dtnew)
self.contreject += 1
# Accepted step increases the time-step
else:
if dt_old != 0.0 and err_old != 0.0 and dt_current < self.dt_max:
dtnew = min(dt_current*((dt_current*tol)/err_current)**self.KI*(err_old/err_current)**self.KP*(dt_current/dt_old)**self.KP, self.dt_max)
self.print("\tThe step was accepted and the new time-step is %e" % (dtnew))
else:
dtnew = min(dt_current, self.dt_max)
self.print("\tThe step was accepted and the time-step remains at %e " % (dtnew))
self._update()
self.contreject = 0
self.num_steps += 1
self.dt_old = self.dt
self.dt.assign(dtnew)
self.err_old = err_current
return (err_current, dt_current)