diff --git a/backends/arm/_passes/fuse_constant_ops_pass.py b/backends/arm/_passes/fuse_constant_ops_pass.py index efc140889d6..2c8986114db 100644 --- a/backends/arm/_passes/fuse_constant_ops_pass.py +++ b/backends/arm/_passes/fuse_constant_ops_pass.py @@ -65,7 +65,8 @@ def resolve_arg(arg): if isinstance(arg, torch.fx.Node) and arg in input_nodes: idx = input_nodes.index(arg) t = get_param_tensor(self.exported_program, arg) - if qparams: + # Check if qparams exist for this arg + if qparams and idx in qparams.keys(): t = qparams[idx].dequantize_value(t) return t if isinstance(arg, tuple): diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py index bd0fcd7d64f..9bf7414f2ed 100644 --- a/backends/arm/quantizer/quantization_annotator.py +++ b/backends/arm/quantizer/quantization_annotator.py @@ -524,12 +524,19 @@ def any_or_hardtanh_min_zero(n: Node): quant_properties.quant_output = _QuantProperty(0, shared_qspec) elif node.target in (torch.ops.aten.where.self,): true_node = ensure_type(Node, node.args[1]) - shared_qspec = SharedQuantizationSpec(true_node) + input_qspec = ( + SharedQuantizationSpec(true_node) + if is_output_annotated(true_node) + else input_act_qspec + ) quant_properties.quant_inputs = [ - _QuantProperty(1, shared_qspec), - _QuantProperty(2, shared_qspec), + _QuantProperty(1, input_qspec), + _QuantProperty(2, SharedQuantizationSpec((true_node, node))), ] - quant_properties.quant_output = _QuantProperty(0, shared_qspec) + quant_properties.quant_output = _QuantProperty( + 0, + SharedQuantizationSpec((true_node, node)), + ) elif node.target in _one_to_one_shared_input_or_input_act_qspec: input_node = ensure_type(Node, node.args[0]) input_qspec = ( diff --git a/backends/arm/test/ops/test_where.py b/backends/arm/test/ops/test_where.py index a35a9fc3b7d..e05aa2ffa5a 100644 --- a/backends/arm/test/ops/test_where.py +++ b/backends/arm/test/ops/test_where.py @@ -6,7 +6,6 @@ from typing import List, Tuple import torch - from executorch.backends.arm.quantizer import ( EthosUQuantizer, get_symmetric_quantization_config, @@ -65,6 +64,30 @@ def forward( return torch.where(self.condition(input_), input_, other_) +class ConstWhere(torch.nn.Module): + + def __init__(self, buffer: torch.Tensor, dtype: torch.dtype): + super().__init__() + self.buffer = buffer + self.dtype = dtype + self.min = torch.nn.Buffer(torch.tensor(0.0, dtype=self.dtype)) + self.input_1 = torch.nn.Buffer(torch.tensor(-1.0, dtype=self.dtype)) + self.input_2 = torch.nn.Buffer(torch.tensor(1.0, dtype=self.dtype)) + + def get_inputs(self): + return (torch.rand(self.buffer.size(), dtype=self.dtype),) + + def forward(self, input: torch.Tensor): + return ( + torch.where( + self.buffer > self.min, + self.input_1, + self.input_2, + ) + + input + ) + + def tensor_condition(input: torch.Tensor): return input > torch.zeros_like(input) @@ -128,6 +151,11 @@ def scalar_condition(input: torch.Tensor): scalar_condition, ) +const_float32 = ConstWhere( + buffer=torch.tensor([[1.0, -1.0], [-1.0, 1.0]]), + dtype=torch.float32, +) + test_modules_common = { "two_dim_tensor_cond": lambda: two_dim_tensor_cond, "three_dim_tensor_cond": lambda: three_dim_tensor_cond, @@ -135,6 +163,7 @@ def scalar_condition(input: torch.Tensor): "two_dim_scalar_cond": lambda: two_dim_scalar_cond, "three_dim_scalar_cond": lambda: three_dim_scalar_cond, "float32_scalar_cond": lambda: float32_scalar_cond, + "const_float32": lambda: const_float32, } test_modules_FP = { @@ -183,7 +212,6 @@ def test_where_self_tosa_INT(test_module): test_module().get_inputs(), aten_op, exir_op, - symmetric_io_quantization=True, ) pipeline.run() @@ -253,6 +281,5 @@ def test_where_self_vgf_INT(test_module): aten_op, exir_op, tosa_version="TOSA-1.0+INT", - symmetric_io_quantization=True, ) pipeline.run()