11# Copyright 2025 Arm Limited and/or its affiliates.
2- # All rights reserved.
32#
43# This source code is licensed under the BSD-style license found in the
54# LICENSE file in the root directory of this source tree
2726
2827
2928@register_node_visitor
30- class ClampVisitor_INT (NodeVisitor ):
29+ class ClampVisitor (NodeVisitor ):
3130 target = "aten.clamp.default"
3231
3332 tosa_specs = [
3433 TosaSpecification .create_from_string ("TOSA-1.0+INT" ),
34+ TosaSpecification .create_from_string ("TOSA-1.0+FP" ),
3535 ]
3636
3737 def __init__ (self , * args ):
3838 super ().__init__ (* args )
3939
4040 def _get_min_max_arguments (
41- self , node : Node , dtype_min : int | float , dtype_max : int | float
41+ self , node : Node , dtype : torch . dtype
4242 ) -> Tuple [int | float , int | float ]:
4343
4444 def cast_type (value : Any ) -> int | float :
@@ -48,6 +48,13 @@ def cast_type(value: Any) -> int | float:
4848 # Attempt to cast to float
4949 return float (value )
5050
51+ if dtype .is_floating_point :
52+ dtype_min = torch .finfo (dtype ).min
53+ dtype_max = torch .finfo (dtype ).max
54+ else :
55+ dtype_min = torch .iinfo (dtype ).min
56+ dtype_max = torch .iinfo (dtype ).max
57+
5158 min_arg = dtype_min
5259 max_arg = dtype_max
5360
@@ -60,53 +67,15 @@ def cast_type(value: Any) -> int | float:
6067
6168 return min_arg , max_arg
6269
63- def define_node (
64- self ,
65- node : Node ,
66- tosa_graph : Any ,
67- inputs : List [TosaArg ],
68- output : TosaArg ,
69- ) -> None :
70- validate_num_inputs (self .target , inputs , [2 , 3 ])
71- validate_same_dtype (self .target , [inputs [0 ], output ], ts )
72- validate_valid_dtype (
73- self .target , [inputs [0 ], output ], [ts .DType .INT8 ], output .tosa_spec
74- )
75-
76- # NOTE: Quantization of the min/max arguments is handled by QuantizeOperatorArguments
77- min_int8 , max_int8 = self ._get_min_max_arguments (
78- node ,
79- torch .iinfo (torch .int8 ).min ,
80- torch .iinfo (torch .int8 ).max ,
81- )
82-
83- attr = ts .TosaSerializerAttribute ()
84- attr .ClampAttribute (
85- np .frombuffer (np .int8 (min_int8 ).tobytes (), dtype = np .uint8 ).tolist (),
86- np .frombuffer (np .int8 (max_int8 ).tobytes (), dtype = np .uint8 ).tolist (),
87- ts .NanPropagationMode .PROPAGATE ,
88- )
89-
90- self ._serialize_operator (
91- node ,
92- tosa_graph ,
93- ts .Op .CLAMP ,
94- [inputs [0 ].name ],
95- [output .name ],
96- attr ,
97- )
98-
99-
100- @register_node_visitor
101- class ClampVisitor_FP (ClampVisitor_INT ):
102- # inheriting 'target' from INT class
103-
104- tosa_specs = [
105- TosaSpecification .create_from_string ("TOSA-1.0+FP" ),
106- ]
107-
108- def __init__ (self , * args ):
109- super ().__init__ (* args )
70+ def _to_bytes (self , value : int | float , dtype : torch .dtype ) -> bytes :
71+ if dtype == torch .float32 :
72+ return np .frombuffer (np .float32 (value ).tobytes (), dtype = np .uint8 ).tolist ()
73+ elif dtype == torch .float16 :
74+ return np .frombuffer (np .float16 (value ).tobytes (), dtype = np .uint8 ).tolist ()
75+ elif dtype == torch .int8 :
76+ return np .frombuffer (np .int8 (value ).tobytes (), dtype = np .uint8 ).tolist ()
77+ else :
78+ raise ValueError (f"Unsupported dtype for to_bytes: { dtype } " )
11079
11180 def define_node (
11281 self ,
@@ -120,42 +89,20 @@ def define_node(
12089 validate_valid_dtype (
12190 self .target ,
12291 [inputs [0 ], output ],
123- [ts .DType .FP16 , ts .DType .FP32 ],
92+ [ts .DType .INT8 , ts . DType . FP16 , ts .DType .FP32 ],
12493 output .tosa_spec ,
12594 )
12695
96+ node_input_dtype = node .meta ["val" ].dtype
97+ # NOTE: Quantization of the min/max arguments is handled by QuantizeOperatorArguments
98+ min_val , max_val = self ._get_min_max_arguments (node , node_input_dtype )
99+
127100 attr = ts .TosaSerializerAttribute ()
128- match inputs [0 ].dtype :
129- case ts .DType .FP16 :
130- min_f , max_f = self ._get_min_max_arguments (
131- node ,
132- torch .finfo (torch .float16 ).min ,
133- torch .finfo (torch .float16 ).max ,
134- )
135- min_bytes = np .frombuffer (
136- np .float16 (min_f ).tobytes (), dtype = np .uint8
137- ).tolist ()
138- max_bytes = np .frombuffer (
139- np .float16 (max_f ).tobytes (), dtype = np .uint8
140- ).tolist ()
141- case ts .DType .FP32 :
142- min_f , max_f = self ._get_min_max_arguments (
143- node ,
144- torch .finfo (torch .float32 ).min ,
145- torch .finfo (torch .float32 ).max ,
146- )
147- min_bytes = np .frombuffer (
148- np .float32 (min_f ).tobytes (), dtype = np .uint8
149- ).tolist ()
150- max_bytes = np .frombuffer (
151- np .float32 (max_f ).tobytes (), dtype = np .uint8
152- ).tolist ()
153- case _:
154- raise RuntimeError (
155- f"Internal error: Unsupported dtype { inputs [0 ].dtype } in { self .target } "
156- )
157-
158- attr .ClampAttribute (min_bytes , max_bytes , ts .NanPropagationMode .PROPAGATE )
101+ attr .ClampAttribute (
102+ self ._to_bytes (min_val , node_input_dtype ),
103+ self ._to_bytes (max_val , node_input_dtype ),
104+ nan_mode = ts .NanPropagationMode .PROPAGATE ,
105+ )
159106
160107 self ._serialize_operator (
161108 node ,
0 commit comments