This section contains the sketch proof of "Complete, Fast and Correct View.reshapes without using Symbolic". The goal is to reduce multi-views which cost runtime.
- old_shape = (s1,s2,...,si,s(i+1),...,sn)
- old_stride = (st1, st2, ... ,sti, st(i+1), ..., stn)
- merge_old_shape = (p1, p2), where p1 = s1 * ... * si & p2 = s(i+1) * ... * sn,
- new_shape = (k1, ..., kp, k(p+1), ..., kl)
- prod(new_shape) = p1 * p2 (trivial)
- mask and new_mask represent valid indexes before & after reshape respectively.
p1 & p2 individually are mergeable (we will discuss later on this) & we cannot merge p1 & p2.
If prod([k1 ... kp]) < p1 and prod([k1 ... k(p+1)]) > p1, reshape is not possible.
Proof
k(p+1) will require some dimensions from p1 & some from p2, which means p1 & p2 should be mergeable, but they are not.
Conclusion
Hence, reshape is only possible if ∃ a p, where prod([k1 .. kp]) = p1.
Case 1 - All non-zero strides
They will merge if stx = st(x+1) * s(x+1), where x ∈ [1, ..., i-1, i+1, ..., n-1].
Proof
Lets consider merging of (s1 ... si) -> p1, here we have to get a single new stride corresponding to p1. For which it has to be contiguous.
Case 2 - Some stride is zero
Let stj = 0 & st(j+1) != 0 & s(j+1) > 1, where 1 < j < i.
If sj = 1 , reshape is trivial.
If sj > 1,
- If maskj has range > 1, reshape is not possible, because s(j+1) will need to be repeated at-least once and a single stride can't capture repetition.
- If maskj has range = 1, reshape is possible, since it is virtually shape = 1, with some offset.
Case 1 - Splitting Dimension - Mask shouldn't be cut for successful reshape.
-
Example - [1,2,3,4,5,6,7,8] -> [[1,2,3,4], [5,6,7,8]] ; mask = ((2,6)) ; new_mask[0] = (0,2) (trivial split).
-
new_mask[1] = not possible. It is only possible if mask spans [1-8] or lies within a single dimension [1-4] or [5-8].
Case 2 - Combining Dimension - Mask should unfold continuously.
-
Example - [[1,2],[3,4],[5,6]] -> [1,2,3,4,5,6]; mask = ((0,2),(0,2)).
-
new_mask = (0,4); only possible because mask1 span the whole dimension.
-
If mask1 did not span the whole dimension, the only way combining would be possible is if mask0 had range 1 as shown below.
- [[1,2,3],[4,5,6]] -> [1,2,3,4,5,6]; mask = ((1,2),(0,2)); new_mask = ((3,5))