22#
33# This source code is licensed under the BSD-style license found in the
44# LICENSE file in the root directory of this source tree.
5+ """Provide quantization annotation logic for Arm backends.
6+
7+ This module computes per-node quantization properties and applies input/output
8+ annotations to FX graphs using TorchAO qspecs.
9+
10+ """
511
612import logging
713import operator
@@ -44,12 +50,31 @@ class _QuantProperty:
4450
4551
4652class _OpQuantProperties :
53+ """Collect input/output quantization properties for a node.
54+
55+ Attributes:
56+ quant_inputs (List[_QuantProperty]): Quantization specs for inputs
57+ indexed by argument positions.
58+ quant_output (Optional[_QuantProperty]): Quantization spec for the
59+ node's output when applicable.
60+
61+ """
62+
4763 def __init__ (self ):
4864 self .quant_inputs : List [_QuantProperty ] = []
4965 self .quant_output : Optional [_QuantProperty ] = None
5066
5167
5268def _as_list (x ):
69+ """Return ``x`` wrapped as a list if needed.
70+
71+ Args:
72+ x: Value or list of values.
73+
74+ Returns:
75+ list: ``x`` if already a list; otherwise ``[x]``.
76+
77+ """
5378 if isinstance (x , list ):
5479 return x
5580 else :
@@ -66,16 +91,19 @@ def _is_ok_for_quantization(
6691 A node can be quantized if:
6792 - All inputs that are required for quantization are of type `float32`
6893 and are not large scalar values.
69- - The output of the node itself is of type `float32` and is not a large scalar.
94+ - The output of the node itself is of type `float32` and is not a large
95+ scalar.
7096
7197 Args:
7298 node (Node): The node being analyzed.
73- quant_properties (_OpQuantProperties): Contains quantization properties for
74- the node, including input and output quantization specifications.
75- gm (torch.fx.GraphModule): The graph module containing the computational graph.
99+ quant_properties (_OpQuantProperties): Contains quantization properties
100+ for the node, including input and output quantization specifications.
101+ gm (torch.fx.GraphModule): The graph module containing the computational
102+ graph.
76103
77104 Returns:
78105 bool: `True` if the node can be quantized, otherwise `False`.
106+
79107 """
80108 # Check output
81109 if quant_properties .quant_output is not None :
@@ -127,16 +155,28 @@ def _is_ok_for_quantization(
127155
128156
129157def _get_node_target (module : torch .nn .Module | torch .fx .GraphModule , target_str : str ):
158+ """Get an attribute from a module by dotted path.
159+
160+ Args:
161+ module (torch.nn.Module | torch.fx.GraphModule): Root module.
162+ target_str (str): Dotted attribute path, e.g., ``"sub.weight"``.
163+
164+ Returns:
165+ Any: Resolved attribute on the module.
166+
167+ """
130168 targets = target_str .split ("." )
131169 for target in targets [:- 1 ]:
132170 module = module .get_submodule (target )
133171 return getattr (module , targets [- 1 ])
134172
135173
136174def _is_large_scalar (node : Node , gm : torch .fx .GraphModule ):
137- """Check if input is a large scalar value. So that we can skip quantization for the
138- node since histc op (in HistogramObserver) only works for values up to certain upper
139- bound.
175+ """Return True if input is a large scalar value.
176+
177+ Large scalars are skipped because ``torch.histc`` supports values only up
178+ to a certain upper bound.
179+
140180 """
141181 HISTC_UPPER_BOUND = 3.4028235e15
142182 if node .op == "get_attr" and isinstance (node .target , str ):
@@ -166,11 +206,12 @@ def _is_non_float_tensor(node: Node) -> bool:
166206 bool: `True` if the data type is not float32, otherwise `False`.
167207
168208 Note:
169- - If `node.meta["val"]` is a `list`, the function returns `True` if **any**
170- element is ** not** an instance of `FakeTensor` or does ** not** have
209+ - If `node.meta["val"]` is a `list`, the function returns `True` if
210+ any element is not an instance of `FakeTensor` or does not have
171211 `torch.float32` as its data type.
172- - If node.meta["val"] is missing or is not an instance of `FakeTensor`, the
173- function returns True.
212+ - If node.meta["val"] is missing or is not an instance of `FakeTensor`,
213+ the function returns True.
214+
174215 """
175216 if "val" in node .meta and isinstance (node .meta ["val" ], Sequence ):
176217 return any (
@@ -186,6 +227,20 @@ def _is_non_float_tensor(node: Node) -> bool:
186227
187228
188229def _annotate_input (node : Node , quant_property : _QuantProperty ):
230+ """Annotate a node's input with the given qspec.
231+
232+ Maps the specified input argument(s) to the provided quantization spec and
233+ optionally marks the input node(s) as annotated.
234+
235+ Args:
236+ node (Node): Node whose input should be annotated.
237+ quant_property (_QuantProperty): Input index and qspec(s).
238+
239+ Raises:
240+ RuntimeError: If the node is already annotated.
241+ TypeError: If an input argument is not a ``Node`` instance.
242+
243+ """
189244 if is_annotated (node ):
190245 raise RuntimeError (
191246 f"Cannot annotate input: node '{ node .name } ' is already annotated"
@@ -211,6 +266,18 @@ def _annotate_input(node: Node, quant_property: _QuantProperty):
211266
212267
213268def _annotate_output (node : Node , quant_property : _QuantProperty ):
269+ """Annotate a node's output with the given qspec.
270+
271+ Args:
272+ node (Node): Node whose output should be annotated.
273+ quant_property (_QuantProperty): Output index and qspec.
274+
275+ Raises:
276+ RuntimeError: If the node is already annotated.
277+ ValueError: If ``mark_annotated`` is True, ``optional`` is True, or
278+ ``index`` is not zero.
279+
280+ """
214281 if is_annotated (node ):
215282 raise RuntimeError (
216283 f"Cannot annotate output: node '{ node .name } ' is already annotated"
@@ -230,12 +297,13 @@ def _annotate_output(node: Node, quant_property: _QuantProperty):
230297def _match_pattern (
231298 node : Node , pattern : List [List ], filter_fn : Optional [Callable [[Node ], bool ]] = None
232299) -> bool :
233- """
234- Check if there's a chain of node.ancestors? -> node -> node.descendant? that matches the
235- chain provided in 'pattern'. If 'filter_fn' is provided, check that all the nodes in the
236- chain pass the filtering.
300+ """Check whether a node chain matches a pattern.
301+
302+ Verify a chain of ancestors -> node -> descendants matches the provided
303+ ``pattern``. If ``filter_fn`` is provided, require all nodes in the chain
304+ to pass the filter. Each pattern element is a list of disjunctive node
305+ targets.
237306
238- Each 'pattern' element is composed of a list of disjunctive nodes types.
239307 """
240308 if len (pattern ) < 1 :
241309 raise ValueError ("No pattern provided" )
@@ -382,6 +450,21 @@ def _match_pattern(
382450def get_quant_properties ( # noqa: C901
383451 node : Node , gm : torch .fx .GraphModule , quantization_config
384452) -> _OpQuantProperties | None :
453+ """Compute quantization properties for a node.
454+
455+ Determine which inputs and/or outputs should be annotated for quantization
456+ based on the node's operator and surrounding pattern.
457+
458+ Args:
459+ node (Node): Node to analyze.
460+ gm (torch.fx.GraphModule): Owning graph module.
461+ quantization_config: Source for activation/weight/bias qspecs.
462+
463+ Returns:
464+ _OpQuantProperties | None: Properties to apply, or ``None`` if the
465+ node is unsupported or not suitable for quantization.
466+
467+ """
385468 input_act_qspec = quantization_config .get_input_act_qspec ()
386469 weight_qspec = quantization_config .get_weight_qspec ()
387470 output_act_qspec = quantization_config .get_output_act_qspec ()
@@ -390,6 +473,7 @@ def get_quant_properties( # noqa: C901
390473 quant_properties = _OpQuantProperties ()
391474
392475 def any_or_hardtanh_min_zero (n : Node ):
476+ """Return True for any op or hardtanh with ``min_val == 0``."""
393477 # Check that if the node is a hardtanh, its min_val is zero
394478 return n .target != torch .ops .aten .hardtanh .default or n .args [1 ] == 0
395479
@@ -636,6 +720,21 @@ def annotate_graph( # type: ignore[return]
636720 quantization_config : QuantizationConfig ,
637721 filter_fn : Optional [Callable [[Node ], bool ]] = None ,
638722) -> Optional [List [List [Node ]]]:
723+ """Annotate supported nodes in a graph with quantization specs.
724+
725+ Iterate through call_function nodes, computes quantization properties, and
726+ apply input/output annotations. A filter can restrict which nodes are
727+ considered.
728+
729+ Args:
730+ gm (torch.fx.GraphModule): Graph to annotate.
731+ quantization_config (QuantizationConfig): Default qspecs for nodes.
732+ filter_fn (Optional[Callable[[Node], bool]]): Optional node predicate.
733+
734+ Returns:
735+ Optional[List[List[Node]]]: Reserved for future use; currently None.
736+
737+ """
639738 for node in gm .graph .nodes :
640739 if node .op != "call_function" :
641740 continue
0 commit comments