Skip to content

Commit e487d29

Browse files
authored
Arm backend: Fix bug in ConvertExpandToRepeat pass (#15589)
The pass assumed that if all repeat multiples are one, the op is a no-op. However, it can still change the rank. Signed-off-by: Erik Lundell <[email protected]>
1 parent 50fe3b3 commit e487d29

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

backends/arm/_passes/convert_expand_copy_to_repeat.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121

2222
def calculate_multiples(args):
23+
"""Returns expand args converted to repeat args, and whether the expand changes the rank"""
2324
input_node_or_tensor = args[0]
2425

2526
if isinstance(input_node_or_tensor, torch.fx.node.Node):
@@ -45,7 +46,7 @@ def calculate_multiples(args):
4546
multiples[i] if multiples[i] != -1 and extended_shape[i] == 1 else 1
4647
for i in range(expanded_rank)
4748
]
48-
return multiples
49+
return multiples, expanded_rank != len(input_shape)
4950

5051

5152
class ConvertExpandCopyToRepeatPass(ArmPass):
@@ -62,9 +63,9 @@ def call_operator(self, op, args, kwargs, meta):
6263
if op != self.expand_copy:
6364
return super().call_operator(op, args, kwargs, meta)
6465

65-
multiples = calculate_multiples(args)
66+
multiples, changes_rank = calculate_multiples(args)
6667

67-
if all((x == 1 for x in multiples)):
68+
if all((x == 1 for x in multiples)) and not changes_rank:
6869
# All dimensions/repetitions occur only once. Remove node
6970
# altogether since it's in practice just a copy.
7071
logger.warning("Found redundant expand node (no-op). Removing it.")

backends/arm/tosa/partitioner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,8 @@ def is_noop_expand(node: torch.fx.node.Node) -> bool:
110110
if node.target != exir_ops.edge.aten.expand_copy.default:
111111
return False
112112
else:
113-
multiples = calculate_multiples(node.args)
114-
return all(m == 1 for m in multiples)
113+
multiples, changes_rank = calculate_multiples(node.args)
114+
return all(m == 1 for m in multiples) and not changes_rank
115115

116116

117117
def is_partitioned(

0 commit comments

Comments
 (0)