Source code for irksome.stage_value
# formulate RK methods to solve for stage values rather than the stage derivatives.
import numpy
from FIAT import Bernstein, ufc_simplex
from firedrake import (Function, NonlinearVariationalProblem,
NonlinearVariationalSolver, TestFunction, dx,
inner)
from ufl import zero
from ufl.constantvalue import as_ufl
from .bcs import stage2spaces4bc
from .ButcherTableaux import CollocationButcherTableau
from .deriv import expand_time_derivatives
from .manipulation import extract_terms, strip_dt_form
from .tools import AI, is_ode, replace, vecconst
from .base_time_stepper import StageCoupledTimeStepper
[docs]
def to_value(u0, stages, vandermonde):
"""convert from Bernstein to Lagrange representation
the Bernstein coefficients are [u0; ZZ], and the Lagrange
are [u0; UU] since the value at the left-endpoint is unchanged.
Since u0 is not part of the unknown vector of stages, we disassemble
the Vandermonde matrix (first row is [1, 0, ...]).
"""
ZZ_np = numpy.reshape(stages, (-1, *u0.ufl_shape))
if vandermonde is None:
return ZZ_np
u0_np = numpy.reshape(u0, (-1, *u0.ufl_shape))
u_np = numpy.concatenate((u0_np, ZZ_np))
return vandermonde[1:] @ u_np
[docs]
def getFormStage(F, butch, t, dt, u0, stages, bcs=None, splitting=None, vandermonde=None):
"""Given a time-dependent variational form and a
:class:`ButcherTableau`, produce UFL for the s-stage RK method.
:arg F: UFL form for the semidiscrete ODE/DAE
:arg butch: the :class:`ButcherTableau` for the RK method being used to
advance in time.
:arg t: a :class:`Function` on the Real space over the same mesh as
`u0`. This serves as a variable referring to the current time.
:arg dt: a :class:`Function` on the Real space over the same mesh as
`u0`. This serves as a variable referring to the current time step.
The user may adjust this value between time steps.
:arg u0: a :class:`Function` referring to the state of
the PDE system at time `t`
:arg stages: a :class:`Function` representing the stages to be solved for.
It lives in a :class:`firedrake.FunctionSpace` corresponding to the
s-way tensor product of the space on which the semidiscrete
form lives.
:arg splitting: a callable that maps the (floating point) Butcher matrix
a to a pair of matrices `A1, A2` such that `butch.A = A1 A2`. This is used
to vary between the classical RK formulation and Butcher's reformulation
that leads to a denser mass matrix with block-diagonal stiffness.
Only `AI` and `IA` are currently supported.
:arg vandermonde: a numpy array encoding a change of basis to the Lagrange
polynomials associated with the collocation nodes from some other
(e.g. Bernstein or Chebyshev) basis. This allows us to solve for the
coefficients in some basis rather than the values at particular stages,
which can be useful for satisfying bounds constraints.
If none is provided, we assume it is the identity, working in the
Lagrange basis.
:arg bcs: optionally, a :class:`DirichletBC` object (or iterable thereof)
containing (possibly time-dependent) boundary conditions imposed
on the system.
:arg nullspace: A list of tuples of the form (index, VSB) where
index is an index into the function space associated with `u`
and VSB is a :class: `firedrake.VectorSpaceBasis` instance to
be passed to a `firedrake.MixedVectorSpaceBasis` over the
larger space associated with the Runge-Kutta method
On output, we return a tuple consisting of several parts:
- `Fnew`, the :class:`Form`
- `bcnew`, a list of :class:`firedrake.DirichletBC` objects to be posed
on the stages,
"""
# preprocess time derivatives
F = expand_time_derivatives(F, t=t, timedep_coeffs=(u0,))
v = F.arguments()[0]
V = v.function_space()
assert V == u0.function_space()
c = vecconst(butch.c)
bA1, bA2 = splitting(butch.A)
try:
bA2inv = numpy.linalg.inv(bA2)
except numpy.linalg.LinAlgError:
raise NotImplementedError("We require A = A1 A2 with A2 invertible")
A1 = vecconst(bA1)
A2inv = vecconst(bA2inv)
# s-way product space for the stage variables
num_stages = butch.num_stages
Vbig = stages.function_space()
test = TestFunction(Vbig)
# set up the pieces we need to work with to do our substitutions
v_np = numpy.reshape(test, (num_stages, *u0.ufl_shape))
w_np = to_value(u0, stages, vandermonde)
A1Tv = A1.T @ v_np
A2invTv = A2inv.T @ v_np
# first, process terms with a time derivative. I'm
# assuming we have something of the form inner(Dt(g(u0)), v)*dx
# For each stage i, this gets replaced with
# inner((g(stages[i]) - g(u0))/dt, v)*dx
F = expand_time_derivatives(F, t=t, timedep_coeffs=(u0,))
split_form = extract_terms(F)
F_dtless = strip_dt_form(split_form.time)
F_remainder = split_form.remainder
Fnew = zero()
# Terms with time derivatives
for i in range(num_stages):
repl = {t: t + c[i] * dt,
v: A2invTv[i],
u0: w_np[i] - u0}
Fnew += replace(F_dtless, repl)
# Handle the rest of the terms
for i in range(num_stages):
# replace the solution with stage values
repl = {t: t + c[i] * dt,
v: A1Tv[i] * dt,
u0: w_np[i]}
Fnew += replace(F_remainder, repl)
if bcs is None:
bcs = []
bcsnew = []
if vandermonde is not None:
Vander_inv = vecconst(numpy.linalg.inv(vandermonde.astype(float)))
# For each BC, we need a new BC for each stage
# so we need to figure out how the function is indexed (mixed + vec)
# and then set it to have the value of the original argument at
# time t+C[i]*dt.
for bc in bcs:
bcarg = as_ufl(bc._original_arg)
g_np = numpy.array([replace(bcarg, {t: t + ci * dt}) for ci in c])
if vandermonde is not None:
g_np -= vandermonde[1:, 0] * bcarg
g_np = Vander_inv[1:, 1:] @ g_np
for i in range(num_stages):
Vbigi = stage2spaces4bc(bc, V, Vbig, i)
bcsnew.extend(bc.reconstruct(V=Vbigi, g=g_np[i]))
return Fnew, bcsnew
[docs]
class StageValueTimeStepper(StageCoupledTimeStepper):
def __init__(self, F, butcher_tableau, t, dt, u0, bcs=None,
solver_parameters=None, update_solver_parameters=None,
splitting=AI, basis_type=None,
nullspace=None, appctx=None, bounds=None):
# we can only do DAE-type problems correctly if one assumes a stiffly-accurate method.
assert is_ode(F, u0) or butcher_tableau.is_stiffly_accurate
self.num_fields = len(u0.function_space())
self.butcher_tableau = butcher_tableau
degree = butcher_tableau.num_stages
if basis_type is None:
vandermonde = None
elif basis_type == "Bernstein":
assert isinstance(butcher_tableau, CollocationButcherTableau), "Need collocation for Bernstein conversion"
bern = Bernstein(ufc_simplex(1), degree)
pts = numpy.reshape(numpy.append(0, butcher_tableau.c), (-1, 1))
vandermonde = bern.tabulate(0, pts)[(0, )].T
else:
raise ValueError("Unknown or unimplemented basis transformation type")
if vandermonde is not None:
vandermonde = vecconst(vandermonde)
self.vandermonde = vandermonde
super().__init__(F, t, dt, u0, butcher_tableau.num_stages, bcs=bcs,
solver_parameters=solver_parameters,
appctx=appctx, nullspace=nullspace,
splitting=splitting, butcher_tableau=butcher_tableau, bounds=bounds)
self.appctx["stage_type"] = "value"
self.appctx["vandermonde"] = vandermonde
if (not butcher_tableau.is_stiffly_accurate) and (basis_type != "Bernstein"):
self.unew, self.update_solver = self.get_update_solver(update_solver_parameters)
self._update = self._update_general
else:
self._update = self._update_stiff_acc
def _update_stiff_acc(self):
for i, u0bit in enumerate(self.u0.subfunctions):
u0bit.assign(self.stages.subfunctions[self.num_fields*(self.num_stages-1)+i])
[docs]
def get_update_solver(self, update_solver_parameters):
# only form update stuff if we need it
# which means neither stiffly accurate nor Vandermonde
unew = Function(self.u0.function_space())
v, = self.F.arguments()
Fupdate = inner(unew - self.u0, v) * dx
C = vecconst(self.butcher_tableau.c)
B = vecconst(self.butcher_tableau.b)
t = self.t
dt = self.dt
u0 = self.u0
split_form = extract_terms(self.F)
u_np = to_value(self.u0, self.stages, self.vandermonde)
for i in range(self.num_stages):
repl = {t: t + C[i] * dt,
u0: u_np[i]}
Fupdate += dt * B[i] * replace(split_form.remainder, repl)
# And the BC's for the update -- just the original BC at t+dt
update_bcs = []
for bc in self.orig_bcs:
bcarg = as_ufl(bc._original_arg)
gcur = replace(bcarg, {t: t + dt})
update_bcs.append(bc.reconstruct(g=gcur))
update_problem = NonlinearVariationalProblem(
Fupdate, unew, update_bcs)
update_solver = NonlinearVariationalSolver(
update_problem,
solver_parameters=update_solver_parameters)
return unew, update_solver
def _update_general(self):
self.update_solver.solve()
self.u0.assign(self.unew)
[docs]
def get_form_and_bcs(self, stages, butcher_tableau=None):
if butcher_tableau is None:
butcher_tableau = self.butcher_tableau
return getFormStage(self.F, butcher_tableau,
self.t, self.dt, self.u0,
stages, bcs=self.orig_bcs,
splitting=self.splitting,
vandermonde=self.vandermonde)