"""
Provides the interface to TSFC for compiling a form, and
transforms the TSFC-generated code to make it suitable for
passing to the backends.
"""
import pickle
from hashlib import md5
from os import path, environ, getuid, makedirs
import gzip
import os
import zlib
import tempfile
import collections
import ufl
import finat.ufl
from ufl import Form, conj
from .ufl_expr import TestFunction
from tsfc import compile_form as tsfc_compile_form
from tsfc.parameters import PARAMETERS as tsfc_default_parameters
from tsfc.ufl_utils import extract_firedrake_constants
from pyop2 import op2
from pyop2.caching import Cached
from pyop2.mpi import COMM_WORLD, MPI
from firedrake.formmanipulation import split_form
from firedrake.parameters import parameters as default_parameters
from firedrake.petsc import PETSc
from firedrake import utils
# Set TSFC default scalar type at load time
tsfc_default_parameters["scalar_type"] = utils.ScalarType
tsfc_default_parameters["scalar_type_c"] = utils.ScalarType_c
KernelInfo = collections.namedtuple("KernelInfo",
["kernel",
"integral_type",
"oriented",
"subdomain_id",
"domain_number",
"coefficient_numbers",
"constant_numbers",
"needs_cell_facets",
"pass_layer_arg",
"needs_cell_sizes",
"arguments",
"events"])
[docs]
class TSFCKernel(Cached):
_cache = {}
_cachedir = environ.get('FIREDRAKE_TSFC_KERNEL_CACHE_DIR',
path.join(tempfile.gettempdir(),
'firedrake-tsfc-kernel-cache-uid%d' % getuid()))
@classmethod
def _cache_lookup(cls, key):
key, comm = key
# comm has to be part of the in memory key so that when
# compiling the same code on different subcommunicators we
# don't get deadlocks. But MPI_Comm objects are not hashable,
# so use comm.py2f() since this is an internal communicator and
# hence the C handle is stable.
commkey = comm.py2f()
assert commkey != MPI.COMM_NULL.py2f()
return cls._cache.get((key, commkey)) or cls._read_from_disk(key, comm)
@classmethod
def _read_from_disk(cls, key, comm):
if comm.rank == 0:
cache = cls._cachedir
shard, disk_key = key[:2], key[2:]
filepath = os.path.join(cache, shard, disk_key)
val = None
if os.path.exists(filepath):
try:
with gzip.open(filepath, 'rb') as f:
val = f.read()
except zlib.error:
pass
comm.bcast(val, root=0)
else:
val = comm.bcast(None, root=0)
if val is None:
raise KeyError(f"Object with key {key} not found")
return cls._cache.setdefault((key, comm.py2f()), pickle.loads(val))
@classmethod
def _cache_store(cls, key, val):
key, comm = key
cls._cache[(key, comm.py2f())] = val
_ensure_cachedir(comm=comm)
if comm.rank == 0:
val._key = key
shard, disk_key = key[:2], key[2:]
filepath = os.path.join(cls._cachedir, shard, disk_key)
tempfile = os.path.join(cls._cachedir, shard, "%s_p%d.tmp" % (disk_key, os.getpid()))
# No need for a barrier after this, since non root
# processes will never race on this file.
os.makedirs(os.path.join(cls._cachedir, shard), exist_ok=True)
with gzip.open(tempfile, 'wb') as f:
pickle.dump(val, f, 0)
os.rename(tempfile, filepath)
comm.barrier()
@classmethod
def _cache_key(cls, form, name, parameters, coefficient_numbers, constant_numbers, interface, diagonal=False):
return md5((form.signature() + name
+ str(sorted(parameters.items()))
+ str(coefficient_numbers)
+ str(constant_numbers)
+ str(type(interface))
+ str(diagonal)).encode()).hexdigest(), form.ufl_domains()[0].comm
def __init__(
self,
form,
name,
parameters,
coefficient_numbers,
constant_numbers,
interface,
diagonal=False
):
"""A wrapper object for one or more TSFC kernels compiled from a given :class:`~ufl.classes.Form`.
:arg form: the :class:`~ufl.classes.Form` from which to compile the kernels.
:arg name: a prefix to be applied to the compiled kernel names. This is primarily useful for debugging.
:arg parameters: a dict of parameters to pass to the form compiler.
:arg coefficient_numbers: Map from coefficient numbers in the provided (split) form to coefficient numbers in the original form.
:arg constant_numbers: Map from local constant numbers in the provided (split) form to constant numbers in the original form.
:arg interface: the KernelBuilder interface for TSFC (may be None)
:arg diagonal: If assembling a matrix is it diagonal?
"""
if self._initialized:
return
tree = tsfc_compile_form(form, prefix=name, parameters=parameters,
interface=interface,
diagonal=diagonal, log=PETSc.Log.isActive())
kernels = []
for kernel in tree:
# Individual kernels do not have to use all of the coefficients
# provided by the (split) form. Here we combine the numberings
# of (kernel coefficients -> split form coefficients) and
# (split form coefficients -> original form coefficients) to give
# the map (kernel coefficients -> original form coefficients).
coefficient_numbers_per_kernel = tuple(
(coefficient_numbers[index], subindices)
for index, subindices in kernel.coefficient_numbers
)
# Constants from the split form are currently passed to all of
# the kernels so the numbering is trivial.
constant_numbers_per_kernel = constant_numbers
events = (kernel.event,)
pyop2_kernel = as_pyop2_local_kernel(kernel.ast, kernel.name,
len(kernel.arguments),
flop_count=kernel.flop_count,
events=events)
kernels.append(KernelInfo(kernel=pyop2_kernel,
integral_type=kernel.integral_type,
oriented=kernel.oriented,
subdomain_id=kernel.subdomain_id,
domain_number=kernel.domain_number,
coefficient_numbers=coefficient_numbers_per_kernel,
constant_numbers=constant_numbers_per_kernel,
needs_cell_facets=False,
pass_layer_arg=False,
needs_cell_sizes=kernel.needs_cell_sizes,
arguments=kernel.arguments,
events=events))
self.kernels = tuple(kernels)
self._initialized = True
SplitKernel = collections.namedtuple("SplitKernel", ["indices",
"kinfo"])
def _real_mangle(form):
"""If the form contains arguments in the Real function space, replace these with literal 1 before passing to tsfc."""
a = form.arguments()
reals = [x.ufl_element().family() == "Real" for x in a]
if not any(reals):
return form
replacements = {}
for arg, r in zip(a, reals):
if r:
replacements[arg] = 1
# If only the test space is Real, we need to turn the trial function into a test function.
if reals == [True, False]:
replacements[a[1]] = conj(TestFunction(a[1].function_space()))
return ufl.replace(form, replacements)
[docs]
def clear_cache(comm=None):
"""Clear the Firedrake TSFC kernel cache."""
comm = comm or COMM_WORLD
if comm.rank == 0:
import shutil
shutil.rmtree(TSFCKernel._cachedir, ignore_errors=True)
_ensure_cachedir(comm=comm)
def _ensure_cachedir(comm=None):
"""Ensure that the TSFC kernel cache directory exists."""
comm = comm or COMM_WORLD
if comm.rank == 0:
makedirs(TSFCKernel._cachedir, exist_ok=True)
[docs]
def gather_integer_subdomain_ids(knls):
"""Gather a dict of all integer subdomain IDs per integral type.
This is needed to correctly interpret the ``"otherwise"`` subdomain ID.
:arg knls: Iterable of :class:`SplitKernel` objects.
"""
all_integer_subdomain_ids = collections.defaultdict(list)
for _, kinfo in knls:
for subdomain_id in kinfo.subdomain_id:
if subdomain_id != "otherwise":
all_integer_subdomain_ids[kinfo.integral_type].append(subdomain_id)
for k, v in all_integer_subdomain_ids.items():
all_integer_subdomain_ids[k] = tuple(sorted(v))
return all_integer_subdomain_ids
[docs]
def as_pyop2_local_kernel(ast, name, nargs, access=op2.INC, **kwargs):
"""Convert a loopy kernel to a PyOP2 ``pyop2.LocalKernel``.
:arg ast: The kernel code. This could be, for example, a loopy kernel.
:arg name: The kernel name.
:arg nargs: The number of arguments expected by the kernel.
:arg access: Access descriptor for the first kernel argument.
"""
# all but the first argument to the kernel are read-only
accesses = tuple([access] + [op2.READ]*(nargs-1))
return op2.Kernel(ast, name, accesses=accesses,
requires_zeroed_output_arguments=True, **kwargs)