Skip to content
This repository has been archived by the owner on Dec 6, 2024. It is now read-only.

Commit

Permalink
generalise VariableIndex and FlexiblyIndexed (#317)
Browse files Browse the repository at this point in the history
hex: enable interior facet integration
  • Loading branch information
ksagiyam authored Nov 6, 2024
1 parent 8d9abb0 commit 1944432
Show file tree
Hide file tree
Showing 11 changed files with 453 additions and 84 deletions.
325 changes: 268 additions & 57 deletions gem/gem.py

Large diffs are not rendered by default.

47 changes: 40 additions & 7 deletions gem/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
expression DAG languages."""

import collections
import gem


class Node(object):
Expand Down Expand Up @@ -99,8 +100,23 @@ def get_hash(self):
return hash((type(self),) + self._cons_args(self.children))


def _make_traversal_children(node):
if isinstance(node, (gem.Indexed, gem.FlexiblyIndexed)):
# Include child nodes hidden in index expressions.
return node.children + node.indirect_children
else:
return node.children


def pre_traversal(expression_dags):
"""Pre-order traversal of the nodes of expression DAGs."""
"""Pre-order traversal of the nodes of expression DAGs.
Notes
-----
This function also walks through nodes in index expressions
(e.g., `VariableIndex`s); see ``_make_traversal_children()``.
"""
seen = set()
lifo = []
# Some roots might be same, but they must be visited only once.
Expand All @@ -114,14 +130,23 @@ def pre_traversal(expression_dags):
while lifo:
node = lifo.pop()
yield node
for child in reversed(node.children):
children = _make_traversal_children(node)
for child in reversed(children):
if child not in seen:
seen.add(child)
lifo.append(child)


def post_traversal(expression_dags):
"""Post-order traversal of the nodes of expression DAGs."""
"""Post-order traversal of the nodes of expression DAGs.
Notes
-----
This function also walks through nodes in index expressions
(e.g., `VariableIndex`s); see ``_make_traversal_children()``.
"""
seen = set()
lifo = []
# Some roots might be same, but they must be visited only once.
Expand All @@ -130,13 +155,13 @@ def post_traversal(expression_dags):
for root in expression_dags:
if root not in seen:
seen.add(root)
lifo.append((root, list(root.children)))
lifo.append((root, list(_make_traversal_children(root))))

while lifo:
node, deps = lifo[-1]
for i, dep in enumerate(deps):
if dep is not None and dep not in seen:
lifo.append((dep, list(dep.children)))
lifo.append((dep, list(_make_traversal_children(dep))))
deps[i] = None
break
else:
Expand All @@ -150,10 +175,18 @@ def post_traversal(expression_dags):


def collect_refcount(expression_dags):
"""Collects reference counts for a multi-root expression DAG."""
"""Collects reference counts for a multi-root expression DAG.
Notes
-----
This function also collects reference counts of nodes
in index expressions (e.g., `VariableIndex`s); see
``_make_traversal_children()``.
"""
result = collections.Counter(expression_dags)
for node in traversal(expression_dags):
result.update(node.children)
result.update(_make_traversal_children(node))
return result


Expand Down
27 changes: 21 additions & 6 deletions gem/optimise.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from collections import OrderedDict, defaultdict
from functools import singledispatch, partial, reduce
from itertools import combinations, permutations, zip_longest
from numbers import Integral

import numpy

Expand Down Expand Up @@ -95,11 +96,19 @@ def replace_indices(node, self, subst):
replace_indices.register(Node)(reuse_if_untouched_arg)


def _replace_indices_atomic(i, self, subst):
if isinstance(i, VariableIndex):
new_expr = self(i.expression, subst)
return i if new_expr == i.expression else VariableIndex(new_expr)
else:
substitute = dict(subst)
return substitute.get(i, i)


@replace_indices.register(Delta)
def replace_indices_delta(node, self, subst):
substitute = dict(subst)
i = substitute.get(node.i, node.i)
j = substitute.get(node.j, node.j)
i = _replace_indices_atomic(node.i, self, subst)
j = _replace_indices_atomic(node.j, self, subst)
if i == node.i and j == node.j:
return node
else:
Expand All @@ -110,7 +119,9 @@ def replace_indices_delta(node, self, subst):
def replace_indices_indexed(node, self, subst):
child, = node.children
substitute = dict(subst)
multiindex = tuple(substitute.get(i, i) for i in node.multiindex)
multiindex = []
for i in node.multiindex:
multiindex.append(_replace_indices_atomic(i, self, subst))
if isinstance(child, ComponentTensor):
# Indexing into ComponentTensor
# Inline ComponentTensor and augment the substitution rules
Expand All @@ -130,9 +141,11 @@ def replace_indices_flexiblyindexed(node, self, subst):
child, = node.children
assert not child.free_indices

substitute = dict(subst)
dim2idxs = tuple(
(offset, tuple((substitute.get(i, i), s) for i, s in idxs))
(
offset if isinstance(offset, Integral) else _replace_indices_atomic(offset, self, subst),
tuple((_replace_indices_atomic(i, self, subst), s if isinstance(s, Integral) else self(s, subst)) for i, s in idxs)
)
for offset, idxs in node.dim2idxs
)

Expand All @@ -145,6 +158,8 @@ def replace_indices_flexiblyindexed(node, self, subst):
def filtered_replace_indices(node, self, subst):
"""Wrapper for :func:`replace_indices`. At each call removes
substitution rules that do not apply."""
if any(isinstance(k, VariableIndex) for k, _ in subst):
raise NotImplementedError("Can not replace VariableIndex (will need inverse)")
filtered_subst = tuple((k, v) for k, v in subst if k in node.free_indices)
return replace_indices(node, self, filtered_subst)

Expand Down
9 changes: 7 additions & 2 deletions gem/scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import collections
import functools
import itertools

from gem import gem, impero
from gem.node import collect_refcount
Expand Down Expand Up @@ -116,8 +117,12 @@ def handle(ops, push, decref, node):
elif isinstance(node, gem.Zero): # should rarely happen
assert not node.shape
elif isinstance(node, (gem.Indexed, gem.FlexiblyIndexed)):
# Indexing always inlined
decref(node.children[0])
if node.indirect_children:
# Do not inline;
# Index expression can be involved if it contains VariableIndex.
ops.append(impero.Evaluate(node))
for child in itertools.chain(node.children, node.indirect_children):
decref(child)
elif isinstance(node, gem.IndexSum):
ops.append(impero.Noop(node))
push(impero.Accumulate(node))
Expand Down
3 changes: 0 additions & 3 deletions tsfc/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,6 @@ def compile_integral(integral_data, form_data, prefix, parameters, interface, *,
:arg log: bool if the Kernel should be profiled with Log events
:returns: a kernel constructed by the kernel interface
"""
if integral_data.domain.ufl_cell().cellname() == "hexahedron" and \
integral_data.integral_type == "interior_facet":
raise NotImplementedError("interior facet integration in hex meshes not currently supported")
parameters = preprocess_parameters(parameters)
if interface is None:
interface = firedrake_interface_loopy.KernelBuilder
Expand Down
54 changes: 51 additions & 3 deletions tsfc/fem.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
import gem
import numpy
import ufl
from FIAT.reference_element import UFCSimplex, make_affine_mapping
from FIAT.orientation_utils import Orientation as FIATOrientation
from FIAT.reference_element import UFCHexahedron, UFCSimplex, make_affine_mapping
from FIAT.reference_element import TensorProductCell
from finat.physically_mapped import (NeedsCoordinateMappingElement,
PhysicalGeometry)
from finat.point_set import PointSet, PointSingleton
Expand Down Expand Up @@ -108,6 +110,10 @@ def translator(self):
# NOTE: reference cycle!
return Translator(self)

@cached_property
def use_canonical_quadrature_point_ordering(self):
return isinstance(self.fiat_cell, UFCHexahedron) and self.integral_type in ['exterior_facet', 'interior_facet']


class CoordinateMapping(PhysicalGeometry):
"""Callback class that provides physical geometry to FInAT elements.
Expand Down Expand Up @@ -266,10 +272,13 @@ class PointSetContext(ContextBase):
'weight_expr',
)

@cached_property
def integration_cell(self):
return self.fiat_cell.construct_subelement(self.integration_dim)

@cached_property
def quadrature_rule(self):
integration_cell = self.fiat_cell.construct_subelement(self.integration_dim)
return make_quadrature(integration_cell, self.quadrature_degree)
return make_quadrature(self.integration_cell, self.quadrature_degree)

@cached_property
def point_set(self):
Expand Down Expand Up @@ -629,6 +638,11 @@ def callback(entity_id):
# lives on after ditching FFC and switching to FInAT.
return ffc_rounding(square, ctx.epsilon)
table = ctx.entity_selector(callback, mt.restriction)
if ctx.use_canonical_quadrature_point_ordering:
quad_multiindex = ctx.quadrature_rule.point_set.indices
quad_multiindex_permuted = _make_quad_multiindex_permuted(mt, ctx)
mapper = gem.node.MemoizerArg(gem.optimise.filtered_replace_indices)
table = mapper(table, tuple(zip(quad_multiindex, quad_multiindex_permuted)))
return gem.ComponentTensor(gem.Indexed(table, argument_multiindex + sigma), sigma)


Expand Down Expand Up @@ -698,9 +712,43 @@ def take_singleton(xs):
for node in traversal((result,))
if isinstance(node, gem.Literal)):
result = gem.optimise.aggressive_unroll(result)

if ctx.use_canonical_quadrature_point_ordering:
quad_multiindex = ctx.quadrature_rule.point_set.indices
quad_multiindex_permuted = _make_quad_multiindex_permuted(mt, ctx)
mapper = gem.node.MemoizerArg(gem.optimise.filtered_replace_indices)
result = mapper(result, tuple(zip(quad_multiindex, quad_multiindex_permuted)))
return result


def _make_quad_multiindex_permuted(mt, ctx):
quad_rule = ctx.quadrature_rule
# Note that each quad index here represents quad points on a physical
# cell axis, but the table is indexed by indices representing the points
# on each reference cell axis, so we need to apply permutation based on the orientation.
cell = quad_rule.ref_el
quad_multiindex = quad_rule.point_set.indices
if isinstance(cell, TensorProductCell):
for comp in set(cell.cells):
extents = set(q.extent for c, q in zip(cell.cells, quad_multiindex) if c == comp)
if len(extents) != 1:
raise ValueError("Must have the same number of quadrature points in each symmetric axis")
quad_multiindex_permuted = []
o = ctx.entity_orientation(mt.restriction)
if not isinstance(o, FIATOrientation):
raise ValueError(f"Expecting an instance of FIATOrientation : got {o}")
eo = cell.extract_extrinsic_orientation(o)
eo_perm_map = gem.Literal(quad_rule.extrinsic_orientation_permutation_map, dtype=gem.uint_type)
for ref_axis in range(len(quad_multiindex)):
io = cell.extract_intrinsic_orientation(o, ref_axis)
io_perm_map = gem.Literal(quad_rule.intrinsic_orientation_permutation_map_tuple[ref_axis], dtype=gem.uint_type)
# Effectively swap axes if needed.
ref_index = tuple((phys_index, gem.Indexed(eo_perm_map, (eo, ref_axis, phys_axis))) for phys_axis, phys_index in enumerate(quad_multiindex))
quad_index_permuted = gem.VariableIndex(gem.FlexiblyIndexed(io_perm_map, ((0, ((io, 1), )), (0, ref_index))))
quad_multiindex_permuted.append(quad_index_permuted)
return tuple(quad_multiindex_permuted)


def compile_ufl(expression, context, interior_facet=False, point_sum=False):
"""Translate a UFL expression to GEM.
Expand Down
8 changes: 8 additions & 0 deletions tsfc/kernel_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,11 @@ class ExteriorFacetKernelArg(KernelArg):

class InteriorFacetKernelArg(KernelArg):
...


class ExteriorFacetOrientationKernelArg(KernelArg):
...


class InteriorFacetOrientationKernelArg(KernelArg):
...
4 changes: 4 additions & 0 deletions tsfc/kernel_interface/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ def cell_size(self, restriction):
def entity_number(self, restriction):
"""Facet or vertex number as a GEM index."""

@abstractmethod
def entity_orientation(self, restriction):
"""Entity orientation as a GEM index."""

@abstractmethod
def create_element(self, element, **kwargs):
"""Create a FInAT element (suitable for tabulating with) given
Expand Down
5 changes: 5 additions & 0 deletions tsfc/kernel_interface/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ def entity_number(self, restriction):
# Assume self._entity_number dict is set up at this point.
return self._entity_number[restriction]

def entity_orientation(self, restriction):
"""Facet orientation as a GEM index."""
# Assume self._entity_orientation dict is set up at this point.
return self._entity_orientation[restriction]

def apply_glue(self, prepare=None, finalise=None):
"""Append glue code for operations that are not handled in the
GEM abstraction.
Expand Down
22 changes: 21 additions & 1 deletion tsfc/kernel_interface/firedrake_loopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import loopy as lp

from tsfc import kernel_args
from tsfc import kernel_args, fem
from tsfc.finatinterface import create_element
from tsfc.kernel_interface.common import KernelBuilderBase as _KernelBuilderBase, KernelBuilderMixin, get_index_names, check_requirements, prepare_coefficient, prepare_arguments, prepare_constant
from tsfc.loopy import generate as generate_loopy
Expand Down Expand Up @@ -259,14 +259,26 @@ def __init__(self, integral_data_info, scalar_type,
if integral_type in ['exterior_facet', 'exterior_facet_vert']:
facet = gem.Variable('facet', (1,))
self._entity_number = {None: gem.VariableIndex(gem.Indexed(facet, (0,)))}
facet_orientation = gem.Variable('facet_orientation', (1,), dtype=gem.uint_type)
self._entity_orientation = {None: gem.OrientationVariableIndex(gem.Indexed(facet_orientation, (0,)))}
elif integral_type in ['interior_facet', 'interior_facet_vert']:
facet = gem.Variable('facet', (2,))
self._entity_number = {
'+': gem.VariableIndex(gem.Indexed(facet, (0,))),
'-': gem.VariableIndex(gem.Indexed(facet, (1,)))
}
facet_orientation = gem.Variable('facet_orientation', (2,), dtype=gem.uint_type)
self._entity_orientation = {
'+': gem.OrientationVariableIndex(gem.Indexed(facet_orientation, (0,))),
'-': gem.OrientationVariableIndex(gem.Indexed(facet_orientation, (1,)))
}
elif integral_type == 'interior_facet_horiz':
self._entity_number = {'+': 1, '-': 0}
facet_orientation = gem.Variable('facet_orientation', (1,), dtype=gem.uint_type) # base mesh entity orientation
self._entity_orientation = {
'+': gem.OrientationVariableIndex(gem.Indexed(facet_orientation, (0,))),
'-': gem.OrientationVariableIndex(gem.Indexed(facet_orientation, (0,)))
}

self.set_arguments(integral_data_info.arguments)
self.integral_data_info = integral_data_info
Expand Down Expand Up @@ -406,6 +418,14 @@ def construct_kernel(self, name, ctx, log=False):
elif info.integral_type in ["interior_facet", "interior_facet_vert"]:
int_loopy_arg = lp.GlobalArg("facet", numpy.uint32, shape=(2,))
args.append(kernel_args.InteriorFacetKernelArg(int_loopy_arg))
# Will generalise this in the submesh PR.
if fem.PointSetContext(**self.fem_config()).use_canonical_quadrature_point_ordering:
if info.integral_type == "exterior_facet":
ext_ornt_loopy_arg = lp.GlobalArg("facet_orientation", gem.uint_type, shape=(1,))
args.append(kernel_args.ExteriorFacetOrientationKernelArg(ext_ornt_loopy_arg))
elif info.integral_type == "interior_facet":
int_ornt_loopy_arg = lp.GlobalArg("facet_orientation", gem.uint_type, shape=(2,))
args.append(kernel_args.InteriorFacetOrientationKernelArg(int_ornt_loopy_arg))
for name_, shape in tabulations:
tab_loopy_arg = lp.GlobalArg(name_, dtype=self.scalar_type, shape=shape)
args.append(kernel_args.TabulationKernelArg(tab_loopy_arg))
Expand Down
Loading

0 comments on commit 1944432

Please sign in to comment.