From 1944432ab023e009bf159fe5598ce89091419d19 Mon Sep 17 00:00:00 2001 From: ksagiyam <46749170+ksagiyam@users.noreply.github.com> Date: Wed, 6 Nov 2024 14:07:48 +0000 Subject: [PATCH] generalise VariableIndex and FlexiblyIndexed (#317) hex: enable interior facet integration --- gem/gem.py | 325 +++++++++++++++++++---- gem/node.py | 47 +++- gem/optimise.py | 27 +- gem/scheduling.py | 9 +- tsfc/driver.py | 3 - tsfc/fem.py | 54 +++- tsfc/kernel_args.py | 8 + tsfc/kernel_interface/__init__.py | 4 + tsfc/kernel_interface/common.py | 5 + tsfc/kernel_interface/firedrake_loopy.py | 22 +- tsfc/loopy.py | 33 ++- 11 files changed, 453 insertions(+), 84 deletions(-) diff --git a/gem/gem.py b/gem/gem.py index 95e8f2f5..c37e4352 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -24,16 +24,21 @@ from gem.node import Node as NodeBase, traversal +from FIAT.orientation_utils import Orientation as FIATOrientation + __all__ = ['Node', 'Identity', 'Literal', 'Zero', 'Failure', - 'Variable', 'Sum', 'Product', 'Division', 'Power', + 'Variable', 'Sum', 'Product', 'Division', 'FloorDiv', 'Remainder', 'Power', 'MathFunction', 'MinValue', 'MaxValue', 'Comparison', 'LogicalNot', 'LogicalAnd', 'LogicalOr', 'Conditional', 'Index', 'VariableIndex', 'Indexed', 'ComponentTensor', - 'IndexSum', 'ListTensor', 'Concatenate', 'Delta', + 'IndexSum', 'ListTensor', 'Concatenate', 'Delta', 'OrientationVariableIndex', 'index_sum', 'partial_indexed', 'reshape', 'view', 'indices', 'as_gem', 'FlexiblyIndexed', - 'Inverse', 'Solve', 'extract_type'] + 'Inverse', 'Solve', 'extract_type', 'uint_type'] + + +uint_type = numpy.uintc class NodeMeta(type): @@ -130,6 +135,24 @@ def __truediv__(self, other): def __rtruediv__(self, other): return as_gem(other).__truediv__(self) + def __floordiv__(self, other): + other = as_gem_uint(other) + if other.shape: + raise ValueError("Denominator must be scalar") + return componentwise(FloorDiv, self, other) + + def __rfloordiv__(self, other): + return as_gem_uint(other).__floordiv__(self) + + def __mod__(self, other): + other = as_gem_uint(other) + if other.shape: + raise ValueError("Denominator must be scalar") + return componentwise(Remainder, self, other) + + def __rmod__(self, other): + return as_gem_uint(other).__mod__(self) + class Terminal(Node): """Abstract class for terminal GEM nodes.""" @@ -167,7 +190,8 @@ class Constant(Terminal): - array: numpy array of values - value: float or complex value (scalars only) """ - __slots__ = () + __slots__ = ('dtype',) + __back__ = ('dtype',) class Zero(Constant): @@ -176,13 +200,14 @@ class Zero(Constant): __slots__ = ('shape',) __front__ = ('shape',) - def __init__(self, shape=()): + def __init__(self, shape=(), dtype=float): self.shape = shape + self.dtype = dtype @property def value(self): assert not self.shape - return 0.0 + return numpy.array(0, dtype=self.dtype).item() class Identity(Constant): @@ -191,8 +216,9 @@ class Identity(Constant): __slots__ = ('dim',) __front__ = ('dim',) - def __init__(self, dim): + def __init__(self, dim, dtype=float): self.dim = dim + self.dtype = dtype @property def shape(self): @@ -200,7 +226,7 @@ def shape(self): @property def array(self): - return numpy.eye(self.dim) + return numpy.eye(self.dim, dtype=self.dtype) class Literal(Constant): @@ -209,16 +235,24 @@ class Literal(Constant): __slots__ = ('array',) __front__ = ('array',) - def __new__(cls, array): + def __new__(cls, array, dtype=None): array = asarray(array) return super(Literal, cls).__new__(cls) - def __init__(self, array): + def __init__(self, array, dtype=None): array = asarray(array) - try: - self.array = array.astype(float, casting="safe") - except TypeError: - self.array = array.astype(complex) + if dtype is None: + # Assume float or complex. + try: + self.array = array.astype(float, casting="safe") + self.dtype = float + except TypeError: + self.array = array.astype(complex) + self.dtype = complex + else: + # Can be int, etc. + self.array = array.astype(dtype) + self.dtype = dtype def is_equal(self, other): if type(self) is not type(other): @@ -243,12 +277,13 @@ def shape(self): class Variable(Terminal): """Symbolic variable tensor""" - __slots__ = ('name', 'shape') - __front__ = ('name', 'shape') + __slots__ = ('name', 'shape', 'dtype') + __front__ = ('name', 'shape', 'dtype') - def __init__(self, name, shape): + def __init__(self, name, shape, dtype=None): self.name = name self.shape = shape + self.dtype = dtype class Sum(Scalar): @@ -265,7 +300,8 @@ def __new__(cls, a, b): return a if isinstance(a, Constant) and isinstance(b, Constant): - return Literal(a.value + b.value) + dtype = numpy.result_type(a.dtype, b.dtype) + return Literal(a.value + b.value, dtype=dtype) self = super(Sum, cls).__new__(cls) self.children = a, b @@ -289,7 +325,8 @@ def __new__(cls, a, b): return a if isinstance(a, Constant) and isinstance(b, Constant): - return Literal(a.value * b.value) + dtype = numpy.result_type(a.dtype, b.dtype) + return Literal(a.value * b.value, dtype=dtype) self = super(Product, cls).__new__(cls) self.children = a, b @@ -313,13 +350,62 @@ def __new__(cls, a, b): return a if isinstance(a, Constant) and isinstance(b, Constant): - return Literal(a.value / b.value) + dtype = numpy.result_type(a.dtype, b.dtype) + return Literal(a.value / b.value, dtype=dtype) self = super(Division, cls).__new__(cls) self.children = a, b return self +class FloorDiv(Scalar): + __slots__ = ('children',) + + def __new__(cls, a, b): + assert not a.shape + assert not b.shape + # TODO: Attach dtype property to Node and check that + # numpy.result_dtype(a.dtype, b.dtype) is uint type. + # dtype is currently attached only to {Constant, Variable}. + # Constant folding + if isinstance(b, Zero): + raise ValueError("division by zero") + if isinstance(a, Zero): + return Zero(dtype=a.dtype) + if isinstance(b, Constant) and b.value == 1: + return a + if isinstance(a, Constant) and isinstance(b, Constant): + dtype = numpy.result_type(a.dtype, b.dtype) + return Literal(a.value // b.value, dtype=dtype) + self = super(FloorDiv, cls).__new__(cls) + self.children = a, b + return self + + +class Remainder(Scalar): + __slots__ = ('children',) + + def __new__(cls, a, b): + assert not a.shape + assert not b.shape + # TODO: Attach dtype property to Node and check that + # numpy.result_dtype(a.dtype, b.dtype) is uint type. + # dtype is currently attached only to {Constant, Variable}. + # Constant folding + if isinstance(b, Zero): + raise ValueError("division by zero") + if isinstance(a, Zero): + return Zero(dtype=a.dtype) + if isinstance(b, Constant) and b.value == 1: + return Zero(dtype=b.dtype) + if isinstance(a, Constant) and isinstance(b, Constant): + dtype = numpy.result_type(a.dtype, b.dtype) + return Literal(a.value % b.value, dtype=dtype) + self = super(Remainder, cls).__new__(cls) + self.children = a, b + return self + + class Power(Scalar): __slots__ = ('children',) @@ -329,14 +415,16 @@ def __new__(cls, base, exponent): # Constant folding if isinstance(base, Zero): + dtype = numpy.result_type(base.dtype, exponent.dtype) if isinstance(exponent, Zero): raise ValueError("cannot solve 0^0") - return Zero() + return Zero(dtype=dtype) elif isinstance(exponent, Zero): - return one - - if isinstance(base, Constant) and isinstance(exponent, Constant): - return Literal(base.value ** exponent.value) + dtype = numpy.result_type(base.dtype, exponent.dtype) + return Literal(1, dtype=dtype) + elif isinstance(base, Constant) and isinstance(exponent, Constant): + dtype = numpy.result_type(base.dtype, exponent.dtype) + return Literal(base.value ** exponent.value, dtype=dtype) self = super(Power, cls).__new__(cls) self.children = base, exponent @@ -502,7 +590,6 @@ class VariableIndex(IndexBase): def __init__(self, expression): assert isinstance(expression, Node) - assert not expression.free_indices assert not expression.shape self.expression = expression @@ -517,20 +604,20 @@ def __ne__(self, other): return not self.__eq__(other) def __hash__(self): - return hash((VariableIndex, self.expression)) + return hash((type(self), self.expression)) def __str__(self): return str(self.expression) def __repr__(self): - return "VariableIndex(%r)" % (self.expression,) + return "%r(%r)" % (type(self), self.expression,) def __reduce__(self): - return VariableIndex, (self.expression,) + return type(self), (self.expression,) class Indexed(Scalar): - __slots__ = ('children', 'multiindex') + __slots__ = ('children', 'multiindex', 'indirect_children') __back__ = ('multiindex',) def __new__(cls, aggregate, multiindex): @@ -553,41 +640,59 @@ def __new__(cls, aggregate, multiindex): # Zero folding if isinstance(aggregate, Zero): - return Zero() + return Zero(dtype=aggregate.dtype) # All indices fixed if all(isinstance(i, int) for i in multiindex): if isinstance(aggregate, Constant): - return Literal(aggregate.array[multiindex]) + return Literal(aggregate.array[multiindex], dtype=aggregate.dtype) elif isinstance(aggregate, ListTensor): return aggregate.array[multiindex] self = super(Indexed, cls).__new__(cls) self.children = (aggregate,) self.multiindex = multiindex + self.indirect_children = tuple(i.expression for i in self.multiindex if isinstance(i, VariableIndex)) - new_indices = tuple(i for i in multiindex if isinstance(i, Index)) - self.free_indices = unique(aggregate.free_indices + new_indices) + new_indices = [] + for i in multiindex: + if isinstance(i, Index): + new_indices.append(i) + elif isinstance(i, VariableIndex): + new_indices.extend(i.expression.free_indices) + self.free_indices = unique(aggregate.free_indices + tuple(new_indices)) return self def index_ordering(self): """Running indices in the order of indexing in this node.""" - return tuple(i for i in self.multiindex if isinstance(i, Index)) + free_indices = [] + for i in self.multiindex: + if isinstance(i, Index): + free_indices.append(i) + elif isinstance(i, VariableIndex): + free_indices.extend(i.expression.free_indices) + return tuple(free_indices) class FlexiblyIndexed(Scalar): """Flexible indexing of :py:class:`Variable`s to implement views and reshapes (splitting dimensions only).""" - __slots__ = ('children', 'dim2idxs') + __slots__ = ('children', 'dim2idxs', 'indirect_children') __back__ = ('dim2idxs',) def __init__(self, variable, dim2idxs): """Construct a flexibly indexed node. - :arg variable: a node that has a shape - :arg dim2idxs: describes the mapping of indices + Parameters + ---------- + variable : Node + `Node` that has a shape. + dim2idxs : tuple + Tuple of (offset, ((index, stride), (...), ...)) mapping indices, + where offset is {Node, int}, index is {Index, VariableIndex, int}, and + stride is {Node, int}. For example, if ``variable`` is rank two, and ``dim2idxs`` is @@ -600,40 +705,73 @@ def __init__(self, variable, dim2idxs): """ assert variable.shape assert len(variable.shape) == len(dim2idxs) - dim2idxs_ = [] free_indices = [] for dim, (offset, idxs) in zip(variable.shape, dim2idxs): offset_ = offset idxs_ = [] last = 0 - for idx in idxs: - index, stride = idx + if isinstance(offset, Node): + free_indices.extend(offset.free_indices) + for index, stride in idxs: if isinstance(index, Index): assert index.extent is not None free_indices.append(index) idxs_.append((index, stride)) last += (index.extent - 1) * stride + elif isinstance(index, VariableIndex): + base_indices = index.expression.free_indices + assert all(base_index.extent is not None for base_index in base_indices) + free_indices.extend(base_indices) + idxs_.append((index, stride)) + # last += (unknown_extent - 1) * stride elif isinstance(index, int): - offset_ += index * stride + # TODO: Attach dtype to each Node. + # Here, we should simply be able to do: + # >>> offset_ += index * stride + # but "+" and "*" are not currently correctly overloaded + # for indices (integers); they assume floats. + if not isinstance(offset, Integral): + raise NotImplementedError(f"Found non-Integral offset : {offset}") + if isinstance(stride, Constant): + offset_ += index * stride.value + else: + offset_ += index * stride else: raise ValueError("Unexpected index type for flexible indexing") - - if dim is not None and offset_ + last >= dim: + if isinstance(stride, Node): + free_indices.extend(stride.free_indices) + if dim is not None and isinstance(offset_ + last, Integral) and offset_ + last >= dim: raise ValueError("Offset {0} and indices {1} exceed dimension {2}".format(offset, idxs, dim)) - dim2idxs_.append((offset_, tuple(idxs_))) - self.children = (variable,) self.dim2idxs = tuple(dim2idxs_) self.free_indices = unique(free_indices) + indirect_children = [] + for offset, idxs in self.dim2idxs: + if isinstance(offset, Node): + indirect_children.append(offset) + for idx, stride in idxs: + if isinstance(idx, VariableIndex): + indirect_children.append(idx.expression) + if isinstance(stride, Node): + indirect_children.append(stride) + self.indirect_children = tuple(indirect_children) def index_ordering(self): """Running indices in the order of indexing in this node.""" - return tuple(index - for _, idxs in self.dim2idxs - for index, _ in idxs - if isinstance(index, Index)) + free_indices = [] + for offset, idxs in self.dim2idxs: + if isinstance(offset, Node): + free_indices.extend(offset.free_indices) + for index, stride in idxs: + if isinstance(index, Index): + free_indices.append(index) + elif isinstance(index, VariableIndex): + free_indices.extend(index.expression.free_indices) + if isinstance(stride, Node): + free_indices.extend(stride.free_indices) + return tuple(free_indices) class ComponentTensor(Node): @@ -653,7 +791,7 @@ def __new__(cls, expression, multiindex): # Zero folding if isinstance(expression, Zero): - return Zero(shape) + return Zero(shape, dtype=expression.dtype) self = super(ComponentTensor, cls).__new__(cls) self.children = (expression,) @@ -771,7 +909,8 @@ class Concatenate(Node): def __new__(cls, *children): if all(isinstance(child, Zero) for child in children): size = int(sum(numpy.prod(child.shape, dtype=int) for child in children)) - return Zero((size,)) + dtype = numpy.result_type(*(child.dtype for child in children)) + return Zero((size,), dtype=dtype) self = super(Concatenate, cls).__new__(cls) self.children = children @@ -802,7 +941,12 @@ def __new__(cls, i, j): self.i = i self.j = j # Set up free indices - free_indices = tuple(index for index in (i, j) if isinstance(index, Index)) + free_indices = [] + for index in (i, j): + if isinstance(index, Index): + free_indices.append(index) + elif isinstance(index, VariableIndex): + raise NotImplementedError("Can not make Delta with VariableIndex") self.free_indices = tuple(unique(free_indices)) return self @@ -845,6 +989,34 @@ def __init__(self, A, B): self.shape = A.shape[1:] + B.shape[1:] +class OrientationVariableIndex(VariableIndex, FIATOrientation): + """VariableIndex representing a fiat orientation. + + Notes + ----- + In the current implementation, we need to extract + `VariableIndex.expression` as index arithmetic + is not supported (indices are not `Node`). + + """ + + def __floordiv__(self, other): + other = other.expression if isinstance(other, VariableIndex) else as_gem_uint(other) + return type(self)(FloorDiv(self.expression, other)) + + def __rfloordiv__(self, other): + other = other.expression if isinstance(other, VariableIndex) else as_gem_uint(other) + return type(self)(FloorDiv(other, self.expression)) + + def __mod__(self, other): + other = other.expression if isinstance(other, VariableIndex) else as_gem_uint(other) + return type(self)(Remainder(self.expression, other)) + + def __rmod__(self, other): + other = other.expression if isinstance(other, VariableIndex) else as_gem_uint(other) + return type(self)(Remainder(other, self.expression)) + + def unique(indices): """Sorts free indices and eliminates duplicates. @@ -1027,11 +1199,23 @@ def componentwise(op, *exprs): def as_gem(expr): - """Attempt to convert an expression into GEM. + """Attempt to convert an expression into GEM of scalar type. + + Parameters + ---------- + expr : Node or Number + The expression. + + Returns + ------- + Node + A GEM representation of the expression. + + Raises + ------ + ValueError + If conversion was not possible. - :arg expr: The expression. - :returns: A GEM representation of the expression. - :raises ValueError: if conversion was not possible. """ if isinstance(expr, Node): return expr @@ -1041,6 +1225,33 @@ def as_gem(expr): raise ValueError("Do not know how to convert %r to GEM" % expr) +def as_gem_uint(expr): + """Attempt to convert an expression into GEM of uint type. + + Parameters + ---------- + expr : Node or Integral + The expression. + + Returns + ------- + Node + A GEM representation of the expression. + + Raises + ------ + ValueError + If conversion was not possible. + + """ + if isinstance(expr, Node): + return expr + elif isinstance(expr, Integral): + return Literal(expr, dtype=uint_type) + else: + raise ValueError("Do not know how to convert %r to GEM" % expr) + + def extract_type(expressions, klass): """Collects objects of type klass in expressions.""" return tuple(node for node in traversal(expressions) if isinstance(node, klass)) diff --git a/gem/node.py b/gem/node.py index 31d99b9e..5d9c5bf0 100644 --- a/gem/node.py +++ b/gem/node.py @@ -2,6 +2,7 @@ expression DAG languages.""" import collections +import gem class Node(object): @@ -99,8 +100,23 @@ def get_hash(self): return hash((type(self),) + self._cons_args(self.children)) +def _make_traversal_children(node): + if isinstance(node, (gem.Indexed, gem.FlexiblyIndexed)): + # Include child nodes hidden in index expressions. + return node.children + node.indirect_children + else: + return node.children + + def pre_traversal(expression_dags): - """Pre-order traversal of the nodes of expression DAGs.""" + """Pre-order traversal of the nodes of expression DAGs. + + Notes + ----- + This function also walks through nodes in index expressions + (e.g., `VariableIndex`s); see ``_make_traversal_children()``. + + """ seen = set() lifo = [] # Some roots might be same, but they must be visited only once. @@ -114,14 +130,23 @@ def pre_traversal(expression_dags): while lifo: node = lifo.pop() yield node - for child in reversed(node.children): + children = _make_traversal_children(node) + for child in reversed(children): if child not in seen: seen.add(child) lifo.append(child) def post_traversal(expression_dags): - """Post-order traversal of the nodes of expression DAGs.""" + """Post-order traversal of the nodes of expression DAGs. + + Notes + ----- + This function also walks through nodes in index expressions + (e.g., `VariableIndex`s); see ``_make_traversal_children()``. + + + """ seen = set() lifo = [] # Some roots might be same, but they must be visited only once. @@ -130,13 +155,13 @@ def post_traversal(expression_dags): for root in expression_dags: if root not in seen: seen.add(root) - lifo.append((root, list(root.children))) + lifo.append((root, list(_make_traversal_children(root)))) while lifo: node, deps = lifo[-1] for i, dep in enumerate(deps): if dep is not None and dep not in seen: - lifo.append((dep, list(dep.children))) + lifo.append((dep, list(_make_traversal_children(dep)))) deps[i] = None break else: @@ -150,10 +175,18 @@ def post_traversal(expression_dags): def collect_refcount(expression_dags): - """Collects reference counts for a multi-root expression DAG.""" + """Collects reference counts for a multi-root expression DAG. + + Notes + ----- + This function also collects reference counts of nodes + in index expressions (e.g., `VariableIndex`s); see + ``_make_traversal_children()``. + + """ result = collections.Counter(expression_dags) for node in traversal(expression_dags): - result.update(node.children) + result.update(_make_traversal_children(node)) return result diff --git a/gem/optimise.py b/gem/optimise.py index 3194e6ef..7d6c8ecd 100644 --- a/gem/optimise.py +++ b/gem/optimise.py @@ -4,6 +4,7 @@ from collections import OrderedDict, defaultdict from functools import singledispatch, partial, reduce from itertools import combinations, permutations, zip_longest +from numbers import Integral import numpy @@ -95,11 +96,19 @@ def replace_indices(node, self, subst): replace_indices.register(Node)(reuse_if_untouched_arg) +def _replace_indices_atomic(i, self, subst): + if isinstance(i, VariableIndex): + new_expr = self(i.expression, subst) + return i if new_expr == i.expression else VariableIndex(new_expr) + else: + substitute = dict(subst) + return substitute.get(i, i) + + @replace_indices.register(Delta) def replace_indices_delta(node, self, subst): - substitute = dict(subst) - i = substitute.get(node.i, node.i) - j = substitute.get(node.j, node.j) + i = _replace_indices_atomic(node.i, self, subst) + j = _replace_indices_atomic(node.j, self, subst) if i == node.i and j == node.j: return node else: @@ -110,7 +119,9 @@ def replace_indices_delta(node, self, subst): def replace_indices_indexed(node, self, subst): child, = node.children substitute = dict(subst) - multiindex = tuple(substitute.get(i, i) for i in node.multiindex) + multiindex = [] + for i in node.multiindex: + multiindex.append(_replace_indices_atomic(i, self, subst)) if isinstance(child, ComponentTensor): # Indexing into ComponentTensor # Inline ComponentTensor and augment the substitution rules @@ -130,9 +141,11 @@ def replace_indices_flexiblyindexed(node, self, subst): child, = node.children assert not child.free_indices - substitute = dict(subst) dim2idxs = tuple( - (offset, tuple((substitute.get(i, i), s) for i, s in idxs)) + ( + offset if isinstance(offset, Integral) else _replace_indices_atomic(offset, self, subst), + tuple((_replace_indices_atomic(i, self, subst), s if isinstance(s, Integral) else self(s, subst)) for i, s in idxs) + ) for offset, idxs in node.dim2idxs ) @@ -145,6 +158,8 @@ def replace_indices_flexiblyindexed(node, self, subst): def filtered_replace_indices(node, self, subst): """Wrapper for :func:`replace_indices`. At each call removes substitution rules that do not apply.""" + if any(isinstance(k, VariableIndex) for k, _ in subst): + raise NotImplementedError("Can not replace VariableIndex (will need inverse)") filtered_subst = tuple((k, v) for k, v in subst if k in node.free_indices) return replace_indices(node, self, filtered_subst) diff --git a/gem/scheduling.py b/gem/scheduling.py index 831ee048..41039a0e 100644 --- a/gem/scheduling.py +++ b/gem/scheduling.py @@ -3,6 +3,7 @@ import collections import functools +import itertools from gem import gem, impero from gem.node import collect_refcount @@ -116,8 +117,12 @@ def handle(ops, push, decref, node): elif isinstance(node, gem.Zero): # should rarely happen assert not node.shape elif isinstance(node, (gem.Indexed, gem.FlexiblyIndexed)): - # Indexing always inlined - decref(node.children[0]) + if node.indirect_children: + # Do not inline; + # Index expression can be involved if it contains VariableIndex. + ops.append(impero.Evaluate(node)) + for child in itertools.chain(node.children, node.indirect_children): + decref(child) elif isinstance(node, gem.IndexSum): ops.append(impero.Noop(node)) push(impero.Accumulate(node)) diff --git a/tsfc/driver.py b/tsfc/driver.py index 95fd43fb..6e3c3baa 100644 --- a/tsfc/driver.py +++ b/tsfc/driver.py @@ -92,9 +92,6 @@ def compile_integral(integral_data, form_data, prefix, parameters, interface, *, :arg log: bool if the Kernel should be profiled with Log events :returns: a kernel constructed by the kernel interface """ - if integral_data.domain.ufl_cell().cellname() == "hexahedron" and \ - integral_data.integral_type == "interior_facet": - raise NotImplementedError("interior facet integration in hex meshes not currently supported") parameters = preprocess_parameters(parameters) if interface is None: interface = firedrake_interface_loopy.KernelBuilder diff --git a/tsfc/fem.py b/tsfc/fem.py index 27439f1c..abc8bc7c 100644 --- a/tsfc/fem.py +++ b/tsfc/fem.py @@ -8,7 +8,9 @@ import gem import numpy import ufl -from FIAT.reference_element import UFCSimplex, make_affine_mapping +from FIAT.orientation_utils import Orientation as FIATOrientation +from FIAT.reference_element import UFCHexahedron, UFCSimplex, make_affine_mapping +from FIAT.reference_element import TensorProductCell from finat.physically_mapped import (NeedsCoordinateMappingElement, PhysicalGeometry) from finat.point_set import PointSet, PointSingleton @@ -108,6 +110,10 @@ def translator(self): # NOTE: reference cycle! return Translator(self) + @cached_property + def use_canonical_quadrature_point_ordering(self): + return isinstance(self.fiat_cell, UFCHexahedron) and self.integral_type in ['exterior_facet', 'interior_facet'] + class CoordinateMapping(PhysicalGeometry): """Callback class that provides physical geometry to FInAT elements. @@ -266,10 +272,13 @@ class PointSetContext(ContextBase): 'weight_expr', ) + @cached_property + def integration_cell(self): + return self.fiat_cell.construct_subelement(self.integration_dim) + @cached_property def quadrature_rule(self): - integration_cell = self.fiat_cell.construct_subelement(self.integration_dim) - return make_quadrature(integration_cell, self.quadrature_degree) + return make_quadrature(self.integration_cell, self.quadrature_degree) @cached_property def point_set(self): @@ -629,6 +638,11 @@ def callback(entity_id): # lives on after ditching FFC and switching to FInAT. return ffc_rounding(square, ctx.epsilon) table = ctx.entity_selector(callback, mt.restriction) + if ctx.use_canonical_quadrature_point_ordering: + quad_multiindex = ctx.quadrature_rule.point_set.indices + quad_multiindex_permuted = _make_quad_multiindex_permuted(mt, ctx) + mapper = gem.node.MemoizerArg(gem.optimise.filtered_replace_indices) + table = mapper(table, tuple(zip(quad_multiindex, quad_multiindex_permuted))) return gem.ComponentTensor(gem.Indexed(table, argument_multiindex + sigma), sigma) @@ -698,9 +712,43 @@ def take_singleton(xs): for node in traversal((result,)) if isinstance(node, gem.Literal)): result = gem.optimise.aggressive_unroll(result) + + if ctx.use_canonical_quadrature_point_ordering: + quad_multiindex = ctx.quadrature_rule.point_set.indices + quad_multiindex_permuted = _make_quad_multiindex_permuted(mt, ctx) + mapper = gem.node.MemoizerArg(gem.optimise.filtered_replace_indices) + result = mapper(result, tuple(zip(quad_multiindex, quad_multiindex_permuted))) return result +def _make_quad_multiindex_permuted(mt, ctx): + quad_rule = ctx.quadrature_rule + # Note that each quad index here represents quad points on a physical + # cell axis, but the table is indexed by indices representing the points + # on each reference cell axis, so we need to apply permutation based on the orientation. + cell = quad_rule.ref_el + quad_multiindex = quad_rule.point_set.indices + if isinstance(cell, TensorProductCell): + for comp in set(cell.cells): + extents = set(q.extent for c, q in zip(cell.cells, quad_multiindex) if c == comp) + if len(extents) != 1: + raise ValueError("Must have the same number of quadrature points in each symmetric axis") + quad_multiindex_permuted = [] + o = ctx.entity_orientation(mt.restriction) + if not isinstance(o, FIATOrientation): + raise ValueError(f"Expecting an instance of FIATOrientation : got {o}") + eo = cell.extract_extrinsic_orientation(o) + eo_perm_map = gem.Literal(quad_rule.extrinsic_orientation_permutation_map, dtype=gem.uint_type) + for ref_axis in range(len(quad_multiindex)): + io = cell.extract_intrinsic_orientation(o, ref_axis) + io_perm_map = gem.Literal(quad_rule.intrinsic_orientation_permutation_map_tuple[ref_axis], dtype=gem.uint_type) + # Effectively swap axes if needed. + ref_index = tuple((phys_index, gem.Indexed(eo_perm_map, (eo, ref_axis, phys_axis))) for phys_axis, phys_index in enumerate(quad_multiindex)) + quad_index_permuted = gem.VariableIndex(gem.FlexiblyIndexed(io_perm_map, ((0, ((io, 1), )), (0, ref_index)))) + quad_multiindex_permuted.append(quad_index_permuted) + return tuple(quad_multiindex_permuted) + + def compile_ufl(expression, context, interior_facet=False, point_sum=False): """Translate a UFL expression to GEM. diff --git a/tsfc/kernel_args.py b/tsfc/kernel_args.py index 1c79bdae..a397f0f9 100644 --- a/tsfc/kernel_args.py +++ b/tsfc/kernel_args.py @@ -52,3 +52,11 @@ class ExteriorFacetKernelArg(KernelArg): class InteriorFacetKernelArg(KernelArg): ... + + +class ExteriorFacetOrientationKernelArg(KernelArg): + ... + + +class InteriorFacetOrientationKernelArg(KernelArg): + ... diff --git a/tsfc/kernel_interface/__init__.py b/tsfc/kernel_interface/__init__.py index fc194260..51142638 100644 --- a/tsfc/kernel_interface/__init__.py +++ b/tsfc/kernel_interface/__init__.py @@ -33,6 +33,10 @@ def cell_size(self, restriction): def entity_number(self, restriction): """Facet or vertex number as a GEM index.""" + @abstractmethod + def entity_orientation(self, restriction): + """Entity orientation as a GEM index.""" + @abstractmethod def create_element(self, element, **kwargs): """Create a FInAT element (suitable for tabulating with) given diff --git a/tsfc/kernel_interface/common.py b/tsfc/kernel_interface/common.py index 18fd363f..945e367e 100644 --- a/tsfc/kernel_interface/common.py +++ b/tsfc/kernel_interface/common.py @@ -91,6 +91,11 @@ def entity_number(self, restriction): # Assume self._entity_number dict is set up at this point. return self._entity_number[restriction] + def entity_orientation(self, restriction): + """Facet orientation as a GEM index.""" + # Assume self._entity_orientation dict is set up at this point. + return self._entity_orientation[restriction] + def apply_glue(self, prepare=None, finalise=None): """Append glue code for operations that are not handled in the GEM abstraction. diff --git a/tsfc/kernel_interface/firedrake_loopy.py b/tsfc/kernel_interface/firedrake_loopy.py index c6d671b7..cc35a7c5 100644 --- a/tsfc/kernel_interface/firedrake_loopy.py +++ b/tsfc/kernel_interface/firedrake_loopy.py @@ -11,7 +11,7 @@ import loopy as lp -from tsfc import kernel_args +from tsfc import kernel_args, fem from tsfc.finatinterface import create_element from tsfc.kernel_interface.common import KernelBuilderBase as _KernelBuilderBase, KernelBuilderMixin, get_index_names, check_requirements, prepare_coefficient, prepare_arguments, prepare_constant from tsfc.loopy import generate as generate_loopy @@ -259,14 +259,26 @@ def __init__(self, integral_data_info, scalar_type, if integral_type in ['exterior_facet', 'exterior_facet_vert']: facet = gem.Variable('facet', (1,)) self._entity_number = {None: gem.VariableIndex(gem.Indexed(facet, (0,)))} + facet_orientation = gem.Variable('facet_orientation', (1,), dtype=gem.uint_type) + self._entity_orientation = {None: gem.OrientationVariableIndex(gem.Indexed(facet_orientation, (0,)))} elif integral_type in ['interior_facet', 'interior_facet_vert']: facet = gem.Variable('facet', (2,)) self._entity_number = { '+': gem.VariableIndex(gem.Indexed(facet, (0,))), '-': gem.VariableIndex(gem.Indexed(facet, (1,))) } + facet_orientation = gem.Variable('facet_orientation', (2,), dtype=gem.uint_type) + self._entity_orientation = { + '+': gem.OrientationVariableIndex(gem.Indexed(facet_orientation, (0,))), + '-': gem.OrientationVariableIndex(gem.Indexed(facet_orientation, (1,))) + } elif integral_type == 'interior_facet_horiz': self._entity_number = {'+': 1, '-': 0} + facet_orientation = gem.Variable('facet_orientation', (1,), dtype=gem.uint_type) # base mesh entity orientation + self._entity_orientation = { + '+': gem.OrientationVariableIndex(gem.Indexed(facet_orientation, (0,))), + '-': gem.OrientationVariableIndex(gem.Indexed(facet_orientation, (0,))) + } self.set_arguments(integral_data_info.arguments) self.integral_data_info = integral_data_info @@ -406,6 +418,14 @@ def construct_kernel(self, name, ctx, log=False): elif info.integral_type in ["interior_facet", "interior_facet_vert"]: int_loopy_arg = lp.GlobalArg("facet", numpy.uint32, shape=(2,)) args.append(kernel_args.InteriorFacetKernelArg(int_loopy_arg)) + # Will generalise this in the submesh PR. + if fem.PointSetContext(**self.fem_config()).use_canonical_quadrature_point_ordering: + if info.integral_type == "exterior_facet": + ext_ornt_loopy_arg = lp.GlobalArg("facet_orientation", gem.uint_type, shape=(1,)) + args.append(kernel_args.ExteriorFacetOrientationKernelArg(ext_ornt_loopy_arg)) + elif info.integral_type == "interior_facet": + int_ornt_loopy_arg = lp.GlobalArg("facet_orientation", gem.uint_type, shape=(2,)) + args.append(kernel_args.InteriorFacetOrientationKernelArg(int_ornt_loopy_arg)) for name_, shape in tabulations: tab_loopy_arg = lp.GlobalArg(name_, dtype=self.scalar_type, shape=shape) args.append(kernel_args.TabulationKernelArg(tab_loopy_arg)) diff --git a/tsfc/loopy.py b/tsfc/loopy.py index 943143ff..d19da423 100644 --- a/tsfc/loopy.py +++ b/tsfc/loopy.py @@ -2,6 +2,7 @@ This is the final stage of code generation in TSFC.""" +from numbers import Integral import numpy from functools import singledispatch from collections import defaultdict, OrderedDict @@ -49,6 +50,11 @@ def _assign_dtype_terminal(expression, self): return {self.scalar_type} +@_assign_dtype.register(gem.Variable) +def _assign_dtype_variable(expression, self): + return {expression.dtype or self.scalar_type} + + @_assign_dtype.register(gem.Zero) @_assign_dtype.register(gem.Identity) @_assign_dtype.register(gem.Delta) @@ -420,6 +426,16 @@ def _expression_division(expr, ctx): return p.Quotient(*(expression(c, ctx) for c in expr.children)) +@_expression.register(gem.FloorDiv) +def _expression_floordiv(expr, ctx): + return p.FloorDiv(*(expression(c, ctx) for c in expr.children)) + + +@_expression.register(gem.Remainder) +def _expression_remainder(expr, ctx): + return p.Remainder(*(expression(c, ctx) for c in expr.children)) + + @_expression.register(gem.Power) def _expression_power(expr, ctx): return p.Variable("pow")(*(expression(c, ctx) for c in expr.children)) @@ -538,12 +554,19 @@ def _expression_flexiblyindexed(expr, ctx): rank = [] for off, idxs in expr.dim2idxs: + rank_ = [expression(off, ctx)] for index, stride in idxs: - assert isinstance(index, gem.Index) - - rank_ = [off] - for index, stride in idxs: - rank_.append(p.Product((ctx.active_indices[index], stride))) + if isinstance(index, gem.Index): + rank_.append(p.Product((ctx.active_indices[index], expression(stride, ctx)))) + elif isinstance(index, gem.VariableIndex): + rank_.append(p.Product((expression(index.expression, ctx), expression(stride, ctx)))) + else: + raise ValueError(f"Expecting Index or VariableIndex, not {type(index)}") rank.append(p.Sum(tuple(rank_))) return p.Subscript(var, tuple(rank)) + + +@_expression.register(Integral) +def _expression_numbers_integral(expr, ctx): + return expr