Skip to content

Commit 19c9ff3

Browse files
authored
Add xnnpack pass to propagate custom meta field to q/dq nodes (#14864)
### Summary Enable quantization with program-data separation. To select weights for separation, we tag nodes on the eager model. After quantization, qdq nodes are generated. These do not contain the external tags that their inputs have. This PR propagates the tags to the qdq nodes, so that quant weights are moved to external file and can be shared. ### Test plan ``` python -m unittest executorch.backends.xnnpack.test.passes.test_propagate_custom_meta_pass ```
1 parent 71bbb1b commit 19c9ff3

File tree

5 files changed

+259
-2
lines changed

5 files changed

+259
-2
lines changed

backends/test/harness/stages/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from .export import Export
22
from .partition import Partition
3-
from .quantize import Quantize
3+
from .quantize import Quantize, Quantize_
44
from .run_passes import RunPasses
55
from .serialize import Serialize
66
from .stage import Stage, StageType
@@ -12,6 +12,7 @@
1212
"Export",
1313
"Partition",
1414
"Quantize",
15+
"Quantize_",
1516
"RunPasses",
1617
"Serialize",
1718
"Stage",

backends/test/harness/stages/quantize.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Optional, Sequence, Tuple
1+
from typing import Any, Callable, Optional, Sequence, Tuple
22

33
import torch
44

@@ -15,6 +15,8 @@
1515
prepare_qat_pt2e,
1616
)
1717
from torchao.quantization.pt2e.quantizer import Quantizer
18+
from torchao.quantization.quant_api import quantize_
19+
from torchao.utils import unwrap_tensor_subclass
1820

1921

2022
class Quantize(Stage):
@@ -79,3 +81,48 @@ def graph_module(self) -> str:
7981

8082
def run_artifact(self, inputs):
8183
return self.converted_graph.forward(*inputs)
84+
85+
86+
class Quantize_(Stage):
87+
"""
88+
TorchAO quantization stage using the quantize_ API.
89+
"""
90+
91+
def __init__(
92+
self,
93+
config: Any,
94+
filter_fn: Optional[Callable[[torch.nn.Module, str], bool]] = None,
95+
):
96+
"""
97+
Args:
98+
config: TorchAO quantization config (e.g., Int4WeightOnlyConfig, Int8DynamicActivationInt8WeightConfig)
99+
filter_fn: Optional filter function to select which modules to quantize
100+
"""
101+
self.config = config
102+
self.filter_fn = filter_fn
103+
self.quantized_module = None
104+
105+
def stage_type(self) -> str:
106+
return StageType.QUANTIZE
107+
108+
def run(
109+
self, artifact: torch.nn.Module, inputs: Optional[Tuple[torch.Tensor]]
110+
) -> None:
111+
# Apply quantize_ to the model
112+
quantize_(artifact, self.config, self.filter_fn)
113+
114+
# Unwrap tensor subclasses for export compatibility
115+
unwrap_tensor_subclass(artifact)
116+
117+
self.quantized_module = artifact
118+
119+
@property
120+
def artifact(self) -> torch.nn.Module:
121+
return self.quantized_module
122+
123+
@property
124+
def graph_module(self) -> torch.nn.Module:
125+
return self.quantized_module
126+
127+
def run_artifact(self, inputs):
128+
return self.quantized_module.forward(*inputs)

backends/xnnpack/_passes/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323
from executorch.backends.xnnpack._passes.fuse_activation_pass import FuseActivationPass
2424
from executorch.backends.xnnpack._passes.fuse_batch_norm import FuseBatchNormPass
2525
from executorch.backends.xnnpack._passes.prelu_reshape_pass import PReLUReshapePass
26+
from executorch.backends.xnnpack._passes.propagate_custom_meta_pass import (
27+
PropagateCustomMetaPass,
28+
)
2629
from executorch.backends.xnnpack._passes.remove_redundant_copy_pass import (
2730
RemoveRedundantCopyPass,
2831
)
@@ -59,6 +62,7 @@ def __init__(
5962
DimOrderOpsRevertPass,
6063
ConvertToUpsampleBilinear2d,
6164
ConvertToLinearPass,
65+
PropagateCustomMetaPass,
6266
ConvertToSDPAPass,
6367
ConstPropPass,
6468
FuseBatchNormPass,
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass
9+
from executorch.backends.xnnpack.utils.quant_utils import is_dequant, is_quant
10+
from executorch.exir.pass_base import PassResult
11+
12+
13+
class PropagateCustomMetaPass(XNNPACKPass):
14+
"""
15+
Pass to propagate node.meta['custom'] from parent nodes to their q/dq child nodes.
16+
For all quantize/dequantize nodes in the graph, if the parent node has a
17+
node.meta['custom'] entry, this pass will copy that value to the q/dq node's meta.
18+
"""
19+
20+
def call(self, graph_module: torch.fx.GraphModule):
21+
graph = graph_module.graph
22+
23+
for node in graph.nodes:
24+
if not (is_quant(node) or is_dequant(node)):
25+
continue
26+
27+
# Get the parent node (first input argument)
28+
if len(node.all_input_nodes) == 0:
29+
continue
30+
31+
parent_node = node.args[0]
32+
if not isinstance(parent_node, torch.fx.Node):
33+
continue
34+
35+
if "custom" in parent_node.meta:
36+
node.meta["custom"] = parent_node.meta["custom"]
37+
38+
graph_module.recompile()
39+
40+
# Since we are overriding "call", we need to call the parent's "call"
41+
# to retrace the graph and regenerate metadata
42+
graph_module = super().call(graph_module).graph_module
43+
44+
return PassResult(graph_module, True)
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
from typing import Tuple, Union
10+
11+
import executorch.backends.test.harness.stages as BaseStages
12+
13+
import torch
14+
from executorch.backends.xnnpack.partition.config.xnnpack_config import (
15+
ConfigPrecisionType,
16+
)
17+
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
18+
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import (
19+
get_symmetric_quantization_config,
20+
)
21+
from executorch.backends.xnnpack.test.tester import Quantize as XNNPackQuantize, Tester
22+
from executorch.backends.xnnpack.test.tester.tester import ToEdgeTransformAndLower
23+
from executorch.exir.passes.external_constants_pass import (
24+
delegate_external_constants_pass_unlifted,
25+
)
26+
27+
from torchao.quantization.granularity import PerGroup
28+
from torchao.quantization.quant_api import Int8DynamicActivationIntxWeightConfig
29+
30+
try:
31+
import executorch.extension.pybindings.portable_lib # noqa[F401]
32+
import executorch.kernels.quantized # noqa[F401]
33+
34+
has_quantized_ops = True
35+
except:
36+
has_quantized_ops = False
37+
print("Missing quantized ops")
38+
39+
40+
class TestPropagateCustomMetaPass(unittest.TestCase):
41+
class ModuleLinear(torch.nn.Module):
42+
def __init__(
43+
self,
44+
in_size: int = 2,
45+
input_channels: int = 4,
46+
output_channels: int = 4,
47+
dtype: torch.dtype = torch.float,
48+
use_bias: bool = False,
49+
):
50+
super().__init__()
51+
self.linear = torch.nn.Linear(
52+
input_channels, output_channels, bias=use_bias
53+
).to(dtype=dtype)
54+
55+
self.ic = input_channels
56+
self.oc = output_channels
57+
assert dtype in [torch.float, torch.half], "Unsupported op dtype"
58+
self.op_dtype = dtype
59+
self.in_size = in_size
60+
61+
def forward(self, x: torch.Tensor):
62+
return self.linear(x)
63+
64+
def get_random_inputs(self):
65+
inp = torch.randn(self.in_size, self.ic).to(self.op_dtype)
66+
return (inp,)
67+
68+
class Export(BaseStages.Export):
69+
def run(
70+
self,
71+
artifact: torch.nn.Module,
72+
inputs: Tuple[torch.Tensor],
73+
) -> None:
74+
75+
tagged_module = torch.export.export(
76+
artifact, inputs, dynamic_shapes=self.dynamic_shapes, strict=True
77+
).module()
78+
delegate_external_constants_pass_unlifted(
79+
module=tagged_module,
80+
gen_tag_fn=lambda x: "model", # This is the filename the weights will be saved to. In this case, weights will be saved as "model.ptd"
81+
)
82+
self.exported_program = torch.export.export(
83+
tagged_module, inputs, dynamic_shapes=self.dynamic_shapes, strict=True
84+
)
85+
86+
def _test_linear(
87+
self,
88+
partitioner: XnnpackPartitioner,
89+
quantization_stage: Union[BaseStages.Quantize, BaseStages.Quantize_],
90+
):
91+
eager_model = self.ModuleLinear(
92+
in_size=1,
93+
input_channels=32,
94+
output_channels=2,
95+
)
96+
test_inputs = eager_model.get_random_inputs()
97+
98+
tester = Tester(eager_model, test_inputs)
99+
tester.quantize(quantization_stage)
100+
tester.export(self.Export())
101+
tester.to_edge_transform_and_lower(
102+
ToEdgeTransformAndLower([partitioner])
103+
).to_executorch()
104+
tester.run_method_and_compare_outputs()
105+
106+
exec = tester.get_artifact()
107+
program_buffer = exec.buffer
108+
self.assertEqual(len(exec._tensor_data), 1)
109+
data_buffer = bytes(exec._tensor_data.pop("model"))
110+
self.assertTrue(len(data_buffer) > 200)
111+
from executorch.extension.pybindings import portable_lib as runtime
112+
113+
module = runtime._load_for_executorch_from_buffer(program_buffer, data_buffer)
114+
output = module.forward(test_inputs)
115+
reference_output = exec.exported_program().module()(
116+
test_inputs[0],
117+
)
118+
self.assertTrue(torch.allclose(output[0], reference_output, 1e-2))
119+
120+
# with self.assertRaises(RuntimeError):
121+
# runtime._load_for_executorch_from_buffer(program_buffer).forward(
122+
# test_inputs
123+
# )
124+
125+
def test_quantize_(self):
126+
# Quantize with torchao quantize_ API.
127+
DynamicallyQuantizedPartitioner = XnnpackPartitioner(
128+
config_precisions=ConfigPrecisionType.DYNAMIC_QUANT,
129+
per_op_mode=False,
130+
)
131+
linear_config = Int8DynamicActivationIntxWeightConfig(
132+
weight_dtype=torch.int4,
133+
weight_granularity=PerGroup(32),
134+
)
135+
self._test_linear(
136+
DynamicallyQuantizedPartitioner, BaseStages.Quantize_(config=linear_config)
137+
)
138+
139+
def test_pt2e_quantize(self):
140+
# Quantize with pt2e quantize.
141+
quant_configs = [
142+
# per_tensor
143+
get_symmetric_quantization_config(is_per_channel=False, is_dynamic=False),
144+
# per_channel
145+
get_symmetric_quantization_config(is_per_channel=True, is_dynamic=False),
146+
# per_channel_dynamic
147+
get_symmetric_quantization_config(is_per_channel=True, is_dynamic=True),
148+
]
149+
for quant_config in quant_configs:
150+
precision = (
151+
ConfigPrecisionType.DYNAMIC_QUANT
152+
if quant_config.input_activation.is_dynamic
153+
else ConfigPrecisionType.STATIC_QUANT
154+
)
155+
for per_op_mode in [True, False]:
156+
partitioner = XnnpackPartitioner(
157+
config_precisions=precision, per_op_mode=per_op_mode
158+
)
159+
self._test_linear(
160+
partitioner, XNNPackQuantize(quantization_config=quant_config)
161+
)

0 commit comments

Comments
 (0)