diff --git a/gem/gem.py b/gem/gem.py index 8369b6f7..01bf2c8d 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -117,9 +117,8 @@ def __matmul__(self, other): raise ValueError(f"Mismatching shapes {self.shape} and {other.shape} in matmul") *i, k = indices(len(self.shape)) _, *j = indices(len(other.shape)) - expr = Product(Indexed(self, tuple(i) + (k, )), - Indexed(other, (k, ) + tuple(j))) - return ComponentTensor(IndexSum(expr, (k, )), tuple(i) + tuple(j)) + expr = Product(Indexed(self, (*i, k)), Indexed(other, (k, *j))) + return ComponentTensor(IndexSum(expr, (k, )), (*i, *j)) def __rmatmul__(self, other): return as_gem(other).__matmul__(self) @@ -307,6 +306,9 @@ def value(self): def shape(self): return self.array.shape + def __getitem__(self, i): + return self.array[i] + class Variable(Terminal): """Symbolic variable tensor""" @@ -674,12 +676,31 @@ def __new__(cls, aggregate, multiindex): if isinstance(aggregate, Zero): return Zero(dtype=aggregate.dtype) - # All indices fixed - if all(isinstance(i, int) for i in multiindex): - if isinstance(aggregate, Constant): - return Literal(aggregate.array[multiindex], dtype=aggregate.dtype) - elif isinstance(aggregate, ListTensor): - return aggregate.array[multiindex] + # Simplify Indexed(ComponentTensor(Indexed(C, kk), jj), ii) -> Indexed(C, ll) + if isinstance(aggregate, ComponentTensor): + B, = aggregate.children + jj = aggregate.multiindex + if isinstance(B, Indexed): + C, = B.children + kk = B.multiindex + if all(j in kk for j in jj): + ii = tuple(multiindex) + rep = dict(zip(jj, ii)) + ll = tuple(rep.get(k, k) for k in kk) + return Indexed(C, ll) + + # Simplify Constant and ListTensor + if isinstance(aggregate, (Constant, ListTensor)): + if all(isinstance(i, int) for i in multiindex): + # All indices fixed + sub = aggregate[multiindex] + return Literal(sub, dtype=aggregate.dtype) if isinstance(aggregate, Constant) else sub + elif any(isinstance(i, int) for i in multiindex) and all(isinstance(i, (int, Index)) for i in multiindex): + # Some indices fixed + slices = tuple(i if isinstance(i, int) else slice(None) for i in multiindex) + sub = aggregate[slices] + sub = Literal(sub, dtype=aggregate.dtype) if isinstance(aggregate, Constant) else ListTensor(sub) + return Indexed(sub, tuple(i for i in multiindex if not isinstance(i, int))) self = super(Indexed, cls).__new__(cls) self.children = (aggregate,) @@ -825,6 +846,11 @@ def __new__(cls, expression, multiindex): if isinstance(expression, Zero): return Zero(shape, dtype=expression.dtype) + # Index folding + if isinstance(expression, Indexed): + if multiindex == expression.multiindex: + return expression.children[0] + self = super(ComponentTensor, cls).__new__(cls) self.children = (expression,) self.multiindex = multiindex @@ -881,9 +907,17 @@ def __new__(cls, array): dtype = Node.inherit_dtype_from_children(tuple(array.flat)) # Handle children with shape - child_shape = array.flat[0].shape + e0 = array.flat[0] + child_shape = e0.shape assert all(elem.shape == child_shape for elem in array.flat) + # Index folding + if child_shape == array.shape: + if all(isinstance(elem, Indexed) for elem in array.flat): + if all(elem.children == e0.children for elem in array.flat[1:]): + if all(elem.multiindex == idx for elem, idx in zip(array.flat, numpy.ndindex(array.shape))): + return e0.children[0] + if child_shape: # Destroy structure direct_array = numpy.empty(array.shape + child_shape, dtype=object) @@ -911,6 +945,9 @@ def shape(self): def __reduce__(self): return type(self), (self.array,) + def __getitem__(self, i): + return self.array[i] + def reconstruct(self, *args): return ListTensor(asarray(args).reshape(self.array.shape)) diff --git a/gem/node.py b/gem/node.py index 5d9c5bf0..71f81463 100644 --- a/gem/node.py +++ b/gem/node.py @@ -3,6 +3,7 @@ import collections import gem +from itertools import repeat class Node(object): @@ -36,14 +37,17 @@ def _cons_args(self, children): Internally used utility function. """ - front_args = [getattr(self, name) for name in self.__front__] - back_args = [getattr(self, name) for name in self.__back__] + front_args = (getattr(self, name) for name in self.__front__) + back_args = (getattr(self, name) for name in self.__back__) - return tuple(front_args) + tuple(children) + tuple(back_args) + return (*front_args, *children, *back_args) + + def _arguments(self): + return self._cons_args(self.children) def __reduce__(self): # Gold version: - return type(self), self._cons_args(self.children) + return type(self), self._arguments() def reconstruct(self, *args): """Reconstructs the node with new children from @@ -54,8 +58,7 @@ def reconstruct(self, *args): return type(self)(*self._cons_args(args)) def __repr__(self): - cons_args = self._cons_args(self.children) - return "%s(%s)" % (type(self).__name__, ", ".join(map(repr, cons_args))) + return "%s(%s)" % (type(self).__name__, ", ".join(map(repr, self._arguments()))) def __eq__(self, other): """Provides equality testing with quick positive and negative @@ -87,9 +90,7 @@ def is_equal(self, other): """ if type(self) is not type(other): return False - self_consargs = self._cons_args(self.children) - other_consargs = other._cons_args(other.children) - return self_consargs == other_consargs + return self._arguments() == other._arguments() def get_hash(self): """Hash function. @@ -97,7 +98,7 @@ def get_hash(self): This is the method to potentially override in derived classes, not :meth:`__hash__`. """ - return hash((type(self),) + self._cons_args(self.children)) + return hash((type(self), *self._arguments())) def _make_traversal_children(node): @@ -235,8 +236,7 @@ def __call__(self, node): return self.cache[node] except KeyError: result = self.function(node, self) - self.cache[node] = result - return result + return self.cache.setdefault(node, result) class MemoizerArg(object): @@ -259,14 +259,13 @@ def __call__(self, node, arg): return self.cache[cache_key] except KeyError: result = self.function(node, self, arg) - self.cache[cache_key] = result - return result + return self.cache.setdefault(cache_key, result) def reuse_if_untouched(node, self): """Reuse if untouched recipe""" - new_children = list(map(self, node.children)) - if all(nc == c for nc, c in zip(new_children, node.children)): + new_children = tuple(map(self, node.children)) + if new_children == node.children: return node else: return node.reconstruct(*new_children) @@ -274,8 +273,8 @@ def reuse_if_untouched(node, self): def reuse_if_untouched_arg(node, self, arg): """Reuse if touched recipe propagating an extra argument""" - new_children = [self(child, arg) for child in node.children] - if all(nc == c for nc, c in zip(new_children, node.children)): + new_children = tuple(map(self, node.children, repeat(arg))) + if new_children == node.children: return node else: return node.reconstruct(*new_children) diff --git a/gem/optimise.py b/gem/optimise.py index 7d6c8ecd..6206d360 100644 --- a/gem/optimise.py +++ b/gem/optimise.py @@ -101,8 +101,7 @@ def _replace_indices_atomic(i, self, subst): new_expr = self(i.expression, subst) return i if new_expr == i.expression else VariableIndex(new_expr) else: - substitute = dict(subst) - return substitute.get(i, i) + return dict(subst).get(i, i) @replace_indices.register(Delta) @@ -117,20 +116,18 @@ def replace_indices_delta(node, self, subst): @replace_indices.register(Indexed) def replace_indices_indexed(node, self, subst): + multiindex = tuple(_replace_indices_atomic(i, self, subst) for i in node.multiindex) child, = node.children - substitute = dict(subst) - multiindex = [] - for i in node.multiindex: - multiindex.append(_replace_indices_atomic(i, self, subst)) if isinstance(child, ComponentTensor): # Indexing into ComponentTensor # Inline ComponentTensor and augment the substitution rules + substitute = dict(subst) substitute.update(zip(child.multiindex, multiindex)) return self(child.children[0], tuple(sorted(substitute.items()))) else: # Replace indices new_child = self(child, subst) - if new_child == child and multiindex == node.multiindex: + if multiindex == node.multiindex and new_child == child: return node else: return Indexed(new_child, multiindex) @@ -138,9 +135,6 @@ def replace_indices_indexed(node, self, subst): @replace_indices.register(FlexiblyIndexed) def replace_indices_flexiblyindexed(node, self, subst): - child, = node.children - assert not child.free_indices - dim2idxs = tuple( ( offset if isinstance(offset, Integral) else _replace_indices_atomic(offset, self, subst), @@ -149,6 +143,8 @@ def replace_indices_flexiblyindexed(node, self, subst): for offset, idxs in node.dim2idxs ) + child, = node.children + assert not child.free_indices if dim2idxs == node.dim2idxs: return node else: