Source code for firedrake.tsfc_interface

"""
Provides the interface to TSFC for compiling a form, and
transforms the TSFC-generated code to make it suitable for
passing to the backends.

"""
from os import path, environ, getuid, makedirs
import tempfile
import collections
import cachetools

import ufl
import finat.ufl
from ufl import Form, conj
from .ufl_expr import TestFunction

from tsfc import compile_form as original_tsfc_compile_form
from tsfc.parameters import PARAMETERS as tsfc_default_parameters
from tsfc.ufl_utils import extract_firedrake_constants

from pyop2 import op2
from pyop2.caching import memory_and_disk_cache, default_parallel_hashkey
from pyop2.mpi import COMM_WORLD

from firedrake.formmanipulation import split_form
from firedrake.parameters import parameters as default_parameters
from firedrake.petsc import PETSc
from firedrake import utils

# Set TSFC default scalar type at load time
tsfc_default_parameters["scalar_type"] = utils.ScalarType
tsfc_default_parameters["scalar_type_c"] = utils.ScalarType_c


KernelInfo = collections.namedtuple("KernelInfo",
                                    ["kernel",
                                     "integral_type",
                                     "oriented",
                                     "subdomain_id",
                                     "domain_number",
                                     "coefficient_numbers",
                                     "constant_numbers",
                                     "needs_cell_facets",
                                     "pass_layer_arg",
                                     "needs_cell_sizes",
                                     "arguments",
                                     "events"])


_cachedir = environ.get(
    'FIREDRAKE_TSFC_KERNEL_CACHE_DIR',
    path.join(tempfile.gettempdir(), f'firedrake-tsfc-kernel-cache-uid{getuid()}')
)


[docs] def tsfc_compile_form_hashkey(form, prefix, parameters, interface, diagonal, log): # Drop prefix as it's only used for naming and log return default_parallel_hashkey(form.signature(), prefix, parameters, interface, diagonal)
[docs] def tsfc_compile_form_comm_fetcher(*args, **kwargs): # args[0] is a form return args[0].ufl_domains()[0].comm
# Decorate the original tsfc.compile_form with a cache tsfc_compile_form = memory_and_disk_cache( hashkey=tsfc_compile_form_hashkey, comm_fetcher=tsfc_compile_form_comm_fetcher, cachedir=_cachedir )(original_tsfc_compile_form)
[docs] class TSFCKernel: def __init__( self, form, name, parameters, coefficient_numbers, constant_numbers, interface, diagonal=False ): """A wrapper object for one or more TSFC kernels compiled from a given :class:`~ufl.classes.Form`. :arg form: the :class:`~ufl.classes.Form` from which to compile the kernels. :arg name: a prefix to be applied to the compiled kernel names. This is primarily useful for debugging. :arg parameters: a dict of parameters to pass to the form compiler. :arg coefficient_numbers: Map from coefficient numbers in the provided (split) form to coefficient numbers in the original form. :arg constant_numbers: Map from local constant numbers in the provided (split) form to constant numbers in the original form. :arg interface: the KernelBuilder interface for TSFC (may be None) :arg diagonal: If assembling a matrix is it diagonal? """ tree = tsfc_compile_form(form, prefix=name, parameters=parameters, interface=interface, diagonal=diagonal, log=PETSc.Log.isActive()) kernels = [] for kernel in tree: # Individual kernels do not have to use all of the coefficients # provided by the (split) form. Here we combine the numberings # of (kernel coefficients -> split form coefficients) and # (split form coefficients -> original form coefficients) to give # the map (kernel coefficients -> original form coefficients). coefficient_numbers_per_kernel = tuple( (coefficient_numbers[index], subindices) for index, subindices in kernel.coefficient_numbers ) # Constants from the split form are currently passed to all of # the kernels so the numbering is trivial. constant_numbers_per_kernel = constant_numbers events = (kernel.event,) pyop2_kernel = as_pyop2_local_kernel(kernel.ast, kernel.name, len(kernel.arguments), flop_count=kernel.flop_count, events=events) kernels.append(KernelInfo(kernel=pyop2_kernel, integral_type=kernel.integral_type, oriented=kernel.oriented, subdomain_id=kernel.subdomain_id, domain_number=kernel.domain_number, coefficient_numbers=coefficient_numbers_per_kernel, constant_numbers=constant_numbers_per_kernel, needs_cell_facets=False, pass_layer_arg=False, needs_cell_sizes=kernel.needs_cell_sizes, arguments=kernel.arguments, events=events)) self.kernels = tuple(kernels)
SplitKernel = collections.namedtuple("SplitKernel", ["indices", "kinfo"]) def _compile_form_hashkey(*args, **kwargs): # form, name, parameters, split, diagonal parameters = kwargs.pop("parameters", None) key = cachetools.keys.hashkey( args[0].signature(), *args[1:], utils.tuplify(parameters), **kwargs ) kwargs.setdefault("parameters", parameters) return key def _compile_form_comm(*args, **kwargs): return args[0].ufl_domains()[0].comm
[docs] @memory_and_disk_cache( hashkey=_compile_form_hashkey, comm_fetcher=_compile_form_comm, cachedir=_cachedir ) @PETSc.Log.EventDecorator() def compile_form(form, name, parameters=None, split=True, interface=None, diagonal=False): """Compile a form using TSFC. :arg form: the :class:`~ufl.classes.Form` to compile. :arg name: a prefix for the generated kernel functions. :arg parameters: optional dict of parameters to pass to the form compiler. If not provided, parameters are read from the ``form_compiler`` slot of the Firedrake :data:`~.parameters` dictionary (which see). :arg split: If ``False``, then don't split mixed forms. Returns a tuple of tuples of (index, integral type, subdomain id, coordinates, coefficients, needs_orientations, ``pyop2.op2.Kernel``). ``needs_orientations`` indicates whether the form requires cell orientation information (for correctly pulling back to reference elements on embedded manifolds). The coordinates are extracted from the domain of the integral (a :func:`~.Mesh`) """ # Check that we get a Form if not isinstance(form, Form): raise RuntimeError("Unable to convert object to a UFL form: %s" % repr(form)) if parameters is None: parameters = default_parameters["form_compiler"].copy() else: # Override defaults with user-specified values _ = parameters parameters = default_parameters["form_compiler"].copy() parameters.update(_) kernels = [] numbering = form.terminal_numbering() if split: iterable = split_form(form, diagonal=diagonal) else: nargs = len(form.arguments()) if diagonal: assert nargs == 2 nargs = 1 iterable = ([(None, )*nargs, form], ) for idx, f in iterable: f = _real_mangle(f) if not f.integrals(): # If we're assembling the R space component of a mixed argument, # and that component doesn't actually appear in the form then we # have an empty form, which we should not attempt to assemble. continue # Map local coefficient/constant numbers (as seen inside the # compiler) to the global coefficient/constant numbers coefficient_numbers = tuple( numbering[c] for c in f.coefficients() ) constant_numbers = tuple( numbering[c] for c in extract_firedrake_constants(f) ) prefix = name + "".join(map(str, (i for i in idx if i is not None))) tsfc_kernel = TSFCKernel( f, prefix, parameters, coefficient_numbers, constant_numbers, interface, diagonal ) for kinfo in tsfc_kernel.kernels: kernels.append(SplitKernel(idx, kinfo)) kernels = tuple(kernels) return kernels
def _real_mangle(form): """If the form contains arguments in the Real function space, replace these with literal 1 before passing to tsfc.""" a = form.arguments() reals = [x.ufl_element().family() == "Real" for x in a] if not any(reals): return form replacements = {} for arg, r in zip(a, reals): if r: replacements[arg] = 1 # If only the test space is Real, we need to turn the trial function into a test function. if reals == [True, False]: replacements[a[1]] = conj(TestFunction(a[1].function_space())) return ufl.replace(form, replacements)
[docs] def clear_cache(comm=None): """Clear the Firedrake TSFC kernel cache.""" comm = comm or COMM_WORLD if comm.rank == 0: import shutil shutil.rmtree(_cachedir, ignore_errors=True) _ensure_cachedir(comm=comm)
def _ensure_cachedir(comm=None): """Ensure that the TSFC kernel cache directory exists.""" comm = comm or COMM_WORLD if comm.rank == 0: makedirs(_cachedir, exist_ok=True)
[docs] def gather_integer_subdomain_ids(knls): """Gather a dict of all integer subdomain IDs per integral type. This is needed to correctly interpret the ``"otherwise"`` subdomain ID. :arg knls: Iterable of :class:`SplitKernel` objects. """ all_integer_subdomain_ids = collections.defaultdict(list) for _, kinfo in knls: for subdomain_id in kinfo.subdomain_id: if subdomain_id != "otherwise": all_integer_subdomain_ids[kinfo.integral_type].append(subdomain_id) for k, v in all_integer_subdomain_ids.items(): all_integer_subdomain_ids[k] = tuple(sorted(v)) return all_integer_subdomain_ids
[docs] def as_pyop2_local_kernel(ast, name, nargs, access=op2.INC, **kwargs): """Convert a loopy kernel to a PyOP2 ``pyop2.LocalKernel``. :arg ast: The kernel code. This could be, for example, a loopy kernel. :arg name: The kernel name. :arg nargs: The number of arguments expected by the kernel. :arg access: Access descriptor for the first kernel argument. """ # all but the first argument to the kernel are read-only accesses = tuple([access] + [op2.READ]*(nargs-1)) return op2.Kernel(ast, name, accesses=accesses, requires_zeroed_output_arguments=True, **kwargs)
[docs] def extract_numbered_coefficients(expr, numbers): """Return expression coefficients specified by a numbering. :arg expr: A UFL expression. :arg numbers: Iterable of indices used for selecting the correct coefficients from ``expr``. :returns: A list of UFL coefficients. """ orig_coefficients = ufl.algorithms.extract_coefficients(expr) coefficients = [] for coeff in (orig_coefficients[i] for i in numbers): if type(coeff.ufl_element()) == finat.ufl.MixedElement: coefficients.extend(coeff.subfunctions) else: coefficients.append(coeff) return coefficients