import FIAT
import numpy as np
from firedrake import (Function, NonlinearVariationalProblem,
NonlinearVariationalSolver, TestFunction)
from firedrake.dmhooks import pop_parent, push_parent
from ufl.classes import Zero
from .ButcherTableaux import RadauIIA
from .stage import getBits, getFormStage
from .tools import AI, IA, MeshConstant, replace
[docs]
def riia_explicit_coeffs(k):
"""Computes the coefficients needed for the explicit part
of a RadauIIA-IMEX method."""
U = FIAT.ufc_simplex(1)
L = FIAT.GaussRadau(U, k - 1)
Q = FIAT.make_quadrature(L.ref_el, 2*k)
c = np.asarray([list(ell.pt_dict.keys())[0][0]
for ell in L.dual.nodes])
Q = FIAT.make_quadrature(L.ref_el, 2*k)
qpts = Q.get_points()
qwts = Q.get_weights()
A = np.zeros((k, k))
for i in range(k):
qpts_i = 1 + qpts * c[i]
qwts_i = qwts * c[i]
Lvals_i = L.tabulate(0, qpts_i)[0, ]
A[i, :] = Lvals_i @ qwts_i
return A
[docs]
class RadauIIAIMEXMethod:
"""Class for advancing a time-dependent PDE via a polynomial
IMEX/RadauIIA method. This requires one to split the PDE into
an implicit and explicit part.
The class sets up two methods -- `advance` and `iterate`.
The former is used to move the solution forward in time,
while the latter is used both to start the method (filling up
the initial stage values) and can be used at each time step
to increase the accuracy/stability. In the limit as
the iterator is applied many times per time step,
one expects convergence to the solution that would have been
obtained from fully-implicit RadauIIA method.
:arg F: A :class:`ufl.Form` instance describing the implicit part
of the semi-discrete problem
F(t, u; v) == 0, where `u` is the unknown
:class:`firedrake.Function and `v` is the
:class:firedrake.TestFunction`.
:arg Fexp: A :class:`ufl.Form` instance describing the part of the
PDE that is explicitly split off.
:arg butcher_tableau: A :class:`ButcherTableau` instance giving
the Runge-Kutta method to be used for time marching.
Only RadauIIA is allowed here (but it can be any number of stages).
:arg t: a :class:`Function` on the Real space over the same mesh as
`u0`. This serves as a variable referring to the current time.
:arg dt: a :class:`Function` on the Real space over the same mesh as
`u0`. This serves as a variable referring to the current time step.
The user may adjust this value between time steps.
:arg u0: A :class:`firedrake.Function` containing the current
state of the problem to be solved.
:arg bcs: An iterable of :class:`firedrake.DirichletBC` containing
the strongly-enforced boundary conditions. Irksome will
manipulate these to obtain boundary conditions for each
stage of the RK method.
:arg it_solver_parameters: A :class:`dict` of solver parameters that
will be used in solving the algebraic problem associated
with the iterator.
:arg prop_solver_parameters: A :class:`dict` of solver parameters that
will be used in solving the algebraic problem associated
with the propagator.
:arg splitting: A callable used to factor the Butcher matrix,
currently, only AI is supported.
:arg appctx: An optional :class:`dict` containing application context.
:arg nullspace: An optional null space object.
"""
def __init__(self, F, Fexp, butcher_tableau,
t, dt, u0, bcs=None,
it_solver_parameters=None,
prop_solver_parameters=None,
splitting=AI,
appctx=None,
nullspace=None,
num_its_initial=0,
num_its_per_step=0):
assert isinstance(butcher_tableau, RadauIIA)
self.u0 = u0
self.t = t
self.dt = dt
self.num_fields = len(u0.function_space())
self.num_stages = len(butcher_tableau.b)
self.butcher_tableau = butcher_tableau
self.num_its_initial = num_its_initial
self.num_its_per_step = num_its_per_step
# solver statistics
self.num_steps = 0
self.num_props = 0
self.num_its = 0
self.num_nonlinear_iterations_prop = 0
self.num_nonlinear_iterations_it = 0
self.num_linear_iterations_prop = 0
self.num_linear_iterations_it = 0
# Since this assumes stiff accuracy, we drop
# the update information on the floor.
Fbig, _, UU, bigBCs, nsp = getFormStage(
F, butcher_tableau, u0, t, dt, bcs,
splitting=splitting, nullspace=nullspace)
self.UU = UU
self.UU_old = UU_old = Function(UU.function_space())
self.UU_old_split = UU_old.subfunctions
self.bigBCs = bigBCs
Fit, Fprop = getFormExplicit(
Fexp, butcher_tableau, u0, UU_old, t, dt, splitting)
self.itprob = NonlinearVariationalProblem(
Fbig + Fit, UU, bcs=bigBCs)
self.propprob = NonlinearVariationalProblem(
Fbig + Fprop, UU, bcs=bigBCs)
appctx_irksome = {"F": F,
"Fexp": Fexp,
"butcher_tableau": butcher_tableau,
"t": t,
"dt": dt,
"u0": u0,
"bcs": bcs,
"stage_type": "value",
"splitting": splitting,
"nullspace": nullspace}
if appctx is None:
appctx = appctx_irksome
else:
appctx = {**appctx, **appctx_irksome}
push_parent(self.u0.function_space().dm, self.UU.function_space().dm)
self.it_solver = NonlinearVariationalSolver(
self.itprob, appctx=appctx,
solver_parameters=it_solver_parameters,
nullspace=nsp)
self.prop_solver = NonlinearVariationalSolver(
self.propprob, appctx=appctx,
solver_parameters=prop_solver_parameters,
nullspace=nsp)
pop_parent(self.u0.function_space().dm, self.UU.function_space().dm)
num_fields = len(self.u0.function_space())
u0split = u0.subfunctions
for i, u0bit in enumerate(u0split):
for s in range(self.num_stages):
ii = s * num_fields + i
self.UU_old_split[ii].assign(u0bit)
for _ in range(num_its_initial):
self.iterate()
[docs]
def iterate(self):
"""Called 1 or more times to set up the initial state of the
system before time-stepping. Can also be called after each
call to `advance`"""
push_parent(self.u0.function_space().dm, self.UU.function_space().dm)
self.it_solver.solve()
pop_parent(self.u0.function_space().dm, self.UU.function_space().dm)
self.UU_old.assign(self.UU)
self.num_its += 1
self.num_nonlinear_iterations_it += self.it_solver.snes.getIterationNumber()
self.num_linear_iterations_it += self.it_solver.snes.getLinearSolveIterations()
[docs]
def propagate(self):
"""Moves the solution forward in time, to be followed by 0 or
more calls to `iterate`."""
ns = self.num_stages
nf = self.num_fields
u0split = self.u0.subfunctions
for i, u0bit in enumerate(u0split):
u0bit.assign(self.UU_old_split[(ns-1)*nf + i])
push_parent(self.u0.function_space().dm, self.UU.function_space().dm)
ps = self.prop_solver
ps.solve()
pop_parent(self.u0.function_space().dm, self.UU.function_space().dm)
self.UU_old.assign(self.UU)
self.num_props += 1
self.num_nonlinear_iterations_prop += ps.snes.getIterationNumber()
self.num_linear_iterations_prop += ps.snes.getLinearSolveIterations()
[docs]
def advance(self):
self.propagate()
for _ in range(self.num_its_per_step):
self.iterate()
self.num_steps += 1
[docs]
def solver_stats(self):
return (self.num_steps, self.num_props, self.num_its,
self.num_nonlinear_iterations_prop,
self.num_linear_iterations_prop,
self.num_nonlinear_iterations_it,
self.num_linear_iterations_it)