Skip to content

Commit eb409ba

Browse files
Arm backend: Add docstrings for quantizer/quantization_annotator.py (#15523)
Signed-off-by: Sebastian Larsson <[email protected]>
1 parent 5bcf9af commit eb409ba

File tree

1 file changed

+115
-16
lines changed

1 file changed

+115
-16
lines changed

backends/arm/quantizer/quantization_annotator.py

Lines changed: 115 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
5+
"""Provide quantization annotation logic for Arm backends.
6+
7+
This module computes per-node quantization properties and applies input/output
8+
annotations to FX graphs using TorchAO qspecs.
9+
10+
"""
511

612
import logging
713
import operator
@@ -44,12 +50,31 @@ class _QuantProperty:
4450

4551

4652
class _OpQuantProperties:
53+
"""Collect input/output quantization properties for a node.
54+
55+
Attributes:
56+
quant_inputs (List[_QuantProperty]): Quantization specs for inputs
57+
indexed by argument positions.
58+
quant_output (Optional[_QuantProperty]): Quantization spec for the
59+
node's output when applicable.
60+
61+
"""
62+
4763
def __init__(self):
4864
self.quant_inputs: List[_QuantProperty] = []
4965
self.quant_output: Optional[_QuantProperty] = None
5066

5167

5268
def _as_list(x):
69+
"""Return ``x`` wrapped as a list if needed.
70+
71+
Args:
72+
x: Value or list of values.
73+
74+
Returns:
75+
list: ``x`` if already a list; otherwise ``[x]``.
76+
77+
"""
5378
if isinstance(x, list):
5479
return x
5580
else:
@@ -66,16 +91,19 @@ def _is_ok_for_quantization(
6691
A node can be quantized if:
6792
- All inputs that are required for quantization are of type `float32`
6893
and are not large scalar values.
69-
- The output of the node itself is of type `float32` and is not a large scalar.
94+
- The output of the node itself is of type `float32` and is not a large
95+
scalar.
7096
7197
Args:
7298
node (Node): The node being analyzed.
73-
quant_properties (_OpQuantProperties): Contains quantization properties for
74-
the node, including input and output quantization specifications.
75-
gm (torch.fx.GraphModule): The graph module containing the computational graph.
99+
quant_properties (_OpQuantProperties): Contains quantization properties
100+
for the node, including input and output quantization specifications.
101+
gm (torch.fx.GraphModule): The graph module containing the computational
102+
graph.
76103
77104
Returns:
78105
bool: `True` if the node can be quantized, otherwise `False`.
106+
79107
"""
80108
# Check output
81109
if quant_properties.quant_output is not None:
@@ -127,16 +155,28 @@ def _is_ok_for_quantization(
127155

128156

129157
def _get_node_target(module: torch.nn.Module | torch.fx.GraphModule, target_str: str):
158+
"""Get an attribute from a module by dotted path.
159+
160+
Args:
161+
module (torch.nn.Module | torch.fx.GraphModule): Root module.
162+
target_str (str): Dotted attribute path, e.g., ``"sub.weight"``.
163+
164+
Returns:
165+
Any: Resolved attribute on the module.
166+
167+
"""
130168
targets = target_str.split(".")
131169
for target in targets[:-1]:
132170
module = module.get_submodule(target)
133171
return getattr(module, targets[-1])
134172

135173

136174
def _is_large_scalar(node: Node, gm: torch.fx.GraphModule):
137-
"""Check if input is a large scalar value. So that we can skip quantization for the
138-
node since histc op (in HistogramObserver) only works for values up to certain upper
139-
bound.
175+
"""Return True if input is a large scalar value.
176+
177+
Large scalars are skipped because ``torch.histc`` supports values only up
178+
to a certain upper bound.
179+
140180
"""
141181
HISTC_UPPER_BOUND = 3.4028235e15
142182
if node.op == "get_attr" and isinstance(node.target, str):
@@ -166,11 +206,12 @@ def _is_non_float_tensor(node: Node) -> bool:
166206
bool: `True` if the data type is not float32, otherwise `False`.
167207
168208
Note:
169-
- If `node.meta["val"]` is a `list`, the function returns `True` if **any**
170-
element is **not** an instance of `FakeTensor` or does **not** have
209+
- If `node.meta["val"]` is a `list`, the function returns `True` if
210+
any element is not an instance of `FakeTensor` or does not have
171211
`torch.float32` as its data type.
172-
- If node.meta["val"] is missing or is not an instance of `FakeTensor`, the
173-
function returns True.
212+
- If node.meta["val"] is missing or is not an instance of `FakeTensor`,
213+
the function returns True.
214+
174215
"""
175216
if "val" in node.meta and isinstance(node.meta["val"], Sequence):
176217
return any(
@@ -186,6 +227,20 @@ def _is_non_float_tensor(node: Node) -> bool:
186227

187228

188229
def _annotate_input(node: Node, quant_property: _QuantProperty):
230+
"""Annotate a node's input with the given qspec.
231+
232+
Maps the specified input argument(s) to the provided quantization spec and
233+
optionally marks the input node(s) as annotated.
234+
235+
Args:
236+
node (Node): Node whose input should be annotated.
237+
quant_property (_QuantProperty): Input index and qspec(s).
238+
239+
Raises:
240+
RuntimeError: If the node is already annotated.
241+
TypeError: If an input argument is not a ``Node`` instance.
242+
243+
"""
189244
if is_annotated(node):
190245
raise RuntimeError(
191246
f"Cannot annotate input: node '{node.name}' is already annotated"
@@ -211,6 +266,18 @@ def _annotate_input(node: Node, quant_property: _QuantProperty):
211266

212267

213268
def _annotate_output(node: Node, quant_property: _QuantProperty):
269+
"""Annotate a node's output with the given qspec.
270+
271+
Args:
272+
node (Node): Node whose output should be annotated.
273+
quant_property (_QuantProperty): Output index and qspec.
274+
275+
Raises:
276+
RuntimeError: If the node is already annotated.
277+
ValueError: If ``mark_annotated`` is True, ``optional`` is True, or
278+
``index`` is not zero.
279+
280+
"""
214281
if is_annotated(node):
215282
raise RuntimeError(
216283
f"Cannot annotate output: node '{node.name}' is already annotated"
@@ -230,12 +297,13 @@ def _annotate_output(node: Node, quant_property: _QuantProperty):
230297
def _match_pattern(
231298
node: Node, pattern: List[List], filter_fn: Optional[Callable[[Node], bool]] = None
232299
) -> bool:
233-
"""
234-
Check if there's a chain of node.ancestors? -> node -> node.descendant? that matches the
235-
chain provided in 'pattern'. If 'filter_fn' is provided, check that all the nodes in the
236-
chain pass the filtering.
300+
"""Check whether a node chain matches a pattern.
301+
302+
Verify a chain of ancestors -> node -> descendants matches the provided
303+
``pattern``. If ``filter_fn`` is provided, require all nodes in the chain
304+
to pass the filter. Each pattern element is a list of disjunctive node
305+
targets.
237306
238-
Each 'pattern' element is composed of a list of disjunctive nodes types.
239307
"""
240308
if len(pattern) < 1:
241309
raise ValueError("No pattern provided")
@@ -382,6 +450,21 @@ def _match_pattern(
382450
def get_quant_properties( # noqa: C901
383451
node: Node, gm: torch.fx.GraphModule, quantization_config
384452
) -> _OpQuantProperties | None:
453+
"""Compute quantization properties for a node.
454+
455+
Determine which inputs and/or outputs should be annotated for quantization
456+
based on the node's operator and surrounding pattern.
457+
458+
Args:
459+
node (Node): Node to analyze.
460+
gm (torch.fx.GraphModule): Owning graph module.
461+
quantization_config: Source for activation/weight/bias qspecs.
462+
463+
Returns:
464+
_OpQuantProperties | None: Properties to apply, or ``None`` if the
465+
node is unsupported or not suitable for quantization.
466+
467+
"""
385468
input_act_qspec = quantization_config.get_input_act_qspec()
386469
weight_qspec = quantization_config.get_weight_qspec()
387470
output_act_qspec = quantization_config.get_output_act_qspec()
@@ -390,6 +473,7 @@ def get_quant_properties( # noqa: C901
390473
quant_properties = _OpQuantProperties()
391474

392475
def any_or_hardtanh_min_zero(n: Node):
476+
"""Return True for any op or hardtanh with ``min_val == 0``."""
393477
# Check that if the node is a hardtanh, its min_val is zero
394478
return n.target != torch.ops.aten.hardtanh.default or n.args[1] == 0
395479

@@ -636,6 +720,21 @@ def annotate_graph( # type: ignore[return]
636720
quantization_config: QuantizationConfig,
637721
filter_fn: Optional[Callable[[Node], bool]] = None,
638722
) -> Optional[List[List[Node]]]:
723+
"""Annotate supported nodes in a graph with quantization specs.
724+
725+
Iterate through call_function nodes, computes quantization properties, and
726+
apply input/output annotations. A filter can restrict which nodes are
727+
considered.
728+
729+
Args:
730+
gm (torch.fx.GraphModule): Graph to annotate.
731+
quantization_config (QuantizationConfig): Default qspecs for nodes.
732+
filter_fn (Optional[Callable[[Node], bool]]): Optional node predicate.
733+
734+
Returns:
735+
Optional[List[List[Node]]]: Reserved for future use; currently None.
736+
737+
"""
639738
for node in gm.graph.nodes:
640739
if node.op != "call_function":
641740
continue

0 commit comments

Comments
 (0)