From 71c5011c64a7cefc1a54b539f8ceb5ad5a2f53f8 Mon Sep 17 00:00:00 2001 From: Pratyai Mazumder Date: Wed, 23 Oct 2024 15:24:04 +0200 Subject: [PATCH] Add some explanatory comments. --- dace/subsets.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/dace/subsets.py b/dace/subsets.py index 2f1bfb0d1b..aa3a3269e7 100644 --- a/dace/subsets.py +++ b/dace/subsets.py @@ -1393,6 +1393,9 @@ def map(self, r: Range) -> Optional[Range]: while src_i < self.src.dims(): assert dst_i < self.dst.dims() + # Find the next smallest segments of `src` and `dst` whose volumes matches (and therefore can possibly have + # a mapping). + # TODO: It's possible to do this in a O(max(|src|, |dst|)) loop instead of O(|src| * |dst|). src_j, dst_j = None, None for sj in range(src_i + 1, self.src.dims() + 1): for dj in range(dst_i + 1, self.dst.dims() + 1): @@ -1404,12 +1407,14 @@ def map(self, r: Range) -> Optional[Range]: continue break if src_j is None: + # Somehow, we couldn't find a matching segment. This should have been caught earlier. return None - # If we are selecting just a single point in this segment, we can just pick the mapping of that point. - src_segment, dst_segment, r_segment = Range(self.src.ranges[src_i: src_j]), Range( - self.dst.ranges[dst_i: dst_j]), Range(r.ranges[src_i: src_j]) + src_segment = Range(self.src.ranges[src_i: src_j]) + dst_segment = Range(self.dst.ranges[dst_i: dst_j]) + r_segment = Range(r.ranges[src_i: src_j]) if r_segment.volume_exact() == 1: + # If we are selecting just a single point in this segment, we can just pick the mapping of that point. # Compute the local 1D coordinate of the point on `src`. loc = 0 for (idx, _, _), (ridx, _, _), s in zip(reversed(src_segment.ranges), @@ -1427,7 +1432,7 @@ def map(self, r: Range) -> Optional[Range]: # its entirety too. out.extend(self.dst.ranges[dst_i:dst_j]) elif src_j - src_i == 1 and dst_j - dst_i == 1: - # If the segment lengths on both sides are just 1, the mapping is easy to compute. + # If the segment lengths on both sides are just 1, the mapping is easy to compute -- it's just a shift. sb, se, ss = self.src.ranges[src_i] db, de, ds = self.dst.ranges[dst_i] b, e, s = r.ranges[src_i]