"""
Provides the interface to TSFC for compiling a form, and
transforms the TSFC-generated code to make it suitable for
passing to the backends.
"""
from os import path, environ, getuid, makedirs
import tempfile
import collections
import cachetools
import ufl
import finat.ufl
from ufl import Form, conj
from .ufl_expr import TestFunction
from tsfc import compile_form as original_tsfc_compile_form
from tsfc.parameters import PARAMETERS as tsfc_default_parameters
from tsfc.ufl_utils import extract_firedrake_constants
from pyop2 import op2
from pyop2.caching import memory_and_disk_cache, default_parallel_hashkey
from pyop2.mpi import COMM_WORLD
from firedrake.formmanipulation import split_form
from firedrake.parameters import parameters as default_parameters
from firedrake.petsc import PETSc
from firedrake import utils
# Set TSFC default scalar type at load time
tsfc_default_parameters["scalar_type"] = utils.ScalarType
tsfc_default_parameters["scalar_type_c"] = utils.ScalarType_c
KernelInfo = collections.namedtuple("KernelInfo",
["kernel",
"integral_type",
"oriented",
"subdomain_id",
"domain_number",
"coefficient_numbers",
"constant_numbers",
"needs_cell_facets",
"pass_layer_arg",
"needs_cell_sizes",
"arguments",
"events"])
_cachedir = environ.get(
'FIREDRAKE_TSFC_KERNEL_CACHE_DIR',
path.join(tempfile.gettempdir(), f'firedrake-tsfc-kernel-cache-uid{getuid()}')
)
# Decorate the original tsfc.compile_form with a cache
tsfc_compile_form = memory_and_disk_cache(
hashkey=tsfc_compile_form_hashkey,
comm_fetcher=tsfc_compile_form_comm_fetcher,
cachedir=_cachedir
)(original_tsfc_compile_form)
[docs]
class TSFCKernel:
def __init__(
self,
form,
name,
parameters,
coefficient_numbers,
constant_numbers,
interface,
diagonal=False
):
"""A wrapper object for one or more TSFC kernels compiled from a given :class:`~ufl.classes.Form`.
:arg form: the :class:`~ufl.classes.Form` from which to compile the kernels.
:arg name: a prefix to be applied to the compiled kernel names. This is primarily useful for debugging.
:arg parameters: a dict of parameters to pass to the form compiler.
:arg coefficient_numbers: Map from coefficient numbers in the provided (split) form to coefficient numbers in the original form.
:arg constant_numbers: Map from local constant numbers in the provided (split) form to constant numbers in the original form.
:arg interface: the KernelBuilder interface for TSFC (may be None)
:arg diagonal: If assembling a matrix is it diagonal?
"""
tree = tsfc_compile_form(form, prefix=name, parameters=parameters,
interface=interface,
diagonal=diagonal, log=PETSc.Log.isActive())
kernels = []
for kernel in tree:
# Individual kernels do not have to use all of the coefficients
# provided by the (split) form. Here we combine the numberings
# of (kernel coefficients -> split form coefficients) and
# (split form coefficients -> original form coefficients) to give
# the map (kernel coefficients -> original form coefficients).
coefficient_numbers_per_kernel = tuple(
(coefficient_numbers[index], subindices)
for index, subindices in kernel.coefficient_numbers
)
# Constants from the split form are currently passed to all of
# the kernels so the numbering is trivial.
constant_numbers_per_kernel = constant_numbers
events = (kernel.event,)
pyop2_kernel = as_pyop2_local_kernel(kernel.ast, kernel.name,
len(kernel.arguments),
flop_count=kernel.flop_count,
events=events)
kernels.append(KernelInfo(kernel=pyop2_kernel,
integral_type=kernel.integral_type,
oriented=kernel.oriented,
subdomain_id=kernel.subdomain_id,
domain_number=kernel.domain_number,
coefficient_numbers=coefficient_numbers_per_kernel,
constant_numbers=constant_numbers_per_kernel,
needs_cell_facets=False,
pass_layer_arg=False,
needs_cell_sizes=kernel.needs_cell_sizes,
arguments=kernel.arguments,
events=events))
self.kernels = tuple(kernels)
SplitKernel = collections.namedtuple("SplitKernel", ["indices", "kinfo"])
def _compile_form_hashkey(*args, **kwargs):
# form, name, parameters, split, diagonal
parameters = kwargs.pop("parameters", None)
key = cachetools.keys.hashkey(
args[0].signature(),
*args[1:],
utils.tuplify(parameters),
**kwargs
)
kwargs.setdefault("parameters", parameters)
return key
def _compile_form_comm(*args, **kwargs):
return args[0].ufl_domains()[0].comm
def _real_mangle(form):
"""If the form contains arguments in the Real function space, replace these with literal 1 before passing to tsfc."""
a = form.arguments()
reals = [x.ufl_element().family() == "Real" for x in a]
if not any(reals):
return form
replacements = {}
for arg, r in zip(a, reals):
if r:
replacements[arg] = 1
# If only the test space is Real, we need to turn the trial function into a test function.
if reals == [True, False]:
replacements[a[1]] = conj(TestFunction(a[1].function_space()))
return ufl.replace(form, replacements)
[docs]
def clear_cache(comm=None):
"""Clear the Firedrake TSFC kernel cache."""
comm = comm or COMM_WORLD
if comm.rank == 0:
import shutil
shutil.rmtree(_cachedir, ignore_errors=True)
_ensure_cachedir(comm=comm)
def _ensure_cachedir(comm=None):
"""Ensure that the TSFC kernel cache directory exists."""
comm = comm or COMM_WORLD
if comm.rank == 0:
makedirs(_cachedir, exist_ok=True)
[docs]
def gather_integer_subdomain_ids(knls):
"""Gather a dict of all integer subdomain IDs per integral type.
This is needed to correctly interpret the ``"otherwise"`` subdomain ID.
:arg knls: Iterable of :class:`SplitKernel` objects.
"""
all_integer_subdomain_ids = collections.defaultdict(list)
for _, kinfo in knls:
for subdomain_id in kinfo.subdomain_id:
if subdomain_id != "otherwise":
all_integer_subdomain_ids[kinfo.integral_type].append(subdomain_id)
for k, v in all_integer_subdomain_ids.items():
all_integer_subdomain_ids[k] = tuple(sorted(v))
return all_integer_subdomain_ids
[docs]
def as_pyop2_local_kernel(ast, name, nargs, access=op2.INC, **kwargs):
"""Convert a loopy kernel to a PyOP2 ``pyop2.LocalKernel``.
:arg ast: The kernel code. This could be, for example, a loopy kernel.
:arg name: The kernel name.
:arg nargs: The number of arguments expected by the kernel.
:arg access: Access descriptor for the first kernel argument.
"""
# all but the first argument to the kernel are read-only
accesses = tuple([access] + [op2.READ]*(nargs-1))
return op2.Kernel(ast, name, accesses=accesses,
requires_zeroed_output_arguments=True, **kwargs)