Skip to content

Commit 2fa6327

Browse files
Arm backend: Remove FP specific NodeVisitors (#15519)
Merge NodeVisitors of targets with one FP and one INT visitor. Having two node visitors to choose from makes the INT+FP case tricky, hence why they are being merged. Signed-off-by: Oscar Andersson <[email protected]>
1 parent bde6b11 commit 2fa6327

File tree

3 files changed

+54
-176
lines changed

3 files changed

+54
-176
lines changed

backends/arm/operators/op_avg_pool2d.py

Lines changed: 10 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ class AvgPool2dVisitor(NodeVisitor):
3333

3434
tosa_specs = [
3535
TosaSpecification.create_from_string("TOSA-1.0+INT"),
36+
TosaSpecification.create_from_string("TOSA-1.0+FP"),
3637
]
3738

3839
def __init__(self, *args):
@@ -105,43 +106,6 @@ def _build_generic_avgpool2d(
105106
attr,
106107
)
107108

108-
def define_node(
109-
self,
110-
node: torch.fx.Node,
111-
tosa_graph: Any,
112-
inputs: List[TosaArg],
113-
output: TosaArg,
114-
) -> None:
115-
validate_num_inputs(self.target, inputs, [3, 4, 5, 6, 7])
116-
validate_same_dtype(self.target, [inputs[0], output], ts)
117-
validate_valid_dtype(
118-
self.target, [inputs[0], output], ts.DType.INT8, output.tosa_spec
119-
)
120-
121-
accumulator_type = ts.DType.INT32
122-
123-
input_qargs = get_input_qparams(node)
124-
input_zp = input_qargs[0].get_zp_per_tensor()
125-
126-
output_qargs = get_output_qparams(node)
127-
output_zp = output_qargs[0].get_zp_per_tensor()
128-
129-
self._build_generic_avgpool2d(
130-
node, tosa_graph, inputs, output, input_zp, output_zp, accumulator_type
131-
)
132-
133-
134-
@register_node_visitor
135-
class AvgPool2dVisitor_FP(AvgPool2dVisitor):
136-
target = "aten.avg_pool2d.default"
137-
138-
tosa_specs = [
139-
TosaSpecification.create_from_string("TOSA-1.0+FP"),
140-
]
141-
142-
def __init__(self, *args):
143-
super().__init__(*args)
144-
145109
def define_node(
146110
self,
147111
node: torch.fx.Node,
@@ -159,14 +123,17 @@ def define_node(
159123
)
160124

161125
if inputs[0].dtype == ts.DType.INT8:
162-
super().define_node(node, tosa_graph, inputs, output)
126+
accumulator_type = ts.DType.INT32
127+
input_qargs = get_input_qparams(node)
128+
input_zp = input_qargs[0].get_zp_per_tensor()
163129

164-
if inputs[0].dtype == ts.DType.FP32:
130+
output_qargs = get_output_qparams(node)
131+
output_zp = output_qargs[0].get_zp_per_tensor()
132+
else:
165133
accumulator_type = ts.DType.FP32
166-
# Initilize zero point to zero.
167134
input_zp = 0
168135
output_zp = 0
169136

170-
self._build_generic_avgpool2d(
171-
node, tosa_graph, inputs, output, input_zp, output_zp, accumulator_type
172-
)
137+
self._build_generic_avgpool2d(
138+
node, tosa_graph, inputs, output, input_zp, output_zp, accumulator_type
139+
)

backends/arm/operators/op_clamp.py

Lines changed: 29 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# Copyright 2025 Arm Limited and/or its affiliates.
2-
# All rights reserved.
32
#
43
# This source code is licensed under the BSD-style license found in the
54
# LICENSE file in the root directory of this source tree
@@ -27,18 +26,19 @@
2726

2827

2928
@register_node_visitor
30-
class ClampVisitor_INT(NodeVisitor):
29+
class ClampVisitor(NodeVisitor):
3130
target = "aten.clamp.default"
3231

3332
tosa_specs = [
3433
TosaSpecification.create_from_string("TOSA-1.0+INT"),
34+
TosaSpecification.create_from_string("TOSA-1.0+FP"),
3535
]
3636

3737
def __init__(self, *args):
3838
super().__init__(*args)
3939

4040
def _get_min_max_arguments(
41-
self, node: Node, dtype_min: int | float, dtype_max: int | float
41+
self, node: Node, dtype: torch.dtype
4242
) -> Tuple[int | float, int | float]:
4343

4444
def cast_type(value: Any) -> int | float:
@@ -48,6 +48,13 @@ def cast_type(value: Any) -> int | float:
4848
# Attempt to cast to float
4949
return float(value)
5050

51+
if dtype.is_floating_point:
52+
dtype_min = torch.finfo(dtype).min
53+
dtype_max = torch.finfo(dtype).max
54+
else:
55+
dtype_min = torch.iinfo(dtype).min
56+
dtype_max = torch.iinfo(dtype).max
57+
5158
min_arg = dtype_min
5259
max_arg = dtype_max
5360

@@ -60,53 +67,15 @@ def cast_type(value: Any) -> int | float:
6067

6168
return min_arg, max_arg
6269

63-
def define_node(
64-
self,
65-
node: Node,
66-
tosa_graph: Any,
67-
inputs: List[TosaArg],
68-
output: TosaArg,
69-
) -> None:
70-
validate_num_inputs(self.target, inputs, [2, 3])
71-
validate_same_dtype(self.target, [inputs[0], output], ts)
72-
validate_valid_dtype(
73-
self.target, [inputs[0], output], [ts.DType.INT8], output.tosa_spec
74-
)
75-
76-
# NOTE: Quantization of the min/max arguments is handled by QuantizeOperatorArguments
77-
min_int8, max_int8 = self._get_min_max_arguments(
78-
node,
79-
torch.iinfo(torch.int8).min,
80-
torch.iinfo(torch.int8).max,
81-
)
82-
83-
attr = ts.TosaSerializerAttribute()
84-
attr.ClampAttribute(
85-
np.frombuffer(np.int8(min_int8).tobytes(), dtype=np.uint8).tolist(),
86-
np.frombuffer(np.int8(max_int8).tobytes(), dtype=np.uint8).tolist(),
87-
ts.NanPropagationMode.PROPAGATE,
88-
)
89-
90-
self._serialize_operator(
91-
node,
92-
tosa_graph,
93-
ts.Op.CLAMP,
94-
[inputs[0].name],
95-
[output.name],
96-
attr,
97-
)
98-
99-
100-
@register_node_visitor
101-
class ClampVisitor_FP(ClampVisitor_INT):
102-
# inheriting 'target' from INT class
103-
104-
tosa_specs = [
105-
TosaSpecification.create_from_string("TOSA-1.0+FP"),
106-
]
107-
108-
def __init__(self, *args):
109-
super().__init__(*args)
70+
def _to_bytes(self, value: int | float, dtype: torch.dtype) -> bytes:
71+
if dtype == torch.float32:
72+
return np.frombuffer(np.float32(value).tobytes(), dtype=np.uint8).tolist()
73+
elif dtype == torch.float16:
74+
return np.frombuffer(np.float16(value).tobytes(), dtype=np.uint8).tolist()
75+
elif dtype == torch.int8:
76+
return np.frombuffer(np.int8(value).tobytes(), dtype=np.uint8).tolist()
77+
else:
78+
raise ValueError(f"Unsupported dtype for to_bytes: {dtype}")
11079

11180
def define_node(
11281
self,
@@ -120,42 +89,20 @@ def define_node(
12089
validate_valid_dtype(
12190
self.target,
12291
[inputs[0], output],
123-
[ts.DType.FP16, ts.DType.FP32],
92+
[ts.DType.INT8, ts.DType.FP16, ts.DType.FP32],
12493
output.tosa_spec,
12594
)
12695

96+
node_input_dtype = node.meta["val"].dtype
97+
# NOTE: Quantization of the min/max arguments is handled by QuantizeOperatorArguments
98+
min_val, max_val = self._get_min_max_arguments(node, node_input_dtype)
99+
127100
attr = ts.TosaSerializerAttribute()
128-
match inputs[0].dtype:
129-
case ts.DType.FP16:
130-
min_f, max_f = self._get_min_max_arguments(
131-
node,
132-
torch.finfo(torch.float16).min,
133-
torch.finfo(torch.float16).max,
134-
)
135-
min_bytes = np.frombuffer(
136-
np.float16(min_f).tobytes(), dtype=np.uint8
137-
).tolist()
138-
max_bytes = np.frombuffer(
139-
np.float16(max_f).tobytes(), dtype=np.uint8
140-
).tolist()
141-
case ts.DType.FP32:
142-
min_f, max_f = self._get_min_max_arguments(
143-
node,
144-
torch.finfo(torch.float32).min,
145-
torch.finfo(torch.float32).max,
146-
)
147-
min_bytes = np.frombuffer(
148-
np.float32(min_f).tobytes(), dtype=np.uint8
149-
).tolist()
150-
max_bytes = np.frombuffer(
151-
np.float32(max_f).tobytes(), dtype=np.uint8
152-
).tolist()
153-
case _:
154-
raise RuntimeError(
155-
f"Internal error: Unsupported dtype {inputs[0].dtype} in {self.target}"
156-
)
157-
158-
attr.ClampAttribute(min_bytes, max_bytes, ts.NanPropagationMode.PROPAGATE)
101+
attr.ClampAttribute(
102+
self._to_bytes(min_val, node_input_dtype),
103+
self._to_bytes(max_val, node_input_dtype),
104+
nan_mode=ts.NanPropagationMode.PROPAGATE,
105+
)
159106

160107
self._serialize_operator(
161108
node,

backends/arm/operators/op_where.py

Lines changed: 15 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
from typing import Any, List, Sequence
6+
from typing import Any, List
77

88
import tosa_serializer as ts
99

@@ -23,25 +23,36 @@
2323

2424

2525
@register_node_visitor
26-
class WhereVisitor_INT(NodeVisitor):
26+
class WhereVisitor(NodeVisitor):
2727
target = "aten.where.self"
2828

2929
tosa_specs = [
3030
TosaSpecification.create_from_string("TOSA-1.0+INT"),
31+
TosaSpecification.create_from_string("TOSA-1.0+FP"),
3132
]
3233

3334
def __init__(self, *args):
3435
super().__init__(*args)
3536

36-
def _add_node_to_tosa_graph(
37+
def define_node(
3738
self,
3839
node: Node,
3940
tosa_graph: Any,
4041
inputs: List[TosaArg],
4142
output: TosaArg,
42-
supported_dtypes: Sequence,
4343
) -> None:
4444

45+
supported_dtypes = []
46+
if output.tosa_spec.support_integer():
47+
supported_dtypes += [
48+
ts.DType.BOOL,
49+
ts.DType.INT8,
50+
ts.DType.INT16,
51+
ts.DType.INT32,
52+
]
53+
if output.tosa_spec.support_float():
54+
supported_dtypes += [ts.DType.BOOL, ts.DType.FP16, ts.DType.FP32]
55+
4556
validate_num_inputs(self.target, inputs, 3)
4657
# Not first input, which is condition tensor.
4758
validate_same_dtype(self.target, inputs[1:], ts)
@@ -63,50 +74,3 @@ def _add_node_to_tosa_graph(
6374
[output.name],
6475
attr,
6576
)
66-
67-
def define_node(
68-
self,
69-
node: Node,
70-
tosa_graph: Any,
71-
inputs: List[TosaArg],
72-
output: TosaArg,
73-
) -> None:
74-
bi_supported_dtypes = [
75-
ts.DType.INT8,
76-
ts.DType.INT16,
77-
ts.DType.INT32,
78-
ts.DType.BOOL,
79-
]
80-
self._add_node_to_tosa_graph(
81-
node, tosa_graph, inputs, output, bi_supported_dtypes
82-
)
83-
84-
85-
@register_node_visitor
86-
class WhereVisitor_FP(WhereVisitor_INT):
87-
88-
tosa_specs = [
89-
TosaSpecification.create_from_string("TOSA-1.0+FP"),
90-
]
91-
92-
def __init__(self, *args):
93-
super().__init__(*args)
94-
95-
def define_node(
96-
self,
97-
node: Node,
98-
tosa_graph: Any,
99-
inputs: List[TosaArg],
100-
output: TosaArg,
101-
) -> None:
102-
mi_supported_dtypes = [
103-
ts.DType.FP16,
104-
ts.DType.FP32,
105-
ts.DType.INT8,
106-
ts.DType.INT16,
107-
ts.DType.INT32,
108-
ts.DType.BOOL,
109-
]
110-
self._add_node_to_tosa_graph(
111-
node, tosa_graph, inputs, output, mi_supported_dtypes
112-
)

0 commit comments

Comments
 (0)