Skip to content

Commit

Permalink
Cover more cases with tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
pratyai committed Oct 23, 2024
1 parent 827f960 commit 99c6829
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 4 deletions.
7 changes: 4 additions & 3 deletions dace/subsets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1406,9 +1406,10 @@ def map(self, r: Range) -> Optional[Range]:
if src_j is None:
return None

if Range(r.ranges[src_i: src_j]).volume_exact() == 1:
# If we are selecting just a single point in this segment, we can just pick the mapping of that point.
src_segment, dst_segment = Range(self.src.ranges[src_i: src_j]), Range(self.dst.ranges[dst_i: dst_j])
# If we are selecting just a single point in this segment, we can just pick the mapping of that point.
src_segment, dst_segment, r_segment = Range(self.src.ranges[src_i: src_j]), Range(
self.dst.ranges[dst_i: dst_j]), Range(r.ranges[src_i: src_j])
if r_segment.volume_exact() == 1:
# Compute the local 1D coordinate of the point on `src`.
loc = 0
for (idx, _, _), (ridx, _, _), s in zip(reversed(src_segment.ranges),
Expand Down
43 changes: 42 additions & 1 deletion tests/subsets_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def test_mapping_without_symbols(self):
self.assertEqual(Range([(1, 1 + K // 2, 1), (2, 2 + N // 2, 1), (3, 3 + M // 2, 1)]),
sm.map(Range([(0, K // 2, 1), (0, N // 2, 1), (0, M // 2, 1)])))

def test_mapping_with_only_offsets(self):
def test_mapping_with_symbols(self):
K, N, M = dace.symbol('K', positive=True), dace.symbol('N', positive=True), dace.symbol('M', positive=True)

# A regular cube.
Expand Down Expand Up @@ -115,6 +115,47 @@ def test_mapping_with_only_offsets(self):
args)
self.assertEqual(want, got)

def test_mapping_with_reshaping(self):
K, N, M = dace.symbol('K', positive=True), dace.symbol('N', positive=True), dace.symbol('M', positive=True)

# A regular cube.
src = Range([(0, K - 1, 1), (0, N - 1, 1), (0, M - 1, 1)])
# A regular cube with different shape.
dst = Range([(0, K - 1, 1), (0, N * M - 1, 1)])
# A Mapper
sm = SubrangeMapper(src, dst)

# Pick the entire range.
self.assertEqual(dst, sm.map(src))

# NOTE: I couldn't make SymPy understand that `(K//2) % K == (K//2)` always holds for postive integers `K`.
# Hence, the numerical approach.
argslist = [{'K': k, 'N': n, 'M': m} for k, n, m in zip(np.random.randint(1, 10, size=20),
np.random.randint(1, 10, size=20),
np.random.randint(1, 10, size=20))]
# Pick a point K//2, N//2, M//2.
for args in argslist:
want = eval_range(
Range([(K // 2, K // 2, 1), ((N // 2) + (M // 2) * N, (N // 2) + (M // 2) * N, 1)]),
args)
got = eval_range(
sm.map(Range([(K // 2, K // 2, 1), (N // 2, N // 2, 1), (M // 2, M // 2, 1)])),
args)
self.assertEqual(want, got)
# Pick a quadrant.
for args in argslist:
# But its mapping cannot be expressed as a simple range with offset and stride.
self.assertIsNone(sm.map(Range([(0, K // 2, 1), (0, N // 2, 1), (0, M // 2, 1)])))
# Pick only points in problematic quadrants, but larger subsets elsewhere.
for args in argslist:
want = eval_range(
Range([(0, K // 2, 1), ((N // 2) + (M // 2) * N, (N // 2) + (M // 2) * N, 1)]),
args)
got = eval_range(
sm.map(Range([(0, K // 2, 1), (N // 2, N // 2, 1), (M // 2, M // 2, 1)])),
args)
self.assertEqual(want, got)


if __name__ == '__main__':
unittest.main()

0 comments on commit 99c6829

Please sign in to comment.