Source code for firedrake.mg.embedded

import firedrake
import ufl
import finat.ufl
import weakref
from functools import reduce
from enum import IntEnum
from operator import and_
from firedrake.petsc import PETSc
from firedrake.embedding import get_embedding_dg_element


__all__ = ("TransferManager", )


native = frozenset(["Lagrange", "Discontinuous Lagrange", "Real", "Q", "DQ"])


class Op(IntEnum):
    PROLONG = 0
    RESTRICT = 1
    INJECT = 2


[docs] class TransferManager(object):
[docs] class Cache(object): """A caching object for work vectors and matrices. :arg element: The element to use for the caching.""" def __init__(self, element): self.embedding_element = get_embedding_dg_element(element) self._dat_versions = {} self._V_DG_mass = {} self._DG_inv_mass = {} self._V_approx_inv_mass = {} self._V_inv_mass_ksp = {} self._DG_work = {} self._work_vec = {} self._V_dof_weights = {}
def __init__(self, *, native_transfers=None, use_averaging=True): """ An object for managing transfers between levels in a multigrid hierarchy (possibly via embedding in DG spaces). :arg native_transfers: dict mapping UFL element to "natively supported" transfer operators. This should be a three-tuple of (prolong, restrict, inject). :arg use_averaging: Use averaging to approximate the projection out of the embedded DG space? If False, a global L2 projection will be performed. """ self.native_transfers = native_transfers or {} self.use_averaging = use_averaging self.caches = {}
[docs] def is_native(self, element): if element in self.native_transfers.keys(): return True if isinstance(element.cell, ufl.TensorProductCell) and len(element.sub_elements) > 0: return reduce(and_, map(self.is_native, element.sub_elements)) return element.family() in native
def _native_transfer(self, element, op): try: return self.native_transfers[element][op] except KeyError: if self.is_native(element): ops = firedrake.prolong, firedrake.restrict, firedrake.inject return self.native_transfers.setdefault(element, ops)[op] return None
[docs] def cache(self, element): try: return self.caches[element] except KeyError: return self.caches.setdefault(element, TransferManager.Cache(element))
[docs] def V_dof_weights(self, V): """Dof weights for averaging projection. :arg V: function space to compute weights for. :returns: A PETSc Vec. """ cache = self.cache(V.ufl_element()) key = V.dim() try: return cache._V_dof_weights[key] except KeyError: # Compute dof multiplicity for V # Spin over all (owned) cells incrementing visible dofs by 1. # After halo exchange, the Vec representation is the # global Vector counting the number of cells that see each # dof. f = firedrake.Function(V) firedrake.par_loop(("{[i, j]: 0 <= i < A.dofs and 0 <= j < %d}" % V.value_size, "A[i, j] = A[i, j] + 1"), firedrake.dx, {"A": (f, firedrake.INC)}) with f.dat.vec_ro as fv: return cache._V_dof_weights.setdefault(key, fv.copy())
[docs] def V_DG_mass(self, V, DG): """ Mass matrix from between V and DG spaces. :arg V: a function space :arg DG: the DG space :returns: A PETSc Mat mapping from V -> DG """ cache = self.cache(V.ufl_element()) key = V.dim() try: return cache._V_DG_mass[key] except KeyError: M = firedrake.assemble(firedrake.inner(firedrake.TrialFunction(V), firedrake.TestFunction(DG))*firedrake.dx) return cache._V_DG_mass.setdefault(key, M.petscmat)
[docs] def DG_inv_mass(self, DG): """ Inverse DG mass matrix :arg DG: the DG space :returns: A PETSc Mat. """ cache = self.caches[DG.ufl_element()] key = DG.dim() try: return cache._DG_inv_mass[key] except KeyError: M = firedrake.assemble(firedrake.Tensor(firedrake.inner(firedrake.TrialFunction(DG), firedrake.TestFunction(DG))*firedrake.dx).inv) return cache._DG_inv_mass.setdefault(key, M.petscmat)
[docs] def V_approx_inv_mass(self, V, DG): """ Approximate inverse mass. Computes (cellwise) (V, V)^{-1} (V, DG). :arg V: a function space :arg DG: the DG space :returns: A PETSc Mat mapping from V -> DG. """ cache = self.cache(V.ufl_element()) key = V.dim() try: return cache._V_approx_inv_mass[key] except KeyError: a = firedrake.Tensor(firedrake.inner(firedrake.TrialFunction(V), firedrake.TestFunction(V))*firedrake.dx) b = firedrake.Tensor(firedrake.inner(firedrake.TrialFunction(DG), firedrake.TestFunction(V))*firedrake.dx) M = firedrake.assemble(a.inv * b) return cache._V_approx_inv_mass.setdefault(key, M.petscmat)
[docs] def V_inv_mass_ksp(self, V): """ A KSP inverting a mass matrix :arg V: a function space. :returns: A PETSc KSP for inverting (V, V). """ cache = self.cache(V.ufl_element()) key = V.dim() try: return cache._V_inv_mass_ksp[key] except KeyError: M = firedrake.assemble(firedrake.inner(firedrake.TrialFunction(V), firedrake.TestFunction(V))*firedrake.dx) ksp = PETSc.KSP().create(comm=V._comm) ksp.setOperators(M.petscmat) ksp.setOptionsPrefix("{}_prolongation_mass_".format(V.ufl_element()._short_name)) ksp.setType("preonly") ksp.pc.setType("cholesky") ksp.setFromOptions() ksp.setUp() return cache._V_inv_mass_ksp.setdefault(key, ksp)
[docs] def DG_work(self, V): """A DG work Function matching V :arg V: a function space. :returns: A Function in the embedding DG space. """ needs_dual = ufl.duals.is_dual(V) cache = self.cache(V.ufl_element()) key = (V.dim(), needs_dual) try: return cache._DG_work[key] except KeyError: if needs_dual: primal = self.DG_work(V.dual()) dual = primal.riesz_representation(riesz_map="l2") return cache._DG_work.setdefault(key, dual) DG = firedrake.FunctionSpace(V.mesh(), cache.embedding_element) return cache._DG_work.setdefault(key, firedrake.Function(DG))
[docs] def work_vec(self, V): """A work Vec for V :arg V: a function space. :returns: A PETSc Vec for V. """ cache = self.cache(V.ufl_element()) key = V.dim() try: return cache._work_vec[key] except KeyError: return cache._work_vec.setdefault(key, V.dof_dset.layout_vec.duplicate())
[docs] def requires_transfer(self, element, transfer_op, source, target): """Determine whether either the source or target have been modified since the last time a grid transfer was executed with them.""" key = (transfer_op, weakref.ref(source.dat), weakref.ref(target.dat)) dat_versions = (source.dat.dat_version, target.dat.dat_version) try: return self.cache(element)._dat_versions[key] != dat_versions except KeyError: return True
[docs] def cache_dat_versions(self, element, transfer_op, source, target): """Record the returned dat_versions of the source and target.""" key = (transfer_op, weakref.ref(source.dat), weakref.ref(target.dat)) dat_versions = (source.dat.dat_version, target.dat.dat_version) self.cache(element)._dat_versions[key] = dat_versions
[docs] @PETSc.Log.EventDecorator() def op(self, source, target, transfer_op): """Primal transfer (either prolongation or injection). :arg source: The source :class:`.Function`. :arg target: The target :class:`.Function`. :arg transfer_op: The transfer operation for the DG space. """ Vs = source.function_space() Vt = target.function_space() source_element = Vs.ufl_element() target_element = Vt.ufl_element() if not self.requires_transfer(source_element, transfer_op, source, target): return if self.is_native(source_element) and self.is_native(target_element): self._native_transfer(source_element, transfer_op)(source, target) elif type(source_element) is finat.ufl.MixedElement: assert type(target_element) is finat.ufl.MixedElement for source_, target_ in zip(source.subfunctions, target.subfunctions): self.op(source_, target_, transfer_op=transfer_op) else: # Get some work vectors dgsource = self.DG_work(Vs) dgtarget = self.DG_work(Vt) VDGs = dgsource.function_space() VDGt = dgtarget.function_space() dgwork = self.work_vec(VDGs) # Project into DG space # u \in Vs -> u \in VDGs with source.dat.vec_ro as sv, dgsource.dat.vec_wo as dgv: self.V_DG_mass(Vs, VDGs).mult(sv, dgwork) self.DG_inv_mass(VDGs).mult(dgwork, dgv) # Transfer # u \in VDGs -> u \in VDGt self.op(dgsource, dgtarget, transfer_op) # Project back # u \in VDGt -> u \in Vt with dgtarget.dat.vec_ro as dgv, target.dat.vec_wo as t: if self.use_averaging: self.V_approx_inv_mass(Vt, VDGt).mult(dgv, t) t.pointwiseDivide(t, self.V_dof_weights(Vt)) else: work = self.work_vec(Vt) self.V_DG_mass(Vt, VDGt).multTranspose(dgv, work) self.V_inv_mass_ksp(Vt).solve(work, t) self.cache_dat_versions(source_element, transfer_op, source, target)
[docs] def prolong(self, uc, uf): """Prolong a function. :arg uc: The source (coarse grid) function. :arg uf: The target (fine grid) function. """ self.op(uc, uf, transfer_op=Op.PROLONG)
[docs] def inject(self, uf, uc): """Inject a function (primal restriction) :arg uf: The source (fine grid) function. :arg uc: The target (coarse grid) function. """ self.op(uf, uc, transfer_op=Op.INJECT)
[docs] def restrict(self, source, target): """Restrict a dual function. :arg source: The source (fine grid) :class:`.Cofunction`. :arg target: The target (coarse grid) :class:`.Cofunction`. """ Vs_star = source.function_space() Vt_star = target.function_space() source_element = Vs_star.ufl_element() target_element = Vt_star.ufl_element() if not self.requires_transfer(source_element, Op.RESTRICT, source, target): return if self.is_native(source_element) and self.is_native(target_element): self._native_transfer(source_element, Op.RESTRICT)(source, target) elif type(source_element) is finat.ufl.MixedElement: assert type(target_element) is finat.ufl.MixedElement for source_, target_ in zip(source.subfunctions, target.subfunctions): self.restrict(source_, target_) else: Vs = Vs_star.dual() Vt = Vt_star.dual() # Get some work vectors dgsource = self.DG_work(Vs_star) dgtarget = self.DG_work(Vt_star) VDGs = dgsource.function_space().dual() VDGt = dgtarget.function_space().dual() work = self.work_vec(Vs) dgwork = self.work_vec(VDGt) # g \in Vs^* -> g \in VDGs^* with source.dat.vec_ro as sv, dgsource.dat.vec_wo as dgv: if self.use_averaging: work.pointwiseDivide(sv, self.V_dof_weights(Vs)) self.V_approx_inv_mass(Vs, VDGs).multTranspose(work, dgv) else: self.V_inv_mass_ksp(Vs).solve(sv, work) self.V_DG_mass(Vs, VDGs).mult(work, dgv) # g \in VDGs^* -> g \in VDGt^* self.restrict(dgsource, dgtarget) # g \in VDGt^* -> g \in Vt^* with dgtarget.dat.vec_ro as dgv, target.dat.vec_wo as t: self.DG_inv_mass(VDGt).mult(dgv, dgwork) self.V_DG_mass(Vt, VDGt).multTranspose(dgwork, t) self.cache_dat_versions(source_element, Op.RESTRICT, source, target)