Skip to content

Commit fd4eb9d

Browse files
Arm backend: Upcast index argument to int64 for aten.index_copy ops (#15595)
Signed-off-by: Yufeng Shi <[email protected]>
1 parent c9fcb24 commit fd4eb9d

File tree

2 files changed

+71
-1
lines changed

2 files changed

+71
-1
lines changed

backends/arm/_passes/insert_int32_casts_after_int64_placeholders.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ class InsertInt32CastsAfterInt64PlaceholdersPass(ArmPass):
3636
# Key: op overload; Value: zero-based indices of positional args that must be i64.
3737
I64_INPUT_ARG_POSITIONS = {
3838
torch.ops.aten.one_hot.default: (0,),
39+
torch.ops.aten.index_copy_.default: (2,),
40+
torch.ops.aten.index_copy.default: (2,),
3941
}
4042

4143
def _insert_callsite_i32_to_i64_casts(self, graph_module: torch.fx.GraphModule):

backends/arm/test/passes/test_insert_int32_casts_after_int64_placeholders_pass.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,13 @@
88
import torch
99
from executorch.backends.arm._passes import InsertInt32CastsAfterInt64PlaceholdersPass
1010

11-
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline
11+
from executorch.backends.arm.test.tester.test_pipeline import (
12+
PassPipeline,
13+
TosaPipelineINT,
14+
)
1215

1316
input_t = Tuple[torch.Tensor, torch.Tensor] # weights, indices
17+
input_t3 = Tuple[torch.Tensor, torch.LongTensor, torch.Tensor]
1418

1519

1620
class Int64InputModel(torch.nn.Module):
@@ -44,3 +48,67 @@ def test_int64_model_tosa_FP():
4448
)
4549
pipeline.pop_stage(-1) # Do not compare output
4650
pipeline.run()
51+
52+
53+
class UpcastToInt64ForIndexCopyInplaceModel(torch.nn.Module):
54+
aten_op = "torch.ops.aten.index_copy_.default"
55+
56+
def forward(self, x: torch.Tensor, index: torch.LongTensor, y: torch.tensor):
57+
return x.index_copy_(0, index, y)
58+
59+
def get_inputs(self) -> input_t3:
60+
return (
61+
torch.zeros(5, 3),
62+
torch.tensor([0, 4, 2]),
63+
torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float),
64+
)
65+
66+
67+
def test_upcast_to_int64_for_index_copy_inplace_tosa_INT():
68+
module = UpcastToInt64ForIndexCopyInplaceModel()
69+
pipeline = TosaPipelineINT[input_t3](
70+
module,
71+
module.get_inputs(),
72+
aten_op=module.aten_op,
73+
)
74+
pipeline.pop_stage("check.quant_nodes")
75+
pipeline.change_args(
76+
"check_count.exir",
77+
{
78+
"torch.ops.higher_order.executorch_call_delegate": 0,
79+
},
80+
)
81+
pipeline.pop_stage("run_method_and_compare_outputs")
82+
pipeline.run()
83+
84+
85+
class UpcastToInt64ForIndexCopyModel(torch.nn.Module):
86+
aten_op = "torch.ops.aten.index_copy.default"
87+
88+
def forward(self, x: torch.Tensor, index: torch.LongTensor, y: torch.tensor):
89+
return x.index_copy(0, index, y)
90+
91+
def get_inputs(self) -> input_t3:
92+
return (
93+
torch.zeros(5, 3),
94+
torch.tensor([0, 4, 2]),
95+
torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float),
96+
)
97+
98+
99+
def test_upcast_to_int64_for_index_copy_tosa_INT():
100+
module = UpcastToInt64ForIndexCopyModel()
101+
pipeline = TosaPipelineINT[input_t3](
102+
module,
103+
module.get_inputs(),
104+
aten_op=module.aten_op,
105+
)
106+
pipeline.pop_stage("check.quant_nodes")
107+
pipeline.change_args(
108+
"check_count.exir",
109+
{
110+
"torch.ops.higher_order.executorch_call_delegate": 0,
111+
},
112+
)
113+
pipeline.pop_stage("run_method_and_compare_outputs")
114+
pipeline.run()

0 commit comments

Comments
 (0)