Skip to content

Commit 96189bf

Browse files
Arm backend: Make quantized where.self fuseable (#15520)
Makes where.self fuseable and changes how where.self is annotated. Signed-off-by: Oscar Andersson <[email protected]>
1 parent eb409ba commit 96189bf

File tree

3 files changed

+43
-8
lines changed

3 files changed

+43
-8
lines changed

backends/arm/_passes/fuse_constant_ops_pass.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,8 @@ def resolve_arg(arg):
6565
if isinstance(arg, torch.fx.Node) and arg in input_nodes:
6666
idx = input_nodes.index(arg)
6767
t = get_param_tensor(self.exported_program, arg)
68-
if qparams:
68+
# Check if qparams exist for this arg
69+
if qparams and idx in qparams.keys():
6970
t = qparams[idx].dequantize_value(t)
7071
return t
7172
if isinstance(arg, tuple):

backends/arm/quantizer/quantization_annotator.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -608,12 +608,19 @@ def any_or_hardtanh_min_zero(n: Node):
608608
quant_properties.quant_output = _QuantProperty(0, shared_qspec)
609609
elif node.target in (torch.ops.aten.where.self,):
610610
true_node = ensure_type(Node, node.args[1])
611-
shared_qspec = SharedQuantizationSpec(true_node)
611+
input_qspec = (
612+
SharedQuantizationSpec(true_node)
613+
if is_output_annotated(true_node)
614+
else input_act_qspec
615+
)
612616
quant_properties.quant_inputs = [
613-
_QuantProperty(1, shared_qspec),
614-
_QuantProperty(2, shared_qspec),
617+
_QuantProperty(1, input_qspec),
618+
_QuantProperty(2, SharedQuantizationSpec((true_node, node))),
615619
]
616-
quant_properties.quant_output = _QuantProperty(0, shared_qspec)
620+
quant_properties.quant_output = _QuantProperty(
621+
0,
622+
SharedQuantizationSpec((true_node, node)),
623+
)
617624
elif node.target in _one_to_one_shared_input_or_input_act_qspec:
618625
input_node = ensure_type(Node, node.args[0])
619626
input_qspec = (

backends/arm/test/ops/test_where.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from typing import List, Tuple
77

88
import torch
9-
109
from executorch.backends.arm.quantizer import (
1110
EthosUQuantizer,
1211
get_symmetric_quantization_config,
@@ -65,6 +64,30 @@ def forward(
6564
return torch.where(self.condition(input_), input_, other_)
6665

6766

67+
class ConstWhere(torch.nn.Module):
68+
69+
def __init__(self, buffer: torch.Tensor, dtype: torch.dtype):
70+
super().__init__()
71+
self.buffer = buffer
72+
self.dtype = dtype
73+
self.min = torch.nn.Buffer(torch.tensor(0.0, dtype=self.dtype))
74+
self.input_1 = torch.nn.Buffer(torch.tensor(-1.0, dtype=self.dtype))
75+
self.input_2 = torch.nn.Buffer(torch.tensor(1.0, dtype=self.dtype))
76+
77+
def get_inputs(self):
78+
return (torch.rand(self.buffer.size(), dtype=self.dtype),)
79+
80+
def forward(self, input: torch.Tensor):
81+
return (
82+
torch.where(
83+
self.buffer > self.min,
84+
self.input_1,
85+
self.input_2,
86+
)
87+
+ input
88+
)
89+
90+
6891
def tensor_condition(input: torch.Tensor):
6992
return input > torch.zeros_like(input)
7093

@@ -128,13 +151,19 @@ def scalar_condition(input: torch.Tensor):
128151
scalar_condition,
129152
)
130153

154+
const_float32 = ConstWhere(
155+
buffer=torch.tensor([[1.0, -1.0], [-1.0, 1.0]]),
156+
dtype=torch.float32,
157+
)
158+
131159
test_modules_common = {
132160
"two_dim_tensor_cond": lambda: two_dim_tensor_cond,
133161
"three_dim_tensor_cond": lambda: three_dim_tensor_cond,
134162
"float32_tensor_cond": lambda: float32_tensor_cond,
135163
"two_dim_scalar_cond": lambda: two_dim_scalar_cond,
136164
"three_dim_scalar_cond": lambda: three_dim_scalar_cond,
137165
"float32_scalar_cond": lambda: float32_scalar_cond,
166+
"const_float32": lambda: const_float32,
138167
}
139168

140169
test_modules_FP = {
@@ -183,7 +212,6 @@ def test_where_self_tosa_INT(test_module):
183212
test_module().get_inputs(),
184213
aten_op,
185214
exir_op,
186-
symmetric_io_quantization=True,
187215
)
188216
pipeline.run()
189217

@@ -253,6 +281,5 @@ def test_where_self_vgf_INT(test_module):
253281
aten_op,
254282
exir_op,
255283
tosa_version="TOSA-1.0+INT",
256-
symmetric_io_quantization=True,
257284
)
258285
pipeline.run()

0 commit comments

Comments
 (0)