|
6 | 6 | from typing import List, Tuple |
7 | 7 |
|
8 | 8 | import torch |
9 | | - |
10 | 9 | from executorch.backends.arm.quantizer import ( |
11 | 10 | EthosUQuantizer, |
12 | 11 | get_symmetric_quantization_config, |
@@ -65,6 +64,30 @@ def forward( |
65 | 64 | return torch.where(self.condition(input_), input_, other_) |
66 | 65 |
|
67 | 66 |
|
| 67 | +class ConstWhere(torch.nn.Module): |
| 68 | + |
| 69 | + def __init__(self, buffer: torch.Tensor, dtype: torch.dtype): |
| 70 | + super().__init__() |
| 71 | + self.buffer = buffer |
| 72 | + self.dtype = dtype |
| 73 | + self.min = torch.nn.Buffer(torch.tensor(0.0, dtype=self.dtype)) |
| 74 | + self.input_1 = torch.nn.Buffer(torch.tensor(-1.0, dtype=self.dtype)) |
| 75 | + self.input_2 = torch.nn.Buffer(torch.tensor(1.0, dtype=self.dtype)) |
| 76 | + |
| 77 | + def get_inputs(self): |
| 78 | + return (torch.rand(self.buffer.size(), dtype=self.dtype),) |
| 79 | + |
| 80 | + def forward(self, input: torch.Tensor): |
| 81 | + return ( |
| 82 | + torch.where( |
| 83 | + self.buffer > self.min, |
| 84 | + self.input_1, |
| 85 | + self.input_2, |
| 86 | + ) |
| 87 | + + input |
| 88 | + ) |
| 89 | + |
| 90 | + |
68 | 91 | def tensor_condition(input: torch.Tensor): |
69 | 92 | return input > torch.zeros_like(input) |
70 | 93 |
|
@@ -128,13 +151,19 @@ def scalar_condition(input: torch.Tensor): |
128 | 151 | scalar_condition, |
129 | 152 | ) |
130 | 153 |
|
| 154 | +const_float32 = ConstWhere( |
| 155 | + buffer=torch.tensor([[1.0, -1.0], [-1.0, 1.0]]), |
| 156 | + dtype=torch.float32, |
| 157 | +) |
| 158 | + |
131 | 159 | test_modules_common = { |
132 | 160 | "two_dim_tensor_cond": lambda: two_dim_tensor_cond, |
133 | 161 | "three_dim_tensor_cond": lambda: three_dim_tensor_cond, |
134 | 162 | "float32_tensor_cond": lambda: float32_tensor_cond, |
135 | 163 | "two_dim_scalar_cond": lambda: two_dim_scalar_cond, |
136 | 164 | "three_dim_scalar_cond": lambda: three_dim_scalar_cond, |
137 | 165 | "float32_scalar_cond": lambda: float32_scalar_cond, |
| 166 | + "const_float32": lambda: const_float32, |
138 | 167 | } |
139 | 168 |
|
140 | 169 | test_modules_FP = { |
@@ -183,7 +212,6 @@ def test_where_self_tosa_INT(test_module): |
183 | 212 | test_module().get_inputs(), |
184 | 213 | aten_op, |
185 | 214 | exir_op, |
186 | | - symmetric_io_quantization=True, |
187 | 215 | ) |
188 | 216 | pipeline.run() |
189 | 217 |
|
@@ -253,6 +281,5 @@ def test_where_self_vgf_INT(test_module): |
253 | 281 | aten_op, |
254 | 282 | exir_op, |
255 | 283 | tosa_version="TOSA-1.0+INT", |
256 | | - symmetric_io_quantization=True, |
257 | 284 | ) |
258 | 285 | pipeline.run() |
0 commit comments