Skip to content

Commit

Permalink
GEM: Simplify Indexed tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrubeck committed Jan 21, 2025
1 parent 15ae324 commit 24c6952
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 38 deletions.
57 changes: 47 additions & 10 deletions gem/gem.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,8 @@ def __matmul__(self, other):
raise ValueError(f"Mismatching shapes {self.shape} and {other.shape} in matmul")
*i, k = indices(len(self.shape))
_, *j = indices(len(other.shape))
expr = Product(Indexed(self, tuple(i) + (k, )),
Indexed(other, (k, ) + tuple(j)))
return ComponentTensor(IndexSum(expr, (k, )), tuple(i) + tuple(j))
expr = Product(Indexed(self, (*i, k)), Indexed(other, (k, *j)))
return ComponentTensor(IndexSum(expr, (k, )), (*i, *j))

def __rmatmul__(self, other):
return as_gem(other).__matmul__(self)
Expand Down Expand Up @@ -307,6 +306,9 @@ def value(self):
def shape(self):
return self.array.shape

def __getitem__(self, i):
return self.array[i]


class Variable(Terminal):
"""Symbolic variable tensor"""
Expand Down Expand Up @@ -674,12 +676,31 @@ def __new__(cls, aggregate, multiindex):
if isinstance(aggregate, Zero):
return Zero(dtype=aggregate.dtype)

# All indices fixed
if all(isinstance(i, int) for i in multiindex):
if isinstance(aggregate, Constant):
return Literal(aggregate.array[multiindex], dtype=aggregate.dtype)
elif isinstance(aggregate, ListTensor):
return aggregate.array[multiindex]
# Simplify Indexed(ComponentTensor(Indexed(C, kk), jj), ii) -> Indexed(C, ll)
if isinstance(aggregate, ComponentTensor):
B, = aggregate.children
jj = aggregate.multiindex
if isinstance(B, Indexed):
C, = B.children
kk = B.multiindex
if all(j in kk for j in jj):
ii = tuple(multiindex)
rep = dict(zip(jj, ii))
ll = tuple(rep.get(k, k) for k in kk)
return Indexed(C, ll)

# Simplify Constant and ListTensor
if isinstance(aggregate, (Constant, ListTensor)):
if all(isinstance(i, int) for i in multiindex):
# All indices fixed
sub = aggregate[multiindex]
return Literal(sub, dtype=aggregate.dtype) if isinstance(aggregate, Constant) else sub
elif any(isinstance(i, int) for i in multiindex) and all(isinstance(i, (int, Index)) for i in multiindex):
# Some indices fixed
slices = tuple(i if isinstance(i, int) else slice(None) for i in multiindex)
sub = aggregate[slices]
sub = Literal(sub, dtype=aggregate.dtype) if isinstance(aggregate, Constant) else ListTensor(sub)
return Indexed(sub, tuple(i for i in multiindex if not isinstance(i, int)))

self = super(Indexed, cls).__new__(cls)
self.children = (aggregate,)
Expand Down Expand Up @@ -825,6 +846,11 @@ def __new__(cls, expression, multiindex):
if isinstance(expression, Zero):
return Zero(shape, dtype=expression.dtype)

# Index folding
if isinstance(expression, Indexed):
if multiindex == expression.multiindex:
return expression.children[0]

self = super(ComponentTensor, cls).__new__(cls)
self.children = (expression,)
self.multiindex = multiindex
Expand Down Expand Up @@ -881,9 +907,17 @@ def __new__(cls, array):
dtype = Node.inherit_dtype_from_children(tuple(array.flat))

# Handle children with shape
child_shape = array.flat[0].shape
e0 = array.flat[0]
child_shape = e0.shape
assert all(elem.shape == child_shape for elem in array.flat)

# Index folding
if child_shape == array.shape:
if all(isinstance(elem, Indexed) for elem in array.flat):
if all(elem.children == e0.children for elem in array.flat[1:]):
if all(elem.multiindex == idx for elem, idx in zip(array.flat, numpy.ndindex(array.shape))):
return e0.children[0]

if child_shape:
# Destroy structure
direct_array = numpy.empty(array.shape + child_shape, dtype=object)
Expand Down Expand Up @@ -911,6 +945,9 @@ def shape(self):
def __reduce__(self):
return type(self), (self.array,)

def __getitem__(self, i):
return self.array[i]

def reconstruct(self, *args):
return ListTensor(asarray(args).reshape(self.array.shape))

Expand Down
35 changes: 17 additions & 18 deletions gem/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import collections
import gem
from itertools import repeat


class Node(object):
Expand Down Expand Up @@ -36,14 +37,17 @@ def _cons_args(self, children):
Internally used utility function.
"""
front_args = [getattr(self, name) for name in self.__front__]
back_args = [getattr(self, name) for name in self.__back__]
front_args = (getattr(self, name) for name in self.__front__)
back_args = (getattr(self, name) for name in self.__back__)

return tuple(front_args) + tuple(children) + tuple(back_args)
return (*front_args, *children, *back_args)

def _arguments(self):
return self._cons_args(self.children)

def __reduce__(self):
# Gold version:
return type(self), self._cons_args(self.children)
return type(self), self._arguments()

def reconstruct(self, *args):
"""Reconstructs the node with new children from
Expand All @@ -54,8 +58,7 @@ def reconstruct(self, *args):
return type(self)(*self._cons_args(args))

def __repr__(self):
cons_args = self._cons_args(self.children)
return "%s(%s)" % (type(self).__name__, ", ".join(map(repr, cons_args)))
return "%s(%s)" % (type(self).__name__, ", ".join(map(repr, self._arguments())))

def __eq__(self, other):
"""Provides equality testing with quick positive and negative
Expand Down Expand Up @@ -87,17 +90,15 @@ def is_equal(self, other):
"""
if type(self) is not type(other):
return False
self_consargs = self._cons_args(self.children)
other_consargs = other._cons_args(other.children)
return self_consargs == other_consargs
return self._arguments() == other._arguments()

def get_hash(self):
"""Hash function.
This is the method to potentially override in derived classes,
not :meth:`__hash__`.
"""
return hash((type(self),) + self._cons_args(self.children))
return hash((type(self), *self._arguments()))


def _make_traversal_children(node):
Expand Down Expand Up @@ -235,8 +236,7 @@ def __call__(self, node):
return self.cache[node]
except KeyError:
result = self.function(node, self)
self.cache[node] = result
return result
return self.cache.setdefault(node, result)


class MemoizerArg(object):
Expand All @@ -259,23 +259,22 @@ def __call__(self, node, arg):
return self.cache[cache_key]
except KeyError:
result = self.function(node, self, arg)
self.cache[cache_key] = result
return result
return self.cache.setdefault(cache_key, result)


def reuse_if_untouched(node, self):
"""Reuse if untouched recipe"""
new_children = list(map(self, node.children))
if all(nc == c for nc, c in zip(new_children, node.children)):
new_children = tuple(map(self, node.children))
if new_children == node.children:
return node
else:
return node.reconstruct(*new_children)


def reuse_if_untouched_arg(node, self, arg):
"""Reuse if touched recipe propagating an extra argument"""
new_children = [self(child, arg) for child in node.children]
if all(nc == c for nc, c in zip(new_children, node.children)):
new_children = tuple(map(self, node.children, repeat(arg)))
if new_children == node.children:
return node
else:
return node.reconstruct(*new_children)
16 changes: 6 additions & 10 deletions gem/optimise.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,7 @@ def _replace_indices_atomic(i, self, subst):
new_expr = self(i.expression, subst)
return i if new_expr == i.expression else VariableIndex(new_expr)
else:
substitute = dict(subst)
return substitute.get(i, i)
return dict(subst).get(i, i)


@replace_indices.register(Delta)
Expand All @@ -117,30 +116,25 @@ def replace_indices_delta(node, self, subst):

@replace_indices.register(Indexed)
def replace_indices_indexed(node, self, subst):
multiindex = tuple(_replace_indices_atomic(i, self, subst) for i in node.multiindex)
child, = node.children
substitute = dict(subst)
multiindex = []
for i in node.multiindex:
multiindex.append(_replace_indices_atomic(i, self, subst))
if isinstance(child, ComponentTensor):
# Indexing into ComponentTensor
# Inline ComponentTensor and augment the substitution rules
substitute = dict(subst)
substitute.update(zip(child.multiindex, multiindex))
return self(child.children[0], tuple(sorted(substitute.items())))
else:
# Replace indices
new_child = self(child, subst)
if new_child == child and multiindex == node.multiindex:
if multiindex == node.multiindex and new_child == child:
return node
else:
return Indexed(new_child, multiindex)


@replace_indices.register(FlexiblyIndexed)
def replace_indices_flexiblyindexed(node, self, subst):
child, = node.children
assert not child.free_indices

dim2idxs = tuple(
(
offset if isinstance(offset, Integral) else _replace_indices_atomic(offset, self, subst),
Expand All @@ -149,6 +143,8 @@ def replace_indices_flexiblyindexed(node, self, subst):
for offset, idxs in node.dim2idxs
)

child, = node.children
assert not child.free_indices
if dim2idxs == node.dim2idxs:
return node
else:
Expand Down

0 comments on commit 24c6952

Please sign in to comment.