diff --git a/dace/subsets.py b/dace/subsets.py index ae0c9b1a11..2f1bfb0d1b 100644 --- a/dace/subsets.py +++ b/dace/subsets.py @@ -1406,9 +1406,10 @@ def map(self, r: Range) -> Optional[Range]: if src_j is None: return None - if Range(r.ranges[src_i: src_j]).volume_exact() == 1: - # If we are selecting just a single point in this segment, we can just pick the mapping of that point. - src_segment, dst_segment = Range(self.src.ranges[src_i: src_j]), Range(self.dst.ranges[dst_i: dst_j]) + # If we are selecting just a single point in this segment, we can just pick the mapping of that point. + src_segment, dst_segment, r_segment = Range(self.src.ranges[src_i: src_j]), Range( + self.dst.ranges[dst_i: dst_j]), Range(r.ranges[src_i: src_j]) + if r_segment.volume_exact() == 1: # Compute the local 1D coordinate of the point on `src`. loc = 0 for (idx, _, _), (ridx, _, _), s in zip(reversed(src_segment.ranges), diff --git a/tests/subsets_test.py b/tests/subsets_test.py index 65a2f19a98..4af4449c72 100644 --- a/tests/subsets_test.py +++ b/tests/subsets_test.py @@ -79,7 +79,7 @@ def test_mapping_without_symbols(self): self.assertEqual(Range([(1, 1 + K // 2, 1), (2, 2 + N // 2, 1), (3, 3 + M // 2, 1)]), sm.map(Range([(0, K // 2, 1), (0, N // 2, 1), (0, M // 2, 1)]))) - def test_mapping_with_only_offsets(self): + def test_mapping_with_symbols(self): K, N, M = dace.symbol('K', positive=True), dace.symbol('N', positive=True), dace.symbol('M', positive=True) # A regular cube. @@ -115,6 +115,47 @@ def test_mapping_with_only_offsets(self): args) self.assertEqual(want, got) + def test_mapping_with_reshaping(self): + K, N, M = dace.symbol('K', positive=True), dace.symbol('N', positive=True), dace.symbol('M', positive=True) + + # A regular cube. + src = Range([(0, K - 1, 1), (0, N - 1, 1), (0, M - 1, 1)]) + # A regular cube with different shape. + dst = Range([(0, K - 1, 1), (0, N * M - 1, 1)]) + # A Mapper + sm = SubrangeMapper(src, dst) + + # Pick the entire range. + self.assertEqual(dst, sm.map(src)) + + # NOTE: I couldn't make SymPy understand that `(K//2) % K == (K//2)` always holds for postive integers `K`. + # Hence, the numerical approach. + argslist = [{'K': k, 'N': n, 'M': m} for k, n, m in zip(np.random.randint(1, 10, size=20), + np.random.randint(1, 10, size=20), + np.random.randint(1, 10, size=20))] + # Pick a point K//2, N//2, M//2. + for args in argslist: + want = eval_range( + Range([(K // 2, K // 2, 1), ((N // 2) + (M // 2) * N, (N // 2) + (M // 2) * N, 1)]), + args) + got = eval_range( + sm.map(Range([(K // 2, K // 2, 1), (N // 2, N // 2, 1), (M // 2, M // 2, 1)])), + args) + self.assertEqual(want, got) + # Pick a quadrant. + for args in argslist: + # But its mapping cannot be expressed as a simple range with offset and stride. + self.assertIsNone(sm.map(Range([(0, K // 2, 1), (0, N // 2, 1), (0, M // 2, 1)]))) + # Pick only points in problematic quadrants, but larger subsets elsewhere. + for args in argslist: + want = eval_range( + Range([(0, K // 2, 1), ((N // 2) + (M // 2) * N, (N // 2) + (M // 2) * N, 1)]), + args) + got = eval_range( + sm.map(Range([(0, K // 2, 1), (N // 2, N // 2, 1), (M // 2, M // 2, 1)])), + args) + self.assertEqual(want, got) + if __name__ == '__main__': unittest.main()