Skip to content

Commit 73b6918

Browse files
authored
Don't assert for symbolic stride in dim_order_from_stride() (#15472)
Summary: Curently dim_order_for_stride() checks if any stride is 0, and fails if so. N7613577 shows a min repro from factorized joiner usecase with symbolic stride, where this check fails: P2015309933 This diff skips the assert for symbolic strides. Differential Revision: D85875885
1 parent 487dae0 commit 73b6918

File tree

3 files changed

+27
-3
lines changed

3 files changed

+27
-3
lines changed

exir/tensor.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,15 @@ def dim_order_from_stride(stride: Tuple[int]) -> Tuple[bytes]:
6767
Another example is: sizes = (1, 3, 1, 1) with strides = (3, 1, 3, 3), returned
6868
value is (0, 2, 3, 1)
6969
"""
70+
from torch.fx.experimental.symbolic_shapes import (
71+
guard_or_false,
72+
guard_size_oblivious,
73+
)
74+
7075
for _, s in enumerate(stride):
71-
if s == 0:
76+
if guard_or_false(s == 0):
7277
raise ValueError("0 in strides is not supported for ExecuTorch.")
7378

74-
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
75-
7679
class K(NamedTuple):
7780
stride: int
7881

exir/tests/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,7 @@ python_unittest(
385385
deps = [
386386
"//caffe2:torch",
387387
"//executorch/exir:dim_order_utils",
388+
"//executorch/exir:lib",
388389
],
389390
)
390391

exir/tests/test_dim_order_utils.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import unittest
99

1010
import torch
11+
from executorch.exir import to_edge_transform_and_lower
1112
from executorch.exir.dim_order_utils import get_dim_order, get_memory_format
1213

1314

@@ -27,3 +28,22 @@ def test_get_dim_order(self) -> None:
2728
list(range(ndim)), get_dim_order(torch.contiguous_format, ndim)
2829
)
2930
self.assertEqual([0, 2, 3, 1], get_dim_order(torch.channels_last, 4))
31+
32+
def test_dim_order_from_stride(self):
33+
class Test(torch.nn.Module):
34+
def __init__(self):
35+
super().__init__()
36+
37+
def forward(self, t1, t2):
38+
idx = torch.nonzero(t1).reshape(-1)
39+
y = torch.index_select(t2, 0, idx)
40+
return y
41+
42+
M = Test()
43+
x = torch.tensor([0, 1, 1, 0, 1], dtype=torch.bool)
44+
y = torch.randn(5, 6)
45+
M(x, y)
46+
47+
expo_prog = torch.export.export_for_training(M, (x, y))
48+
edge_prog = to_edge_transform_and_lower(expo_prog)
49+
edge_prog.to_executorch()

0 commit comments

Comments
 (0)