Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GEM: Simplify Indexed tensors #131

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 34 additions & 42 deletions finat/fiat_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None):
:param ps: the point set.
:param entity: the cell entity on which to tabulate.
'''
space_dimension = self._element.space_dimension()
value_size = np.prod(self._element.value_shape(), dtype=int)
value_shape = self.value_shape
value_size = np.prod(value_shape, dtype=int)
fiat_result = self._element.tabulate(order, ps.points, entity)
result = {}
# In almost all cases, we have
Expand All @@ -109,52 +109,44 @@ def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None):
# basis functions, and the additional 3 are for
# dealing with transformations between physical
# and reference space).
index_shape = (self._element.space_dimension(),)
space_dimension = self._element.space_dimension()
if self.space_dimension() == space_dimension:
beta = self.get_indices()
index_shape = tuple(index.extent for index in beta)
else:
index_shape = (space_dimension,)
beta = tuple(gem.Index(extent=i) for i in index_shape)
assert len(beta) == len(self.get_indices())

zeta = self.get_value_indices()
result_indices = beta + zeta

for alpha, fiat_table in fiat_result.items():
if isinstance(fiat_table, Exception):
result[alpha] = gem.Failure(self.index_shape + self.value_shape, fiat_table)
result[alpha] = gem.Failure(index_shape + value_shape, fiat_table)
continue

derivative = sum(alpha)
table_roll = fiat_table.reshape(
space_dimension, value_size, len(ps.points)
).transpose(1, 2, 0)

exprs = []
for table in table_roll:
if derivative == self.degree and not self.complex.is_macrocell():
# Make sure numerics satisfies theory
exprs.append(gem.Literal(table[0]))
elif derivative > self.degree:
# Make sure numerics satisfies theory
assert np.allclose(table, 0.0)
exprs.append(gem.Literal(np.zeros(self.index_shape)))
else:
point_indices = ps.indices
point_shape = tuple(index.extent for index in point_indices)

exprs.append(gem.partial_indexed(
gem.Literal(table.reshape(point_shape + index_shape)),
point_indices
))
if self.value_shape:
# As above, this extent may be different from that
# advertised by the finat element.
beta = tuple(gem.Index(extent=i) for i in index_shape)
assert len(beta) == len(self.get_indices())

zeta = self.get_value_indices()
result[alpha] = gem.ComponentTensor(
gem.Indexed(
gem.ListTensor(np.array(
[gem.Indexed(expr, beta) for expr in exprs]
).reshape(self.value_shape)),
zeta),
beta + zeta
)
fiat_table = fiat_table.reshape(space_dimension, value_size, -1)

point_indices = ()
if derivative == self.degree and not self.complex.is_macrocell():
# Make sure numerics satisfies theory
fiat_table = fiat_table[..., 0]
elif derivative > self.degree:
# Make sure numerics satisfies theory
assert np.allclose(fiat_table, 0.0)
fiat_table = np.zeros(fiat_table.shape[:-1])
else:
expr, = exprs
result[alpha] = expr
point_indices = ps.indices

point_shape = tuple(index.extent for index in point_indices)
table_shape = index_shape + value_shape + point_shape
table_indices = beta + zeta + point_indices

expr = gem.Indexed(gem.Literal(fiat_table.reshape(table_shape)), table_indices)
expr = gem.ComponentTensor(expr, result_indices)
result[alpha] = expr
return result

def point_evaluation(self, order, refcoords, entity=None):
Expand Down
62 changes: 45 additions & 17 deletions finat/physically_mapped.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from abc import ABCMeta, abstractmethod
from collections.abc import Mapping
from functools import reduce

import gem
import numpy
Expand Down Expand Up @@ -247,6 +249,46 @@ class NeedsCoordinateMappingElement(metaclass=ABCMeta):
pass


class MappedTabulation(Mapping):
"""A lazy tabulation dict that applies the basis transformation only
on the requested derivatives."""

def __init__(self, M, ref_tabulation):
self.M = M
self.ref_tabulation = ref_tabulation
# we expect M to be sparse with O(1) nonzeros per row
# for each row, get the column index of each nonzero entry
csr = [[j for j in range(M.shape[1]) if not isinstance(M.array[i, j], gem.Zero)]
for i in range(M.shape[0])]
self.csr = csr
self._tabulation_cache = {}

def matvec(self, table):
# basis recombination using hand-rolled sparse-dense matrix multiplication
ii = gem.indices(len(table.shape)-1)
phi = [gem.Indexed(table, (j, *ii)) for j in range(self.M.shape[1])]
# the sum approach is faster than calling numpy.dot or gem.IndexSum
exprs = [gem.ComponentTensor(reduce(gem.Sum, (self.M.array[i, j] * phi[j] for j in js)), ii)
for i, js in enumerate(self.csr)]

val = gem.ListTensor(exprs)
# val = self.M @ table
return gem.optimise.aggressive_unroll(val)

def __getitem__(self, alpha):
try:
return self._tabulation_cache[alpha]
except KeyError:
result = self.matvec(self.ref_tabulation[alpha])
return self._tabulation_cache.setdefault(alpha, result)

def __iter__(self):
return iter(self.ref_tabulation)

def __len__(self):
return len(self.ref_tabulation)


class PhysicallyMappedElement(NeedsCoordinateMappingElement):
"""A mixin that applies a "physical" transformation to tabulated
basis functions."""
Expand All @@ -267,24 +309,10 @@ def basis_transformation(self, coordinate_mapping):
def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None):
assert coordinate_mapping is not None

M = self.basis_transformation(coordinate_mapping)
# we expect M to be sparse with O(1) nonzeros per row
# for each row, get the column index of each nonzero entry
csr = [[j for j in range(M.shape[1]) if not isinstance(M.array[i, j], gem.Zero)]
for i in range(M.shape[0])]

def matvec(table):
# basis recombination using hand-rolled sparse-dense matrix multiplication
table = [gem.partial_indexed(table, (j,)) for j in range(M.shape[1])]
# the sum approach is faster than calling numpy.dot or gem.IndexSum
expressions = [sum(M.array[i, j] * table[j] for j in js) for i, js in enumerate(csr)]
val = gem.ListTensor(expressions)
return gem.optimise.aggressive_unroll(val)
ref_tabulation = super().basis_evaluation(order, ps, entity=entity)

result = super().basis_evaluation(order, ps, entity=entity)

return {alpha: matvec(table)
for alpha, table in result.items()}
M = self.basis_transformation(coordinate_mapping)
return MappedTabulation(M, ref_tabulation)

def point_evaluation(self, order, refcoords, entity=None):
raise NotImplementedError("TODO: not yet thought about it")
Expand Down
7 changes: 3 additions & 4 deletions gem/coffee.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
This file is NOT for code generation as a COFFEE AST.
"""

from collections import OrderedDict
import itertools
from itertools import chain, repeat
import logging

import numpy
Expand Down Expand Up @@ -58,10 +57,10 @@ def find_optimal_atomics(monomials, linear_indices):

:returns: list of atomic GEM expressions
"""
atomics = tuple(OrderedDict.fromkeys(itertools.chain(*(monomial.atomics for monomial in monomials))))
atomics = tuple(dict.fromkeys(chain.from_iterable(monomial.atomics for monomial in monomials)))

def cost(solution):
extent = sum(map(lambda atomic: index_extent(atomic, linear_indices), solution))
extent = sum(map(index_extent, solution, repeat(linear_indices)))
# Prefer shorter solutions, but larger extents
return (len(solution), -extent)

Expand Down
71 changes: 51 additions & 20 deletions gem/gem.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ def __call__(self, *args, **kwargs):

# Set free_indices if not set already
if not hasattr(obj, 'free_indices'):
obj.free_indices = unique(chain(*[c.free_indices
for c in obj.children]))
obj.free_indices = unique(chain.from_iterable(c.free_indices
for c in obj.children))
# Set dtype if not set already.
if not hasattr(obj, 'dtype'):
obj.dtype = obj.inherit_dtype_from_children(obj.children)
Expand Down Expand Up @@ -117,9 +117,8 @@ def __matmul__(self, other):
raise ValueError(f"Mismatching shapes {self.shape} and {other.shape} in matmul")
*i, k = indices(len(self.shape))
_, *j = indices(len(other.shape))
expr = Product(Indexed(self, tuple(i) + (k, )),
Indexed(other, (k, ) + tuple(j)))
return ComponentTensor(IndexSum(expr, (k, )), tuple(i) + tuple(j))
expr = Product(Indexed(self, (*i, k)), Indexed(other, (k, *j)))
return ComponentTensor(IndexSum(expr, (k, )), (*i, *j))

def __rmatmul__(self, other):
return as_gem(other).__matmul__(self)
Expand Down Expand Up @@ -335,7 +334,7 @@ def __new__(cls, a, b):
return a

if isinstance(a, Constant) and isinstance(b, Constant):
return Literal(a.value + b.value, dtype=Node.inherit_dtype_from_children([a, b]))
return Literal(a.value + b.value, dtype=Node.inherit_dtype_from_children((a, b)))

self = super(Sum, cls).__new__(cls)
self.children = a, b
Expand All @@ -359,7 +358,7 @@ def __new__(cls, a, b):
return a

if isinstance(a, Constant) and isinstance(b, Constant):
return Literal(a.value * b.value, dtype=Node.inherit_dtype_from_children([a, b]))
return Literal(a.value * b.value, dtype=Node.inherit_dtype_from_children((a, b)))

self = super(Product, cls).__new__(cls)
self.children = a, b
Expand All @@ -383,7 +382,7 @@ def __new__(cls, a, b):
return a

if isinstance(a, Constant) and isinstance(b, Constant):
return Literal(a.value / b.value, dtype=Node.inherit_dtype_from_children([a, b]))
return Literal(a.value / b.value, dtype=Node.inherit_dtype_from_children((a, b)))

self = super(Division, cls).__new__(cls)
self.children = a, b
Expand All @@ -396,7 +395,7 @@ class FloorDiv(Scalar):
def __new__(cls, a, b):
assert not a.shape
assert not b.shape
dtype = Node.inherit_dtype_from_children([a, b])
dtype = Node.inherit_dtype_from_children((a, b))
if dtype != uint_type:
raise ValueError(f"dtype ({dtype}) != unit_type ({uint_type})")
# Constant folding
Expand All @@ -419,7 +418,7 @@ class Remainder(Scalar):
def __new__(cls, a, b):
assert not a.shape
assert not b.shape
dtype = Node.inherit_dtype_from_children([a, b])
dtype = Node.inherit_dtype_from_children((a, b))
if dtype != uint_type:
raise ValueError(f"dtype ({dtype}) != uint_type ({uint_type})")
# Constant folding
Expand All @@ -442,7 +441,7 @@ class Power(Scalar):
def __new__(cls, base, exponent):
assert not base.shape
assert not exponent.shape
dtype = Node.inherit_dtype_from_children([base, exponent])
dtype = Node.inherit_dtype_from_children((base, exponent))

# Constant folding
if isinstance(base, Zero):
Expand Down Expand Up @@ -558,7 +557,7 @@ def __new__(cls, condition, then, else_):
self = super(Conditional, cls).__new__(cls)
self.children = condition, then, else_
self.shape = then.shape
self.dtype = Node.inherit_dtype_from_children([then, else_])
self.dtype = Node.inherit_dtype_from_children((then, else_))
return self


Expand Down Expand Up @@ -674,12 +673,31 @@ def __new__(cls, aggregate, multiindex):
if isinstance(aggregate, Zero):
return Zero(dtype=aggregate.dtype)

# All indices fixed
if all(isinstance(i, int) for i in multiindex):
if isinstance(aggregate, Constant):
return Literal(aggregate.array[multiindex], dtype=aggregate.dtype)
elif isinstance(aggregate, ListTensor):
return aggregate.array[multiindex]
# Simplify Literal and ListTensor
if isinstance(aggregate, (Constant, ListTensor)):
if all(isinstance(i, int) for i in multiindex):
# All indices fixed
sub = aggregate.array[multiindex]
return Literal(sub, dtype=aggregate.dtype) if isinstance(aggregate, Constant) else sub
elif any(isinstance(i, int) for i in multiindex) and all(isinstance(i, (int, Index)) for i in multiindex):
# Some indices fixed
slices = tuple(i if isinstance(i, int) else slice(None) for i in multiindex)
sub = aggregate.array[slices]
sub = Literal(sub, dtype=aggregate.dtype) if isinstance(aggregate, Constant) else ListTensor(sub)
return Indexed(sub, tuple(i for i in multiindex if not isinstance(i, int)))

# Simplify Indexed(ComponentTensor(Indexed(C, kk), jj), ii) -> Indexed(C, ll)
if isinstance(aggregate, ComponentTensor):
B, = aggregate.children
jj = aggregate.multiindex
if isinstance(B, Indexed):
C, = B.children
kk = B.multiindex
if all(j in kk for j in jj):
ii = tuple(multiindex)
rep = dict(zip(jj, ii))
ll = tuple(rep.get(k, k) for k in kk)
return Indexed(C, ll)

self = super(Indexed, cls).__new__(cls)
self.children = (aggregate,)
Expand Down Expand Up @@ -825,6 +843,11 @@ def __new__(cls, expression, multiindex):
if isinstance(expression, Zero):
return Zero(shape, dtype=expression.dtype)

# Index folding
if isinstance(expression, Indexed):
if multiindex == expression.multiindex:
return expression.children[0]

self = super(ComponentTensor, cls).__new__(cls)
self.children = (expression,)
self.multiindex = multiindex
Expand Down Expand Up @@ -881,9 +904,17 @@ def __new__(cls, array):
dtype = Node.inherit_dtype_from_children(tuple(array.flat))

# Handle children with shape
child_shape = array.flat[0].shape
e0 = array.flat[0]
child_shape = e0.shape
assert all(elem.shape == child_shape for elem in array.flat)

# Index folding
if child_shape == array.shape:
if all(isinstance(elem, Indexed) for elem in array.flat):
if all(elem.children == e0.children for elem in array.flat[1:]):
if all(elem.multiindex == idx for elem, idx in zip(array.flat, numpy.ndindex(array.shape))):
return e0.children[0]

if child_shape:
# Destroy structure
direct_array = numpy.empty(array.shape + child_shape, dtype=object)
Expand Down Expand Up @@ -921,7 +952,7 @@ def is_equal(self, other):
"""Common subexpression eliminating equality predicate."""
if type(self) is not type(other):
return False
if (self.array == other.array).all():
if numpy.array_equal(self.array, other.array):
self.array = other.array
return True
return False
Expand Down
Loading
Loading