Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a SubrangeMapper helper class which maps a _subrange_ of src range to its counterpart in dst range, if possible. #1702

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
3 changes: 2 additions & 1 deletion dace/libraries/standard/nodes/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ def expansion(node: 'Reduce', state: SDFGState, sdfg: SDFG):
'reduce_init', {'_o%d' % i: '0:%s' % symstr(d)
for i, d in enumerate(outedge.data.subset.size())}, {},
'__out = %s' % node.identity,
{'__out': dace.Memlet.simple('_out', ','.join(['_o%d' % i for i in range(output_dims)]))},
# {'__out': dace.Memlet.simple('_out', ','.join(['_o%d' % i for i in range(output_dims)]))},
{'__out': dace.Memlet.simple('_out', ','.join(['_o%d' % i for i in osqdim]))},
external_edges=True)
else:
nstate = nsdfg.add_state()
Expand Down
164 changes: 138 additions & 26 deletions dace/subsets.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved.
import dace.serialize
from dace import data, symbolic, dtypes
import re
import sympy as sp
import warnings
from functools import reduce
import sympy.core.sympify
from typing import List, Optional, Sequence, Set, Union
import warnings

import sympy as sp
import sympy.core.sympify
from sympy import ceiling

import dace.serialize
from dace import symbolic
from dace.config import Config


Expand All @@ -20,6 +23,7 @@ def nng(expr):
except AttributeError: # No free_symbols in expr
return expr


def bounding_box_cover_exact(subset_a, subset_b) -> bool:
min_elements_a = subset_a.min_element()
max_elements_a = subset_a.max_element()
Expand All @@ -29,16 +33,17 @@ def bounding_box_cover_exact(subset_a, subset_b) -> bool:
# Covering only make sense if the two subsets have the same number of dimensions.
if len(min_elements_a) != len(min_elements_b):
return ValueError(
f"A bounding box of dimensionality {len(min_elements_a)} cannot"
f" test covering a bounding box of dimensionality {len(min_elements_b)}."
f"A bounding box of dimensionality {len(min_elements_a)} cannot"
f" test covering a bounding box of dimensionality {len(min_elements_b)}."
)

return all([(symbolic.simplify_ext(nng(rb)) <= symbolic.simplify_ext(nng(orb))) == True
and (symbolic.simplify_ext(nng(re)) >= symbolic.simplify_ext(nng(ore))) == True
for rb, re, orb, ore in zip(min_elements_a, max_elements_a,
min_elements_b, max_elements_b)])

def bounding_box_symbolic_positive(subset_a, subset_b, approximation = False)-> bool:

def bounding_box_symbolic_positive(subset_a, subset_b, approximation=False) -> bool:
min_elements_a = subset_a.min_element_approx() if approximation else subset_a.min_element()
max_elements_a = subset_a.max_element_approx() if approximation else subset_a.max_element()
min_elements_b = subset_b.min_element_approx() if approximation else subset_b.min_element()
Expand All @@ -47,8 +52,8 @@ def bounding_box_symbolic_positive(subset_a, subset_b, approximation = False)->
# Covering only make sense if the two subsets have the same number of dimensions.
if len(min_elements_a) != len(min_elements_b):
return ValueError(
f"A bounding box of dimensionality {len(min_elements_a)} cannot"
f" test covering a bounding box of dimensionality {len(min_elements_b)}."
f"A bounding box of dimensionality {len(min_elements_a)} cannot"
f" test covering a bounding box of dimensionality {len(min_elements_b)}."
)

for rb, re, orb, ore in zip(min_elements_a, max_elements_a,
Expand All @@ -70,6 +75,7 @@ def bounding_box_symbolic_positive(subset_a, subset_b, approximation = False)->
return False
return True


class Subset(object):
""" Defines a subset of a data descriptor. """

Expand All @@ -80,7 +86,7 @@ def covers(self, other):
# Subsets of different dimensionality can never cover each other.
if self.dims() != other.dims():
return ValueError(
f"A subset of dimensionality {self.dim()} cannot test covering a subset of dimensionality {other.dims()}"
f"A subset of dimensionality {self.dim()} cannot test covering a subset of dimensionality {other.dims()}"
)

if not Config.get('optimizer', 'symbolic_positive'):
Expand All @@ -99,20 +105,22 @@ def covers(self, other):
return False

return True

def covers_precise(self, other):
""" Returns True if self contains all the elements in other. """

# Subsets of different dimensionality can never cover each other.
if self.dims() != other.dims():
return ValueError(
f"A subset of dimensionality {self.dim()} cannot test covering a subset of dimensionality {other.dims()}"
f"A subset of dimensionality {self.dim()} cannot test covering a subset of dimensionality {other.dims()}"
)

# If self does not cover other with a bounding box union, return false.
symbolic_positive = Config.get('optimizer', 'symbolic_positive')
try:
bounding_box_cover = bounding_box_cover_exact(self, other) if symbolic_positive else bounding_box_symbolic_positive(self, other)
bounding_box_cover = bounding_box_cover_exact(self,
other) if symbolic_positive else bounding_box_symbolic_positive(
self, other)
if not bounding_box_cover:
return False
except TypeError:
Expand Down Expand Up @@ -151,14 +159,13 @@ def covers_precise(self, other):
except:
return False
return True
# unknown type
# unknown type
else:
raise TypeError

except TypeError:
return False


def __repr__(self):
return '%s (%s)' % (type(self).__name__, self.__str__())

Expand Down Expand Up @@ -229,6 +236,7 @@ def _tuple_to_symexpr(val):
@dace.serialize.serializable
class Range(Subset):
""" Subset defined in terms of a fixed range. """

def __init__(self, ranges):
parsed_ranges = []
parsed_tiles = []
Expand Down Expand Up @@ -334,6 +342,10 @@ def size_exact(self):
for (iMin, iMax, step), ts in zip(self.ranges, self.tile_sizes)
]

def volume_exact(self) -> int:
""" Returns the total number of elements in all dimensions together. """
return reduce(lambda a, b: a * b, self.size_exact())

def bounding_box_size(self):
""" Returns the size of a bounding box around this range. """
return [
Expand Down Expand Up @@ -578,7 +590,7 @@ def from_string(string):
value = symbolic.pystr_to_symbolic(uni_dim_tokens[0].strip())
ranges.append((value, value, 1))
continue
#return Range(ranges)
# return Range(ranges)
# If dimension has more than 4 tokens, the range is invalid
if len(uni_dim_tokens) > 4:
raise SyntaxError("Invalid range: {}".format(multi_dim_tokens))
Expand Down Expand Up @@ -848,6 +860,7 @@ def intersects(self, other: 'Range'):
class Indices(Subset):
""" A subset of one element representing a single index in an
N-dimensional data descriptor. """

def __init__(self, indices):
if indices is None or len(indices) == 0:
raise TypeError('Expected an array of index expressions: got empty' ' array or None')
Expand All @@ -874,7 +887,7 @@ def from_json(obj, context=None):
raise TypeError("from_json of class \"Indices\" called on json "
"with type %s (expected 'Indices')" % obj['type'])

#return Indices(symbolic.SymExpr(obj['indices']))
# return Indices(symbolic.SymExpr(obj['indices']))
return Indices([*map(symbolic.pystr_to_symbolic, obj['indices'])])

def __hash__(self):
Expand All @@ -895,6 +908,10 @@ def size(self):
def size_exact(self):
return self.size()

def volume_exact(self) -> int:
""" Returns the total number of elements in all dimensions together. """
return reduce(lambda a, b: a * b, self.size_exact())

def min_element(self):
return self.indices

Expand Down Expand Up @@ -1081,6 +1098,7 @@ def intersection(self, other: 'Indices'):
return self
return None


class SubsetUnion(Subset):
"""
Wrapper subset type that stores multiple Subsets in a list.
Expand Down Expand Up @@ -1118,7 +1136,7 @@ def covers(self, other):
return False
else:
return any(s.covers(other) for s in self.subset_list)

def covers_precise(self, other):
"""
Returns True if this SubsetUnion covers another
Expand All @@ -1144,7 +1162,7 @@ def __str__(self):
string += " "
string += subset.__str__()
return string

def dims(self):
if not self.subset_list:
return 0
Expand All @@ -1168,7 +1186,7 @@ def free_symbols(self) -> Set[str]:
for subset in self.subset_list:
result |= subset.free_symbols
return result

def replace(self, repl_dict):
for subset in self.subset_list:
subset.replace(repl_dict)
Expand All @@ -1178,13 +1196,12 @@ def num_elements(self):
min = 0
for subset in self.subset_list:
try:
if subset.num_elements() < min or min ==0:
if subset.num_elements() < min or min == 0:
min = subset.num_elements()
except:
continue

return min

return min


def _union_special_cases(arb: symbolic.SymbolicType, brb: symbolic.SymbolicType, are: symbolic.SymbolicType,
Expand Down Expand Up @@ -1251,8 +1268,6 @@ def bounding_box_union(subset_a: Subset, subset_b: Subset) -> Range:
return Range(result)




def union(subset_a: Subset, subset_b: Subset) -> Subset:
""" Compute the union of two Subset objects.
If the subsets are not of the same type, degenerates to bounding-box
Expand Down Expand Up @@ -1321,6 +1336,7 @@ def list_union(subset_a: Subset, subset_b: Subset) -> Subset:
except TypeError:
return None


def intersects(subset_a: Subset, subset_b: Subset) -> Union[bool, None]:
"""
Returns True if two subsets intersect, False if they do not, or
Expand All @@ -1342,3 +1358,99 @@ def intersects(subset_a: Subset, subset_b: Subset) -> Union[bool, None]:
return None
except TypeError: # cannot determine truth value of Relational
return None


class SubrangeMapper:
"""
Equipped with a `src` and a `dst` range of equal volumes but possibly different shapes or even dimensions (i.e.,
there is an 1-to-1 correspondence between the elements of `src` and `dst`), maps a subrange of `src` to its
counterpart in `dst`, if possible.

Note that such subrange-to-subrange mapping may not always exist.
"""

def __init__(self, src: Range, dst: Range):
src, dst = self.canonical(src), self.canonical(dst)
assert src.volume_exact() == dst.volume_exact()
self.src, self.dst = src, dst

@staticmethod
def canonical(r: Range) -> Range:
"""
Extends the (excluded) upper bound of each component of the ranges as much as possible, without affecting the
volume of the range.
"""
return Range([(b, b + s * ceiling((e - b + 1) / s) - 1, s)
for b, e, s in r.ndrange()])

def map(self, r: Range) -> Optional[Range]:
r = self.canonical(r)
# Ideally we also have `assert self.src.covers_precise(r)`. However, we cannot determine that for symbols.
assert self.src.dims() == r.dims()
out = []
src_i, dst_i = 0, 0
while src_i < self.src.dims() and dst_i < self.dst.dims():
# If we run out only on one side, handle that case after the loop.

# Find the next smallest segments of `src` and `dst` whose volumes matches (and therefore can possibly have
# a mapping).
# TODO: It's possible to do this in a O(max(|src|, |dst|)) loop instead of O(|src| * |dst|).
src_j, dst_j = None, None
for sj in range(src_i + 1, self.src.dims() + 1):
for dj in range(dst_i + 1, self.dst.dims() + 1):
if Range(self.src.ranges[src_i:sj]).volume_exact() == Range(
self.dst.ranges[dst_i:dj]).volume_exact():
src_j, dst_j = sj, dj
break
else:
continue
break
if src_j is None:
# Somehow, we couldn't find a matching segment. This should have been caught earlier.
return None

src_segment = Range(self.src.ranges[src_i: src_j])
dst_segment = Range(self.dst.ranges[dst_i: dst_j])
r_segment = Range(r.ranges[src_i: src_j])
if r_segment.volume_exact() == 1:
# If we are selecting just a single point in this segment, we can just pick the mapping of that point.
# Compute the local 1D coordinate of the point on `src`.
loc = 0
for (idx, _, _), (ridx, _, _), s in zip(reversed(src_segment.ranges),
reversed(r.ranges[src_i: src_j]),
reversed(src_segment.size())):
loc = loc * s + (ridx - idx)
# Translate that local 1D coordinate onto `dst`.
dst_coord = []
for (idx, _, _), s in zip(dst_segment.ranges, dst_segment.size()):
dst_coord.append(loc % s + idx)
loc = loc // s
out.extend([(idx, idx, 1) for idx in dst_coord])
elif self.src.ranges[src_i: src_j] == r.ranges[src_i: src_j]:
# If we are selecting the entirety of this segment, we can just pick the corresponding mapped segment in
# its entirety too.
out.extend(self.dst.ranges[dst_i:dst_j])
elif src_j - src_i == 1 and dst_j - dst_i == 1:
# If the segment lengths on both sides are just 1, the mapping is easy to compute -- it's just a shift.
sb, se, ss = self.src.ranges[src_i]
db, de, ds = self.dst.ranges[dst_i]
b, e, s = r.ranges[src_i]
lb, le, ls = (b - sb) // ss, (e - se) // ss - 1, s // ss
tb, te, ts = db + lb * ds, de + (le + 1) * ds, ds * ls
out.append((tb, te, ts))
else:
# TODO: Can we narrow down this case even more? That would be number theoretic problem.
# E.g., If we are reshaping [6, 5] to [2, 15], we are demanding that these dimensions must be wholly
# selected for now.
return None

src_i, dst_i = src_j, dst_j
if src_i < self.src.dims():
src_segment = Range(self.src.ranges[src_i: self.src.dims()])
assert src_segment.volume_exact() == 1
if dst_i < self.dst.dims():
# Take the remaining dst segment which must have a volume of 1 by now.
dst_segment = Range(self.dst.ranges[dst_i: self.dst.dims()])
assert dst_segment.volume_exact() == 1
out.extend(dst_segment.ranges)
return Range(out)
Loading