from gusto.rexi.rexi_coefficients import *
from firedrake import Function, DirichletBC, \
LinearVariationalProblem, LinearVariationalSolver
from gusto.core.labels import time_derivative, prognostic, linearisation
from firedrake.fml import (
Term, all_terms, drop, subject,
replace_subject, replace_test_function, replace_trial_function
)
from firedrake.formmanipulation import split_form
NullTerm = Term(None)
[docs]
class Rexi(object):
"""
Class defining the solver for the system
(A_n + tau L)V_n = U
required for computing the matrix exponential.
"""
def __init__(self, equation, rexi_parameters, *, solver_parameters=None,
manager=None, cpx_type='mixed'):
"""
Args:
equation (:class:`PrognosticEquation`): the model's equation
rexi_parameters (:class:`RexiParameters`): Rexi configuration
parameters
solver_parameters (dict, optional): dictionary of parameters to
pass to the solver. Defaults to None.
manager (:class:`.Ensemble`): the space and ensemble sub-
communicators. Defaults to None.
cpx_type (str, optional): implementation of complex-valued space,
can be 'mixed' or 'vector'.
"""
if cpx_type == 'mixed':
from gusto.complex_proxy import mixed as cpx
elif cpx_type == 'vector':
from gusto.complex_proxy import vector as cpx
else:
raise ValueError("cpx_type must be 'mixed' or 'vector'")
self.cpx = cpx
residual = equation.residual.label_map(
lambda t: t.has_label(linearisation),
map_if_true=lambda t: Term(t.get(linearisation).form, t.labels),
map_if_false=drop)
residual = residual.label_map(
all_terms,
lambda t: replace_trial_function(t.get(subject))(t))
# Get the Rexi Coefficients, given the values of h and M in
# rexi_parameters
self.alpha, self.beta, self.beta2 = RexiCoefficients(rexi_parameters)
self.manager = manager
# define the start point of the solver loop (idx) and the
# number of solvers (N) for this process depending on the
# total number of solvers (nsolvers) and how many ensemble
# processes (neprocs) there are.
nsolvers = len(self.alpha)
if manager is None:
# if running in serial we loop over all the solvers, from
# 0: nsolvers
self.N = nsolvers
self.idx = 0
else:
rank = manager.ensemble_comm.rank
neprocs = manager.ensemble_comm.size
m = int(nsolvers/neprocs)
p = nsolvers - m*neprocs
if rank < p:
self.N = m+1
self.idx = rank*(m+1)
else:
self.N = m
self.idx = rank*m + p
# set up complex function space
W_ = equation.function_space
W = cpx.FunctionSpace(W_)
self.U0 = Function(W_) # right hand side function
self.w = Function(W) # solution
self.wrk = Function(W_) # working buffer
ncpts = len(W_)
# split equation into mass matrix and linear operator
mass = residual.label_map(
lambda t: t.has_label(time_derivative),
map_if_false=drop)
function = residual.label_map(
lambda t: t.has_label(time_derivative),
map_if_true=drop)
# generate ufl for mass matrix over given trial/tests
def form_mass(*trials_and_tests):
trials = trials_and_tests[:ncpts]
tests = trials_and_tests[ncpts:]
m = mass.label_map(
all_terms,
replace_test_function(tests))
m = m.label_map(
all_terms,
replace_subject(trials))
return m
# generate ufl for linear operator over given trial/tests
def form_function(*trials_and_tests):
trials = trials_and_tests[:ncpts]
tests = trials_and_tests[ncpts:]
f = NullTerm
for i in range(ncpts):
fi = function.label_map(
lambda t: t.get(prognostic) == equation.field_names[i],
lambda t: Term(
split_form(t.form)[i].form,
t.labels),
map_if_false=drop)
fi = fi.label_map(
all_terms,
replace_test_function(tests[i]))
fi = fi.label_map(
all_terms,
replace_subject(trials))
f += fi
f = f.label_map(lambda t: t is NullTerm, drop)
return f
# generate ufl for right hand side over given trial/tests
def form_rhs(*tests):
return form_mass(*self.U0.subfunctions, *tests)
# complex Constants for alpha and beta values
self.ac = cpx.ComplexConstant(1)
self.bc = cpx.ComplexConstant(1)
# alpha*M and tau*L
aM = cpx.BilinearForm(W, self.ac, form_mass)
aL, self.tau, _ = cpx.BilinearForm(W, 1, form_function, return_z=True)
a = aM - aL
# right hand side is just U0
b = cpx.LinearForm(W, 1, form_rhs)
if hasattr(equation, "aP"):
aP = equation.aP(trial, self.ai, self.tau)
else:
aP = None
# BCs are declared for the plain velocity space.
# First we need to transfer the velocity boundary conditions to the
# velocity component of the mixed space.
uidx = equation.field_names.index('u')
ubcs = (DirichletBC(W_.sub(uidx), bc.function_arg, bc.sub_domain)
for bc in equation.bcs['u'])
# now we can transfer the velocity boundary conditions to the complex space
bcs = tuple(cb for bc in ubcs for cb in cpx.DirichletBC(W, W_, bc))
rexi_prob = LinearVariationalProblem(a.form, b.form, self.w, aP=aP,
bcs=bcs,
constant_jacobian=False)
# if solver_parameters is None:
# solver_parameters = equation.solver_parameters
self.solver = LinearVariationalSolver(
rexi_prob, solver_parameters=solver_parameters)
[docs]
def solve(self, x_out, x_in, dt):
"""
Solve method for approximating the matrix exponential by a
rational sum. Solves
(A_n + tau L)V_n = U
multiplies by the corresponding B_n and sums over n.
:arg x_in: the mixed function on the rhs.
:arg dt: the value of tau
"""
cpx = self.cpx
# assign tau and U0 and initialise solution to 0.
self.tau.assign(dt)
self.U0.assign(x_in)
x_out.assign(0.)
# loop over solvers, assigning a_i, solving and accumulating the sum
for i in range(self.N):
j = self.idx + i
self.ac.real.assign(self.alpha[j].real)
self.ac.imag.assign(self.alpha[j].imag)
self.bc.real.assign(self.beta[j].real)
self.bc.imag.assign(self.beta[j].imag)
self.solver.solve()
# accumulate real part of beta*w
cpx.get_real(self.w, self.wrk)
x_out += self.bc.real*self.wrk
cpx.get_imag(self.w, self.wrk)
x_out -= self.bc.imag*self.wrk
# in parallel we have to accumulate the sum over all processes
if self.manager is not None:
self.wrk.assign(x_out)
self.manager.allreduce(self.wrk, x_out)