Source code for gusto.rexi.rexi

from gusto.rexi.rexi_coefficients import *
from firedrake import Function, DirichletBC, \
    LinearVariationalProblem, LinearVariationalSolver
from gusto.core.labels import time_derivative, prognostic, linearisation
from firedrake.fml import (
    Term, all_terms, drop, subject,
    replace_subject, replace_test_function, replace_trial_function
)
from firedrake.formmanipulation import split_form

NullTerm = Term(None)


[docs] class Rexi(object): """ Class defining the solver for the system (A_n + tau L)V_n = U required for computing the matrix exponential. """ def __init__(self, equation, rexi_parameters, *, solver_parameters=None, manager=None, cpx_type='mixed'): """ Args: equation (:class:`PrognosticEquation`): the model's equation rexi_parameters (:class:`RexiParameters`): Rexi configuration parameters solver_parameters (dict, optional): dictionary of parameters to pass to the solver. Defaults to None. manager (:class:`.Ensemble`): the space and ensemble sub- communicators. Defaults to None. cpx_type (str, optional): implementation of complex-valued space, can be 'mixed' or 'vector'. """ if cpx_type == 'mixed': from gusto.complex_proxy import mixed as cpx elif cpx_type == 'vector': from gusto.complex_proxy import vector as cpx else: raise ValueError("cpx_type must be 'mixed' or 'vector'") self.cpx = cpx residual = equation.residual.label_map( lambda t: t.has_label(linearisation), map_if_true=lambda t: Term(t.get(linearisation).form, t.labels), map_if_false=drop) residual = residual.label_map( all_terms, lambda t: replace_trial_function(t.get(subject))(t)) # Get the Rexi Coefficients, given the values of h and M in # rexi_parameters self.alpha, self.beta, self.beta2 = RexiCoefficients(rexi_parameters) self.manager = manager # define the start point of the solver loop (idx) and the # number of solvers (N) for this process depending on the # total number of solvers (nsolvers) and how many ensemble # processes (neprocs) there are. nsolvers = len(self.alpha) if manager is None: # if running in serial we loop over all the solvers, from # 0: nsolvers self.N = nsolvers self.idx = 0 else: rank = manager.ensemble_comm.rank neprocs = manager.ensemble_comm.size m = int(nsolvers/neprocs) p = nsolvers - m*neprocs if rank < p: self.N = m+1 self.idx = rank*(m+1) else: self.N = m self.idx = rank*m + p # set up complex function space W_ = equation.function_space W = cpx.FunctionSpace(W_) self.U0 = Function(W_) # right hand side function self.w = Function(W) # solution self.wrk = Function(W_) # working buffer ncpts = len(W_) # split equation into mass matrix and linear operator mass = residual.label_map( lambda t: t.has_label(time_derivative), map_if_false=drop) function = residual.label_map( lambda t: t.has_label(time_derivative), map_if_true=drop) # generate ufl for mass matrix over given trial/tests def form_mass(*trials_and_tests): trials = trials_and_tests[:ncpts] tests = trials_and_tests[ncpts:] m = mass.label_map( all_terms, replace_test_function(tests)) m = m.label_map( all_terms, replace_subject(trials)) return m # generate ufl for linear operator over given trial/tests def form_function(*trials_and_tests): trials = trials_and_tests[:ncpts] tests = trials_and_tests[ncpts:] f = NullTerm for i in range(ncpts): fi = function.label_map( lambda t: t.get(prognostic) == equation.field_names[i], lambda t: Term( split_form(t.form)[i].form, t.labels), map_if_false=drop) fi = fi.label_map( all_terms, replace_test_function(tests[i])) fi = fi.label_map( all_terms, replace_subject(trials)) f += fi f = f.label_map(lambda t: t is NullTerm, drop) return f # generate ufl for right hand side over given trial/tests def form_rhs(*tests): return form_mass(*self.U0.subfunctions, *tests) # complex Constants for alpha and beta values self.ac = cpx.ComplexConstant(1) self.bc = cpx.ComplexConstant(1) # alpha*M and tau*L aM = cpx.BilinearForm(W, self.ac, form_mass) aL, self.tau, _ = cpx.BilinearForm(W, 1, form_function, return_z=True) a = aM - aL # right hand side is just U0 b = cpx.LinearForm(W, 1, form_rhs) if hasattr(equation, "aP"): aP = equation.aP(trial, self.ai, self.tau) else: aP = None # BCs are declared for the plain velocity space. # First we need to transfer the velocity boundary conditions to the # velocity component of the mixed space. uidx = equation.field_names.index('u') ubcs = (DirichletBC(W_.sub(uidx), bc.function_arg, bc.sub_domain) for bc in equation.bcs['u']) # now we can transfer the velocity boundary conditions to the complex space bcs = tuple(cb for bc in ubcs for cb in cpx.DirichletBC(W, W_, bc)) rexi_prob = LinearVariationalProblem(a.form, b.form, self.w, aP=aP, bcs=bcs, constant_jacobian=False) # if solver_parameters is None: # solver_parameters = equation.solver_parameters self.solver = LinearVariationalSolver( rexi_prob, solver_parameters=solver_parameters)
[docs] def solve(self, x_out, x_in, dt): """ Solve method for approximating the matrix exponential by a rational sum. Solves (A_n + tau L)V_n = U multiplies by the corresponding B_n and sums over n. :arg x_in: the mixed function on the rhs. :arg dt: the value of tau """ cpx = self.cpx # assign tau and U0 and initialise solution to 0. self.tau.assign(dt) self.U0.assign(x_in) x_out.assign(0.) # loop over solvers, assigning a_i, solving and accumulating the sum for i in range(self.N): j = self.idx + i self.ac.real.assign(self.alpha[j].real) self.ac.imag.assign(self.alpha[j].imag) self.bc.real.assign(self.beta[j].real) self.bc.imag.assign(self.beta[j].imag) self.solver.solve() # accumulate real part of beta*w cpx.get_real(self.w, self.wrk) x_out += self.bc.real*self.wrk cpx.get_imag(self.w, self.wrk) x_out -= self.bc.imag*self.wrk # in parallel we have to accumulate the sum over all processes if self.manager is not None: self.wrk.assign(x_out) self.manager.allreduce(self.wrk, x_out)