diff --git a/test/quantization/pt2e/test_arm_inductor_quantizer.py b/test/quantization/pt2e/test_arm_inductor_quantizer.py new file mode 100644 index 0000000000..750e88d451 --- /dev/null +++ b/test/quantization/pt2e/test_arm_inductor_quantizer.py @@ -0,0 +1,1505 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +# Owner(s): ["oncall: quantization"] +import copy +import itertools +import unittest +from enum import Enum + +import torch +import torch.nn as nn + +import torchao.quantization.pt2e.quantizer.arm_inductor_quantizer as armiq +from torchao.quantization.pt2e import ObserverBase +from torchao.quantization.pt2e.quantize_pt2e import ( + convert_pt2e, + prepare_pt2e, + prepare_qat_pt2e, +) +from torchao.quantization.pt2e.quantizer.arm_inductor_quantizer import ( + ArmInductorQuantizer, +) +from torchao.quantization.pt2e.quantizer.x86_inductor_quantizer import ( + QUANT_ANNOTATION_KEY, +) +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_7 + +if TORCH_VERSION_AT_LEAST_2_5: + from torch.export import export_for_training + +import functools +import platform + +from torch.testing._internal.common_quantization import ( + NodeSpec as ns, +) +from torch.testing._internal.common_quantization import ( + QuantizationTestCase, + skipIfNoInductorSupport, +) +from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo + + +def skipIfNoArm(fn): + reason = "Quantized operations require Arm." + if isinstance(fn, type): + if platform.processor() != "aarch64": + fn.__unittest_skip__ = True + fn.__unittest_skip_why__ = reason + return fn + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if platform.processor() != "aarch64": + raise unittest.SkipTest(reason) + else: + fn(*args, **kwargs) + + return wrapper + + +class NodePosType(Enum): + left = 1 + right = 2 + both = 3 + + +class TestHelperModules: + class SingleConv2dModule(torch.nn.Module): + def __init__(self, with_bn=False) -> None: + super().__init__() + self.conv = nn.Conv2d(3, 6, (2, 2), stride=(1, 1), padding=(1, 1)) + self.bn = torch.nn.BatchNorm2d(6) + self.with_bn = with_bn + + def forward(self, x): + x = self.conv(x) + if self.with_bn: + x = self.bn(x) + return x + + class Conv2dAddModule(torch.nn.Module): + def __init__( + self, + inplace_add: bool = False, + conv2d_type: NodePosType = NodePosType.left, + use_bias: bool = False, + with_bn: bool = False, + ) -> None: + super().__init__() + self.conv = torch.nn.Conv2d( + in_channels=3, + out_channels=3, + kernel_size=3, + stride=1, + padding=1, + bias=use_bias, + ) + self.conv2 = torch.nn.Conv2d( + in_channels=3, + out_channels=3, + kernel_size=3, + stride=1, + padding=1, + bias=use_bias, + ) + self.relu = nn.ReLU() + self.inplace_add = inplace_add + self.conv2d_type = conv2d_type + self.bn = torch.nn.BatchNorm2d(3) + self.with_bn = with_bn + + def forward(self, x): + if self.conv2d_type == NodePosType.left: + if self.inplace_add: + tmp = self.conv(x) + if self.with_bn: + tmp = self.bn(tmp) + tmp += self.relu(x) + return tmp + else: + tmp = self.conv(x) + if self.with_bn: + tmp = self.bn(tmp) + return tmp + self.relu(x) + elif self.conv2d_type == NodePosType.right: + if self.inplace_add: + tmp = self.relu(x) + tmp += self.conv(x) + return tmp + else: + return self.relu(x) + self.conv(x) + elif self.conv2d_type == NodePosType.both: + if self.inplace_add: + tmp = self.conv(x) + tmp += self.conv2(x) + return tmp + else: + return self.conv(x) + self.conv2(x) + + class Conv2dSingleOpPowModule(nn.Module): + def __init__(self, single_op): + super().__init__() + self.conv = nn.Conv2d(2, 2, 1) + self.single_op = single_op + + def forward(self, x): + x = self.conv(x) + x = self.single_op(x) + return torch.pow(x, 2) + + class SingleLinearModule(torch.nn.Module): + def __init__(self, use_bias) -> None: + super().__init__() + self.linear = nn.Linear(4, 4, bias=use_bias) + + def forward(self, x): + return self.linear(x) + + class LinearUnaryModule(torch.nn.Module): + def __init__( + self, use_bias, postop, inplace_postop=False, post_op_algo="none" + ) -> None: + super().__init__() + self.linear = nn.Linear(4, 4, bias=use_bias) + if postop == nn.GELU: + self.postop = postop(approximate=post_op_algo) + else: + self.postop = postop(inplace=inplace_postop) + + def forward(self, x): + return self.postop(self.linear(x)) + + class LinearAddModule(torch.nn.Module): + def __init__( + self, + inplace_add: bool = False, + linear_pos: NodePosType = NodePosType.left, + use_bias: bool = False, + ) -> None: + super().__init__() + self.linear = torch.nn.Linear( + in_features=16, out_features=16, bias=use_bias + ) + self.linear2 = torch.nn.Linear( + in_features=16, out_features=16, bias=use_bias + ) + self.relu = nn.ReLU() + self.inplace_add = inplace_add + self.linear_pos = linear_pos + + def forward(self, x): + if self.linear_pos == NodePosType.left: + if self.inplace_add: + tmp = self.linear(x) + tmp += self.relu(x) + return tmp + else: + tmp = self.linear(x) + return tmp + self.relu(x) + elif self.linear_pos == NodePosType.right: + if self.inplace_add: + tmp = self.relu(x) + tmp += self.linear(x) + return tmp + else: + return self.relu(x) + self.linear(x) + elif self.linear_pos == NodePosType.both: + if self.inplace_add: + tmp = self.linear(x) + tmp += self.linear2(x) + return tmp + else: + return self.linear(x) + self.linear2(x) + + class LinearAddModule2(torch.nn.Module): + def __init__( + self, + inplace_add: bool = False, + ) -> None: + super().__init__() + self.linear = torch.nn.Linear(in_features=16, out_features=16, bias=True) + self.linear2 = torch.nn.Linear(in_features=16, out_features=16, bias=True) + self.inplace_add = inplace_add + + def forward(self, x): + if self.inplace_add: + tmp = self.linear(x) + tmp += self.linear2(tmp) + return tmp + else: + tmp = self.linear(x) + return tmp + self.linear2(tmp) + + class Conv2dAddModule2(torch.nn.Module): + def __init__( + self, + inplace_add: bool = False, + ) -> None: + super().__init__() + self.conv = torch.nn.Conv2d( + in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1 + ) + self.conv2 = torch.nn.Conv2d( + in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1 + ) + self.inplace_add = inplace_add + self.bn = torch.nn.BatchNorm2d(3) + self.bn2 = torch.nn.BatchNorm2d(3) + + def forward(self, x): + if self.inplace_add: + tmp = self.bn(self.conv(x)) + tmp += self.bn2(self.conv2(tmp)) + return tmp + else: + tmp = self.bn(self.conv(x)) + return tmp + self.bn2(self.conv2(tmp)) + + class SelfAttnLikeModule(torch.nn.Module): + def __init__( + self, + input_dim, + transpose_for_score=False, + num_attention_heads=None, + attention_head_size=None, + ) -> None: + super().__init__() + self.input_dim = input_dim + self.q_proj = nn.Linear(input_dim, input_dim, bias=False) + self.k_proj = nn.Linear(input_dim, input_dim, bias=False) + self.v_proj = nn.Linear(input_dim, input_dim, bias=False) + self.softmax = nn.Softmax(dim=-1) + self.transpose_for_score = transpose_for_score + if self.transpose_for_score: + assert num_attention_heads is not None + assert attention_head_size is not None + self.num_attention_heads = num_attention_heads + self.attention_head_size = attention_head_size + + def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: + new_x_shape = x.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, x): + q = self.q_proj(x) + k = self.k_proj(x) + v = self.v_proj(x) + if self.transpose_for_score: + q = self.transpose_for_scores(q) + k = self.transpose_for_scores(k) + v = self.transpose_for_scores(v) + scores = torch.matmul(q, k.transpose(-1, -2)) / (self.input_dim**0.5) + attention = self.softmax(scores) + weighted = torch.matmul(attention, v) + return weighted + + +class ArmInductorQuantTestCase(QuantizationTestCase): + def _test_quantizer( + self, + model, + example_inputs, + quantizer, + expected_node_occurrence, + expected_node_list=None, + is_qat=False, + debug=False, + lower=False, + ): + m_eager = model.train() if is_qat else model.eval() + + # program capture + m = copy.deepcopy(m_eager) + m = export_for_training( + m, + example_inputs, + ).module() + + # QAT Model failed to deepcopy + export_model = m if is_qat else copy.deepcopy(m) + m = prepare_qat_pt2e(m, quantizer) if is_qat else prepare_pt2e(m, quantizer) + # Calibrate + m(*example_inputs) + prepare_model = copy.deepcopy(m) + m = convert_pt2e(m) + convert_model = copy.deepcopy(m) + if debug: + convert_model.print_readable(True) + if lower: + from torch._inductor.constant_folding import constant_fold + from torch._inductor.fx_passes.freezing_patterns import freezing_passes + + m.recompile() + freezing_passes(m, example_inputs) + constant_fold(m) + m(*example_inputs) + node_occurrence = { + ns.call_function(k): v for k, v in expected_node_occurrence.items() + } + if expected_node_list is None: + expected_node_list = [] + node_list = [ns.call_function(n) for n in expected_node_list] + self.checkGraphModuleNodes( + m, expected_node_occurrence=node_occurrence, expected_node_list=node_list + ) + + return export_model, prepare_model, convert_model + + +@skipIfNoInductorSupport +@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_7, "Requires torch 2.7+") +class TestQuantizePT2EArmInductor(ArmInductorQuantTestCase): + @skipIfNoArm + def test_conv2d(self): + """ + Test pattern of single conv2d with ArmInductorQuantizer. + """ + with torch.no_grad(): + m = TestHelperModules.SingleConv2dModule().eval() + example_inputs = (torch.randn(2, 3, 16, 16),) + quantizer = ArmInductorQuantizer().set_global( + armiq.get_default_arm_inductor_quantization_config() + ) + node_occurrence = { + # one for input and weight of the conv + torch.ops.quantized_decomposed.quantize_per_tensor.default: 1, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, + # note: quantize op for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 0, + } + node_list = [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv2d.default, + ] + self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + + @skipIfNoArm + def test_conv2d_binary(self): + """ + Test pattern of conv2d with binary post ops (such as add) with ArmInductorQuantizer. + Currently, only add as binary post op is supported. + """ + conv2d_type_list = [NodePosType.left, NodePosType.both] + example_inputs = (torch.randn(2, 3, 6, 6),) + quantizer = ArmInductorQuantizer().set_global( + armiq.get_default_arm_inductor_quantization_config() + ) + with torch.no_grad(): + for conv2d_type in conv2d_type_list: + m = TestHelperModules.Conv2dAddModule(conv2d_type=conv2d_type).eval() + if conv2d_type != NodePosType.both: + node_occurrence = { + # one for input and weight of the conv + # one for extra input node of add + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 0, + } + else: + node_occurrence = { + # one for input of the conv + # one for input of another conv + # 2 conv will share same input quant/dequant + # one for extra input node of add + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 5, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 0, + } + node_list = [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv2d.default, + torch.ops.aten.add.Tensor, + ] + self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + + @skipIfNoArm + def test_conv2d_binary2(self): + """ + Test Pattern: + tmp = conv2d_1(x) + tmp2 = conv2d_2(tmp) + return tmp + tmp2 + Since conv2d_1 has 2 users, we should annotate conv2d_2 for binary fusion instead of conv2d_1 + """ + example_inputs = (torch.randn(2, 3, 6, 6),) + quantizer = ArmInductorQuantizer().set_global( + armiq.get_default_arm_inductor_quantization_config() + ) + inplace_add_list = [True, False] + with torch.no_grad(): + for inplace_add in inplace_add_list: + m = TestHelperModules.Conv2dAddModule2(inplace_add=inplace_add).eval() + node_occurrence = { + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 5, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 0, + } + node_list = [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv2d.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + ( + torch.ops.aten.add_.Tensor + if inplace_add + else torch.ops.aten.add.Tensor + ), + ] + self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + + def _single_op_share_observer_recipe_test_helper(self, m, x, single_op): + quantizer = ArmInductorQuantizer().set_global( + armiq.get_default_arm_inductor_quantization_config() + ) + example_inputs = (x,) + node_occurrence = { + # one for input and weight of the conv, two for input/output for the maxpool2d + torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + } + node_list = [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv2d.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + single_op, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + ] + _, prepare_model, _ = self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + for node in prepare_model.graph.nodes: + if node.op == "call_function" and node.target is single_op: + single_op_node = node + input_obs_of_single_op = getattr( + prepare_model, single_op_node.args[0].target + ) + output_obs_of_single_op = getattr( + prepare_model, next(iter(single_op_node.users)).target + ) + elif ( + node.op == "call_function" + and node.target is torch.ops.aten.conv2d.default + ): + conv_node = node + input_obs_of_conv = getattr(prepare_model, conv_node.args[0].target) + self.assertTrue(isinstance(input_obs_of_single_op, ObserverBase)) + self.assertTrue(isinstance(output_obs_of_single_op, ObserverBase)) + self.assertTrue(isinstance(input_obs_of_conv, ObserverBase)) + self.assertTrue(input_obs_of_single_op is output_obs_of_single_op) + self.assertTrue(input_obs_of_single_op is not input_obs_of_conv) + + @skipIfNoArm + def test_linear(self): + """ + Test pattern of single linear with ArmInductorQuantizer. + """ + with torch.no_grad(): + for use_bias in [True, False]: + m = TestHelperModules.SingleLinearModule(use_bias).eval() + example_inputs = (torch.randn(2, 4),) + quantizer = ArmInductorQuantizer().set_global( + armiq.get_default_arm_inductor_quantization_config() + ) + node_occurrence = { + # one for input and weight, one for output + torch.ops.quantized_decomposed.quantize_per_tensor.default: 1, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 0, + } + node_list = [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.linear.default, + ] + self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + + def _test_linear_unary_helper( + self, + post_op_module, + post_op_aten, + post_op_aten_inplace, + post_op_algo_list=None, + is_qat=False, + is_dynamic=False, + ): + """ + Test pattern of linear with unary post ops (e.g. relu) with ArmInductorQuantizer. + """ + use_bias_list = [True, False] + # TODO test for inplace add after refactoring of export_for_training + inplace_list = [False] + if post_op_algo_list is None: + post_op_algo_list = [None] + cases = itertools.product(use_bias_list, inplace_list, post_op_algo_list) + with torch.no_grad(): + for use_bias, inplace, post_op_algo in cases: + if inplace and post_op_aten_inplace is None: + continue + m = TestHelperModules.LinearUnaryModule( + use_bias=use_bias, + postop=post_op_module, + inplace_postop=inplace, + post_op_algo=post_op_algo, + ).eval() + example_inputs = (torch.randn(2, 4),) + quantizer = ArmInductorQuantizer().set_global( + armiq.get_default_arm_inductor_quantization_config( + is_qat=is_qat, + is_dynamic=is_dynamic, + ) + ) + quantize_per_tensor_op = ( + torch.ops.quantized_decomposed.quantize_per_tensor.tensor + if is_dynamic + else torch.ops.quantized_decomposed.quantize_per_tensor.default + ) + dequantize_per_tensor_op = ( + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor + if is_dynamic + else torch.ops.quantized_decomposed.dequantize_per_tensor.default + ) + node_occurrence = { + # one for input of the linear + quantize_per_tensor_op: 1, + dequantize_per_tensor_op: 1 if is_dynamic else 2, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 0, + } + node_list = [ + quantize_per_tensor_op, + dequantize_per_tensor_op, + torch.ops.aten.linear.default, + post_op_aten_inplace if inplace else post_op_aten, + ] + self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + is_qat=is_qat, + ) + + @skipIfNoArm + def test_linear_unary(self): + aten = torch.ops.aten + self._test_linear_unary_helper(nn.ReLU, aten.relu.default, aten.relu_.default) + self._test_linear_unary_helper( + nn.LeakyReLU, aten.leaky_relu.default, aten.leaky_relu_.default + ) + self._test_linear_unary_helper( + nn.GELU, aten.gelu.default, None, ["none", "tanh"] + ) + + @skipIfNoArm + def test_linear_unary_qat(self): + aten = torch.ops.aten + self._test_linear_unary_helper( + nn.ReLU, aten.relu.default, aten.relu_.default, is_qat=True + ) + self._test_linear_unary_helper( + nn.LeakyReLU, aten.leaky_relu.default, aten.leaky_relu_.default, is_qat=True + ) + self._test_linear_unary_helper( + nn.GELU, aten.gelu.default, None, ["none", "tanh"], is_qat=True + ) + + @skipIfNoArm + def test_linear_unary_dynamic(self): + aten = torch.ops.aten + self._test_linear_unary_helper( + nn.ReLU, aten.relu.default, aten.relu_.default, is_dynamic=True + ) + self._test_linear_unary_helper( + nn.LeakyReLU, + aten.leaky_relu.default, + aten.leaky_relu_.default, + is_dynamic=True, + ) + self._test_linear_unary_helper( + nn.GELU, aten.gelu.default, None, ["none", "tanh"], is_dynamic=True + ) + + @skipIfNoArm + def test_linear_unary_dynamic_qat(self): + aten = torch.ops.aten + self._test_linear_unary_helper( + nn.ReLU, aten.relu.default, aten.relu_.default, is_qat=True, is_dynamic=True + ) + self._test_linear_unary_helper( + nn.LeakyReLU, + aten.leaky_relu.default, + aten.leaky_relu_.default, + is_qat=True, + is_dynamic=True, + ) + self._test_linear_unary_helper( + nn.GELU, + aten.gelu.default, + None, + ["none", "tanh"], + is_qat=True, + is_dynamic=True, + ) + + def _check_annotation_stat(self, gm, expected_stat_dict): + # Check expected annotation statistics to ensure the annotation is correct + + def _check_annotation(node): + annot = node.meta.get(QUANT_ANNOTATION_KEY, None) + if annot is None: + return False, False + return annot._annotated, annot._is_output_of_quantized_pattern + + for node in gm.graph.nodes: + if node.target in expected_stat_dict.keys(): + annotated, is_quant_out = _check_annotation(node) + expected_stat_dict[node.target]["annotated"] -= annotated + expected_stat_dict[node.target]["is_quant_out"] -= is_quant_out + for op_stat in expected_stat_dict.values(): + assert all(v == 0 for v in op_stat.values()) + + def _test_linear_binary_helper(self, is_qat=False, is_dynamic=False): + """ + Test pattern of linear with binary post ops (such as add) with ArmInductorQuantizer. + Currently, only add as binary post op is supported. + """ + linear_pos_list = [NodePosType.left, NodePosType.right, NodePosType.both] + # TODO test for inplace add after refactoring of export_for_training + inplace_add_list = [False] + example_inputs = (torch.randn(2, 16),) + quantizer = ArmInductorQuantizer().set_global( + armiq.get_default_arm_inductor_quantization_config( + is_qat=is_qat, + is_dynamic=is_dynamic, + ) + ) + quantize_per_tensor_op = ( + torch.ops.quantized_decomposed.quantize_per_tensor.tensor + if is_dynamic + else torch.ops.quantized_decomposed.quantize_per_tensor.default + ) + dequantize_per_tensor_op = ( + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor + if is_dynamic + else torch.ops.quantized_decomposed.dequantize_per_tensor.default + ) + cases = itertools.product(linear_pos_list, inplace_add_list) + with torch.no_grad(): + for linear_pos, inplace_add in cases: + m = TestHelperModules.LinearAddModule( + inplace_add=inplace_add, linear_pos=linear_pos + ).eval() + if linear_pos != NodePosType.both: + node_occurrence = { + # Only one 1 q-dq for input of the linear + # No q-dq for extra input node of add + quantize_per_tensor_op: 1, + dequantize_per_tensor_op: 1, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + } + else: + # convert_pt2e disables duplicate dequant for dynamic quant + num_dequant = 1 if is_dynamic else 2 + node_occurrence = { + # One quantize_per_tensor for both linear nodes (shared) + # Two dequantize_per_tensor for two linear nodes + # No q-dq for extra input node of add + quantize_per_tensor_op: 1, + dequantize_per_tensor_op: num_dequant, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, + } + node_list = [ + quantize_per_tensor_op, + dequantize_per_tensor_op, + torch.ops.aten.linear.default, + ( + torch.ops.aten.add_.Tensor + if inplace_add + else torch.ops.aten.add.Tensor + ), + ] + fq_m = self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + is_qat=is_qat, + )[-1] + # One linear and add are fused. The other linear is quantized alone if present + aten = torch.ops.aten + add_op = aten.add_.Tensor if inplace_add else aten.add.Tensor + expected_annotation_stat = { + aten.linear.default: { + "annotated": 2 if linear_pos == NodePosType.both else 1, + "is_quant_out": 1 if linear_pos == NodePosType.both else 0, + }, + add_op: {"annotated": 1, "is_quant_out": 1}, + } + self._check_annotation_stat(fq_m, expected_annotation_stat) + + @skipIfTorchDynamo("very slow") + @skipIfNoArm + def test_qat_conv2d(self): + """ + Test QAT pattern of conv2d_bn with ArmInductorQuantizer. + """ + m = TestHelperModules.SingleConv2dModule(with_bn=True) + example_inputs = (torch.randn(2, 3, 16, 16),) + quantizer = ArmInductorQuantizer().set_global( + armiq.get_default_arm_inductor_quantization_config(is_qat=True) + ) + node_occurrence = { + # one for input and weight of the conv, one for output for the conv + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3, + # note: quantize op for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 0, + # BN should be folded into Conv + torch.ops.aten._native_batch_norm_legit.default: 0, + } + node_list = [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv2d.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + ] + self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + is_qat=True, + ) + + @skipIfTorchDynamo("very slow") + @skipIfNoArm + def test_qat_conv2d_binary(self): + """ + Test qat pattern of conv2d_bn with binary post ops (such as add) with ArmInductorQuantizer. + Currently, only add as binary post op is supported. + """ + example_inputs = (torch.randn(2, 3, 6, 6),) + quantizer = ArmInductorQuantizer().set_global( + armiq.get_default_arm_inductor_quantization_config(is_qat=True) + ) + for inplace_add in [True, False]: + m = TestHelperModules.Conv2dAddModule(inplace_add=inplace_add, with_bn=True) + node_occurrence = { + # one for input and weight of the conv + # one for output for the add + # one for extra input node of add + torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 0, + # BN should be folded into Conv + torch.ops.aten._native_batch_norm_legit.default: 0, + } + node_list = [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv2d.default, + ( + torch.ops.aten.add_.Tensor + if inplace_add + else torch.ops.aten.add.Tensor + ), + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + ] + self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + is_qat=True, + ) + + @skipIfTorchDynamo("very slow") + @skipIfNoArm + def test_qat_conv2d_binary2(self): + """ + Test qat Pattern: + tmp = bn1(conv2d_1(x)) + tmp2 = bn2(conv2d_2(tmp)) + return tmp + tmp2 + Since conv2d_1 has 2 users, we should annotate conv2d_2 for binary fusion instead of conv2d_1 + """ + example_inputs = (torch.randn(2, 3, 6, 6),) + quantizer = ArmInductorQuantizer().set_global( + armiq.get_default_arm_inductor_quantization_config(is_qat=True) + ) + inplace_add_list = [True, False] + with torch.no_grad(): + for inplace_add in inplace_add_list: + m = TestHelperModules.Conv2dAddModule2(inplace_add=inplace_add) + node_occurrence = { + torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 6, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 0, + # BN should be folded into Conv + torch.ops.aten._native_batch_norm_legit.default: 0, + } + node_list = [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv2d.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + ( + torch.ops.aten.add_.Tensor + if inplace_add + else torch.ops.aten.add.Tensor + ), + ] + self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + is_qat=True, + ) + + @skipIfNoArm + def test_dynamic_quant_linear(self): + """ + Test pattern of dynamic quantization of linear with ArmInductorQuantizer. + """ + with torch.no_grad(): + m = TestHelperModules.SelfAttnLikeModule(input_dim=64).eval() + example_inputs = (torch.randn(1, 4, 64),) + quantizer = ArmInductorQuantizer().set_global( + armiq.get_default_arm_inductor_quantization_config(is_dynamic=True) + ) + node_occurrence = { + torch.ops.quantized_decomposed.choose_qparams.tensor: 1, + torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 1, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 1, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 0, + } + node_list = [ + torch.ops.quantized_decomposed.choose_qparams.tensor, + torch.ops.quantized_decomposed.quantize_per_tensor.tensor, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, + torch.ops.aten.linear.default, + ] + self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + + @skipIfNoArm + def test_qat_dynamic_quant_linear(self): + """ + Test pattern of qat dynamic quantization of linear with ArmInductorQuantizer. + """ + with torch.no_grad(): + m = TestHelperModules.SelfAttnLikeModule(input_dim=64).eval() + example_inputs = (torch.randn(1, 4, 64),) + quantizer = ArmInductorQuantizer().set_global( + armiq.get_default_arm_inductor_quantization_config( + is_qat=True, is_dynamic=True + ) + ) + node_occurrence = { + torch.ops.quantized_decomposed.choose_qparams.tensor: 1, + torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 1, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 1, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 0, + } + node_list = [ + torch.ops.quantized_decomposed.choose_qparams.tensor, + torch.ops.quantized_decomposed.quantize_per_tensor.tensor, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, + torch.ops.aten.linear.default, + ] + self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + is_qat=True, + ) + + @skipIfNoArm + def test_set_module_name_qconfig(self): + """Test case for quantizing a specific submodule by configuring `set_module_name_qconfig`. + Expect that all linear layers within the submodule `sub` are quantized. + """ + + class Sub(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear1 = torch.nn.Linear(5, 10) + self.relu1 = torch.nn.ReLU(inplace=False) + self.linear2 = torch.nn.Linear(10, 5) + + def forward(self, x): + x = self.linear1(x) + x = self.relu1(x) + x = self.linear2(x) + return x + + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(5, 5) + self.sub = Sub() + + def forward(self, x): + x = self.linear(x) + x = self.sub(x) + return x + + m = M().eval() + example_inputs = (torch.randn(3, 5),) + # Set global to `None` and then default config for a specific submodule. + quantizer = ArmInductorQuantizer() + quantizer.set_module_name_qconfig( + "sub", armiq.get_default_arm_inductor_quantization_config() + ) + node_occurrence = { + torch.ops.aten.linear.default: 3, + # quantize and dequantize the input of two linear layers from `sub` + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4, + # dequantize the weight of two linear layers from `sub` + torch.ops.quantized_decomposed.dequantize_per_channel.default: 0, + } + node_list = [ + # first linear is not quantized + torch.ops.aten.linear.default, + # two Q/DQ pairs for two linear layers from `sub` + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.linear.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.linear.default, + ] + self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + + @skipIfNoArm + def test_set_module_name_qconfig_with_underscores(self) -> None: + """Test that if a module name has an underscore, we can still quantize it.""" + + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + # This module name has underscores, which can be part of a mangled name. + self.foo_bar = torch.nn.Linear(2, 2) + self.baz = torch.nn.Linear(2, 2) + + def forward(self, x): + return self.baz(self.foo_bar(x)) + + # Set global to no quantization and then default config for a specific submodule whose name includes an underscore. + quantizer = ArmInductorQuantizer() + quantizer.set_module_name_qconfig( + "foo_bar", armiq.get_default_arm_inductor_quantization_config() + ) + example_inputs = (torch.randn(2, 2),) + m = M().eval() + m = export_for_training(m, example_inputs).module() + m = prepare_pt2e(m, quantizer) + # Use a linear count instead of names because the names might change, but + # the order should be the same. + count = 0 + for n in m.graph.nodes: + if n.op == "call_function" and n.target == torch.ops.aten.linear.default: + # Get the weight observer to see the per-channel vs per-tensor. + weight_observer_node = n.args[1] + if count == 0: + # for foo_bar. + self.assertEqual( + weight_observer_node.op, + "call_module", + f"The op of linear({count})'s weight_observer_node is {weight_observer_node.op} instead call_module", + ) + observer_instance = getattr(m, weight_observer_node.target) + self.assertEqual( + observer_instance.qscheme, torch.per_tensor_symmetric + ) + else: + # For baz it should have no observer at all. + self.assertNotEqual( + weight_observer_node.op, + "call_module", + f"The op of linear({count})'s weight_observer_node is {weight_observer_node.op} instead call_module", + ) + count += 1 + + @skipIfNoArm + def test_set_module_name_and_module_type_case1(self): + """Test that set `module_name_qconfig` and `module_type_qconfig` at the same time. + Expect that all linear layers are not quantized except the last one. + """ + + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear1 = torch.nn.Linear(5, 10) + self.linear2 = torch.nn.Linear(10, 5) + self.sub = torch.nn.Linear(5, 5) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + x = self.sub(x) + return x + + m = M().eval() + example_inputs = (torch.randn(3, 5),) + # Set `sub` with default config and then `None` for all `Linear`. + # The config set by `set_module_name_qconfig` has higher priority than `set_module_type_qconfig`. + quantizer = ArmInductorQuantizer() + quantizer.set_module_name_qconfig( + "sub", armiq.get_default_arm_inductor_quantization_config() + ).set_module_type_qconfig(torch.nn.Linear, None) + + node_occurrence = { + torch.ops.aten.linear.default: 3, + # quantize and dequantize the input of the last linear + torch.ops.quantized_decomposed.quantize_per_tensor.default: 1, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, + # dequantize the weight of the last linear + torch.ops.quantized_decomposed.dequantize_per_channel.default: 0, + } + node_list = [ + # first and second linear are not quantized + torch.ops.aten.linear.default, + torch.ops.aten.linear.default, + # last linear is quantized + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.linear.default, + ] + self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + + @skipIfNoArm + def test_set_module_name_and_module_type_case2(self): + """Test that set `module_name_qconfig` and `module_type_qconfig` at the same time. + Expect that all linear layers are quantized except the last one. + """ + + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear1 = torch.nn.Linear(5, 10) + self.linear2 = torch.nn.Linear(10, 5) + self.sub = torch.nn.Linear(5, 5) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + x = self.sub(x) + return x + + m = M().eval() + example_inputs = (torch.randn(3, 5),) + # Set `sub` with None and then default config for a all `Linear`. + quantizer = ArmInductorQuantizer() + quantizer.set_module_name_qconfig("sub", None).set_module_type_qconfig( + torch.nn.Linear, armiq.get_default_arm_inductor_quantization_config() + ) + + node_occurrence = { + torch.ops.aten.linear.default: 3, + # quantize and dequantize the input and output of the first and second linear + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4, + # dequantize the weight of the first and second linear + torch.ops.quantized_decomposed.dequantize_per_channel.default: 0, + } + node_list = [ + # Q/DQ for first lienar + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.linear.default, + # Q/DQ for second lienar + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.linear.default, + # last linear is not quantized + torch.ops.aten.linear.default, + ] + self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + + @skipIfNoArm + def test_set_module_name_qconfig_for_dynamic_quant(self): + """Test that quantize a specific submodule for dynamic quantization.""" + + with torch.no_grad(): + for is_qat in [False, True]: + m = TestHelperModules.SelfAttnLikeModule(input_dim=64).eval() + example_inputs = (torch.randn(1, 4, 64),) + # only quantize `q_proj` `v_proj` + dynamic_config = armiq.get_default_arm_inductor_quantization_config( + is_dynamic=True, is_qat=is_qat + ) + quantizer = ( + ArmInductorQuantizer() + .set_module_name_qconfig("q_proj", dynamic_config) + .set_module_name_qconfig("v_proj", dynamic_config) + ) + node_occurrence = { + # quantize and dequantize the input + torch.ops.quantized_decomposed.choose_qparams.tensor: 1, + torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 1, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 1, + # dequantize the weight of q_proj and v_proj + torch.ops.quantized_decomposed.dequantize_per_channel.default: 0, + } + node_list = [ + # quantize and dequantize the input + torch.ops.quantized_decomposed.choose_qparams.tensor, + torch.ops.quantized_decomposed.quantize_per_tensor.tensor, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, + # q_proj + torch.ops.aten.linear.default, + # k_proj + torch.ops.aten.linear.default, + # v_proj + torch.ops.aten.linear.default, + ] + self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + is_qat=is_qat, + ) + + @skipIfNoArm + def test_set_module_name_with_mixed_configs(self): + """Test case for setting module names with mixed static/dynamic or QAT/non-QAT configurations. + The config for 'v_proj' will always be ignored and raise a warning. + """ + with torch.no_grad(): + with self.assertWarns(UserWarning) as context: + for q_is_dynamic, v_is_dynamic, q_is_qat, v_is_qat in itertools.product( + [False, True], repeat=4 + ): + if q_is_dynamic == v_is_dynamic and q_is_qat == v_is_qat: + continue + m = TestHelperModules.SelfAttnLikeModule(input_dim=64).eval() + example_inputs = (torch.randn(1, 4, 64),) + quantizer = ( + ArmInductorQuantizer() + .set_module_name_qconfig( + "q_proj", + armiq.get_default_arm_inductor_quantization_config( + is_qat=q_is_qat, is_dynamic=q_is_dynamic + ), + ) + .set_module_name_qconfig( + "v_proj", + armiq.get_default_arm_inductor_quantization_config( + is_qat=v_is_qat, is_dynamic=v_is_dynamic + ), + ) + ) + quant_op = ( + torch.ops.quantized_decomposed.quantize_per_tensor.tensor + if q_is_dynamic + else torch.ops.quantized_decomposed.quantize_per_tensor.default + ) + dequant_op = ( + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor + if q_is_dynamic + else torch.ops.quantized_decomposed.dequantize_per_tensor.default + ) + node_occurrence = { + # quantize and dequantize the input + quant_op: 1, + dequant_op: 1 if q_is_dynamic else 2, + # only `q_proj` was quantized, dequantize its weight + torch.ops.quantized_decomposed.dequantize_per_channel.default: 0, + } + node_list = [ + # quantize and dequantize the input + quant_op, + dequant_op, + # q_proj + torch.ops.aten.linear.default, + # k_proj/v_proj + torch.ops.aten.linear.default, + torch.ops.aten.linear.default, + ] + self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + is_qat=q_is_qat, + ) + warning_msg = ( + "Mixed QAT and Non-QAT" + if q_is_qat != v_is_qat + else "Mixed dynamic and static" + ) + self.assertTrue( + any( + warning_msg in msg + for msg in [str(w.message) for w in context.warnings] + ) + ) + + @skipIfNoArm + def test_set_module_name_and_module_type_with_mixed_configs(self): + """Test that set `module_name_qconfig` and `module_type_qconfig` at the same time with mixed the configs. + Expect that only the last linear(`sub`) is quantized using static quantization. + """ + + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear1 = torch.nn.Linear(5, 10) + self.linear2 = torch.nn.Linear(10, 5) + self.sub = torch.nn.Linear(5, 5) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + x = self.sub(x) + return x + + m = M().eval() + example_inputs = (torch.randn(3, 5),) + # Set `sub` with static config and then dynamic config for a all `Linear`(ignored). + quantizer = ArmInductorQuantizer() + quantizer.set_module_name_qconfig( + "sub", armiq.get_default_arm_inductor_quantization_config(is_dynamic=False) + ).set_module_type_qconfig( + torch.nn.Linear, + armiq.get_default_arm_inductor_quantization_config(is_dynamic=True), + ) + + node_occurrence = { + torch.ops.aten.linear.default: 3, + # quantize and dequantize the input of the last linear + torch.ops.quantized_decomposed.quantize_per_tensor.default: 1, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, + # dequantize the weight of the last linear + torch.ops.quantized_decomposed.dequantize_per_channel.default: 0, + } + node_list = [ + # first and second linear are not quantized + torch.ops.aten.linear.default, + torch.ops.aten.linear.default, + # Q/DQ pairs for the last linear + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.linear.default, + ] + self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + + @skipIfNoArm + def test_filter_linear_recipe(self): + """ + Test removing linear from default recipe of ArmInductorQuantizer. + """ + with torch.no_grad(): + m = TestHelperModules.LinearUnaryModule( + use_bias=True, + postop=nn.ReLU, + ).eval() + example_inputs = (torch.randn(2, 4),) + quantizer = ArmInductorQuantizer().set_global( + armiq.get_default_arm_inductor_quantization_config() + ) + quantizer.set_function_type_qconfig(torch.nn.functional.linear, None) + node_occurrence = { + # one for input and weight of the conv + torch.ops.quantized_decomposed.quantize_per_tensor.default: 0, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 0, + # note: quantize op for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 0, + } + node_list = [ + torch.ops.aten.linear.default, + torch.ops.aten.relu.default, + ] + self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + + @skipIfNoArm + def test_attention_block(self): + """ + Test pattern of Attention like Block with ArmInductorQuantizer. + """ + for annotate_matmul in [False, True]: + with torch.no_grad(): + m = TestHelperModules.SelfAttnLikeModule( + input_dim=64 * 16, + transpose_for_score=True, + num_attention_heads=16, + attention_head_size=64, + ).eval() + example_inputs = (torch.randn(2, 384, 1024),) + + m(*example_inputs) + + quantizer = ArmInductorQuantizer().set_global( + armiq.get_default_arm_inductor_quantization_config() + ) + + if annotate_matmul: + quantizer.set_function_type_qconfig( + torch.matmul, quantizer.get_global_quantization_config() + ) + + node_occurrence = { + torch.ops.quantized_decomposed.quantize_per_tensor.default: ( + 5 if annotate_matmul else 1 + ), + torch.ops.quantized_decomposed.dequantize_per_tensor.default: ( + 10 if annotate_matmul else 6 + ), + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 0, + } + if annotate_matmul: + node_list = [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.linear.default, + torch.ops.aten.view.default, + torch.ops.aten.permute.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.matmul.default, + torch.ops.aten.div.Tensor, + torch.ops.aten.softmax.int, + ] + else: + node_list = [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.linear.default, + torch.ops.aten.view.default, + torch.ops.aten.permute.default, + torch.ops.aten.matmul.default, + torch.ops.aten.div.Tensor, + torch.ops.aten.softmax.int, + ] + self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/torchao/quantization/pt2e/quantizer/arm_inductor_quantizer.py b/torchao/quantization/pt2e/quantizer/arm_inductor_quantizer.py new file mode 100644 index 0000000000..af0c04a79d --- /dev/null +++ b/torchao/quantization/pt2e/quantizer/arm_inductor_quantizer.py @@ -0,0 +1,396 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +# mypy: allow-untyped-defs +import functools +import operator +import warnings +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional + +import torch +import torch.nn.functional as F +from torch.fx import Node +from typing_extensions import TypeAlias + +from torchao.quantization.pt2e.fake_quantize import ( + FakeQuantize, + FusedMovingAvgObsFakeQuantize, +) +from torchao.quantization.pt2e.observer import ( + HistogramObserver, + MinMaxObserver, + MovingAverageMinMaxObserver, + PlaceholderObserver, +) +from torchao.quantization.pt2e.quantizer import ( + QuantizationConfig, + get_module_name_filter, +) +from torchao.quantization.pt2e.quantizer.quantizer import ( + QuantizationAnnotation, + QuantizationSpec, +) + +from .x86_inductor_quantizer import ( + X86InductorQuantizer, +) + +FilterFn: TypeAlias = Callable[[List[Node]], bool] + + +if TYPE_CHECKING: + from torchao.quantization.pt2e import _ObserverOrFakeQuantizeConstructor + +__all__ = [ + "ArmInductorQuantizer", + "get_default_arm_inductor_quantization_config", +] + + +@dataclass +class _ArmInductorQuantizationAnnotation(QuantizationAnnotation): + # _is_output_of_quantized_pattern: + # * Node as output node of a fusion pattern. + # * The fusion pattern supports int8 data type. + # * The fusion pattern has inputs annotated to insert observer. + # * The quantization_config is not `None`. + _is_output_of_quantized_pattern: bool = False + + +# Operators support the int8 data type +# and recipe is configured by default in ArmInductorQuantizer. +default_quantizable_ops = { + torch.ops.aten.conv2d.default, + torch.ops.aten.linear.default, +} + +# A superset of default_quantizable_ops includes operators support the int8 data type +# but not enabled by default recipe of ArmInductorQuantizer. +quantizable_ops = default_quantizable_ops | { + torch.ops.aten.matmul.default, +} + + +def _create_module_name_filter(module_name: str) -> FilterFn: + """Create a filter function for a given module name. + + The filter function takes a list of nodes (as determined by the annotate function) + and return True if *all* nodes come from the specified module name, False otherwise. + + For example: + linear_1: "f32[3, 10]" = torch.ops.aten.linear.default(...) # comes from a module with name `sub.linear1` + relu: "f32[3, 10]" = torch.ops.aten.relu.default(linear_1); # comes from a module with name `sub.relu1` + + >> module_name_filter = _create_module_name_filter_inner("sub") + >> print(module_name_filter([relu, linear_1])) + # True # These two nodes are determined by `_annotate_linear_unary` function and from "sub". + """ + + filter_fn = get_module_name_filter(module_name) + + def check_all_nodes_from_module(nodes: list[Node]) -> bool: + all_nodes_from_module_name: bool = all(filter_fn(n) for n in nodes) + return all_nodes_from_module_name + + return check_all_nodes_from_module + + +def _create_operator_type_filter( + operator_type: Callable, +) -> FilterFn: + """Create a filter function for a given operator type. + + The filter function takes a list of nodes and returns True if it contains + exactly one node with the specified operator type, False otherwise. + + For example: + linear_1: "f32[3, 10]" = torch.ops.aten.linear.default(...) # comes from a module with name `sub.linear1` + relu: "f32[3, 10]" = torch.ops.aten.relu.default(linear_1); # comes from a module with name `sub.relu1` + + >> operator_type_filter = _create_operator_type_filter(torch.ops.aten.linear.default) + >> print(operator_type_filter([relu, linear_1])) + # True # These two nodes are determined by `_annotate_linear_unary` function and the second node is `linear`. + """ + + def operator_type_filter(nodes: list[Node]): + num_nodes_with_operator_type = sum( + node.target == operator_type for node in nodes + ) + if num_nodes_with_operator_type > 1: + raise NotImplementedError( + f"Several nodes within a single pattern are {operator_type}." + ) + return num_nodes_with_operator_type == 1 + + return operator_type_filter + + +def _global_config_filter(nodes: List[Node]) -> bool: + """Filter function for global configuration. + + This filter function takes a list of nodes and returns True if there is exactly one node + in the list that is a default quantizable operation, False otherwise. + """ + num_nodes_in_default_quantizable_ops = sum( + node.target in default_quantizable_ops for node in nodes + ) + if num_nodes_in_default_quantizable_ops > 1: + raise NotImplementedError( + "Several nodes within a single pattern are default quantizable operations." + ) + return num_nodes_in_default_quantizable_ops == 1 + + +def _map_module_function_to_aten_operator_type(): + module_function_to_aten_operator: Dict[Callable, torch._ops.OpOverloadPacket] = {} + map_list = ( + ([torch.nn.Conv2d, F.conv2d], torch.ops.aten.conv2d.default), + ([torch.nn.Linear, F.linear], torch.ops.aten.linear.default), + ( + [ + torch.matmul, + ], + torch.ops.aten.matmul.default, + ), + ) + for map_item in map_list: + module_function_to_aten_operator.update(dict.fromkeys(map_item[0], map_item[1])) # type: ignore[arg-type, call-overload] + return module_function_to_aten_operator + + +@functools.lru_cache +def get_default_arm_inductor_quantization_config( + is_qat: bool = False, + is_dynamic: bool = False, +): + extra_args: Dict[str, Any] = {"eps": 2**-12} + if is_qat: + if is_dynamic: + act_observer_or_fake_quant_ctr = FakeQuantize + dynamic_quant_observer = MovingAverageMinMaxObserver.with_args( + averaging_constant=1 + ) + extra_args["observer"] = dynamic_quant_observer + else: + act_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize # type: ignore[assignment] + else: + if is_dynamic: + act_observer_or_fake_quant_ctr = PlaceholderObserver # type: ignore[assignment] + else: + act_observer_or_fake_quant_ctr = HistogramObserver # type: ignore[assignment] + # check for the qconfig ------------------------- + act_quantization_spec = QuantizationSpec( + dtype=torch.int8, + quant_min=-128, + quant_max=127, + qscheme=torch.per_tensor_affine, + is_dynamic=is_dynamic, + observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args( + **extra_args + ), + ) + + weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = ( + FusedMovingAvgObsFakeQuantize if is_qat else MinMaxObserver + ) + + if is_qat: + # Only support per tensor quant for now + extra_args["observer"] = MovingAverageMinMaxObserver # type: ignore[dict-item] + weight_quantization_spec = QuantizationSpec( + dtype=torch.int8, + quant_min=-128, + quant_max=127, + qscheme=torch.per_tensor_symmetric, + is_dynamic=False, + observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args( + **extra_args + ), + ) + bias_quantization_spec = None # will use placeholder observer by default + quantization_config = QuantizationConfig( + act_quantization_spec, + act_quantization_spec, + weight_quantization_spec, + bias_quantization_spec, + is_qat, + ) + return quantization_config + + +def _config_checker(method: Callable) -> Callable: + @functools.wraps(method) + def wrapper( + quantizer: "ArmInductorQuantizer", + name: Any, + quantization_config: Optional["QuantizationConfig"], + ) -> "ArmInductorQuantizer": + if quantizer._need_skip_config(quantization_config): + warnings.warn( + f"Skip the quantization config for {name}.", + ) + return quantizer + return method(quantizer, name, quantization_config) + + return wrapper + + +class ArmInductorQuantizer(X86InductorQuantizer): + module_function_to_aten_operator_type = _map_module_function_to_aten_operator_type() + + def get_global_quantization_config(self): + if not isinstance(self.global_config, QuantizationConfig): + warnings.warn( + "The global_config for ArmInductorQuantizer is currently invalid. \ + Please ensure that you use set_global to establish the global quantization configuration." + ) + return self.global_config + + @_config_checker + def set_function_type_qconfig( + self, + function_type: Callable, + quantization_config: Optional[QuantizationConfig], + ) -> "ArmInductorQuantizer": + if function_type in ArmInductorQuantizer.module_function_to_aten_operator_type: + self._set_aten_operator_qconfig( + ArmInductorQuantizer.module_function_to_aten_operator_type[ + function_type + ], + quantization_config, + ) + else: + warnings.warn( + f"function: Unable to customize quantization config for {function_type} by ArmInductorQuantizer." + ) + return self + + @_config_checker + def set_module_type_qconfig( + self, + module_type: torch.nn.Module, + quantization_config: Optional[QuantizationConfig], + ) -> "ArmInductorQuantizer": + if module_type in ArmInductorQuantizer.module_function_to_aten_operator_type: + self._set_aten_operator_qconfig( + ArmInductorQuantizer.module_function_to_aten_operator_type[module_type], + quantization_config, + ) + else: + warnings.warn( + f"Module: Unable to customize quantization config for {module_type} by ArmInductorQuantizer." + ) + return self + + @_config_checker + def set_module_name_qconfig( + self, module_name: str, quantization_config: Optional[QuantizationConfig] + ): + """Set quantization_config for a submodule with name: `module_name`, for example: + quantizer.set_module_name_qconfig("blocks.sub"), it will quantize all supported operator/operator + patterns in the submodule with this module name with the given `quantization_config` + + The supported operators include `quantizable_ops` only. + """ + self.module_name_qconfig[module_name] = quantization_config + return self + + def _set_aten_operator_qconfig( + self, + operator_type: torch._ops.OpOverloadPacket, + quantization_config: Optional[QuantizationConfig], + ) -> "ArmInductorQuantizer": + if operator_type in quantizable_ops: + self.operator_type_qconfig[operator_type] = quantization_config + else: + warnings.warn( + f"operator: Unable to quantize {operator} by ArmInductorQuantizer." + ) + return self + + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: + """Annotate the given model with quantization configurations. + + Annotation contracts: + 1. Annotate each node according to the user's qconfig in the following order: + `module_name_qconfig`, `operator_type_qconfig`, and `global_config`. + 2. Avoid re-annotating nodes already annotated in prior stages. For example, + if `linear1` has been annotated by `module_name_qconfig`, it won't be annotated again + during the processing of the 'operator_type_qconfig' or 'global_config'. + 3. For config is `None`, the node will be annotated with `_ArmInductorQuantizationAnnotation(_annotated=True)`. + + For each pair of (module_name_or_operator_type_or_global, qconfig), a filter function is created. + This filter function checks if the node is marked by current stage and not annotated by the previous stage. + """ + for module_name, quantization_config in self.module_name_qconfig.items(): + self._annotate_with_config( + model, quantization_config, _create_module_name_filter(module_name) + ) + + for operator_type, quantization_config in self.operator_type_qconfig.items(): + self._annotate_with_config( + model, quantization_config, _create_operator_type_filter(operator_type) + ) + + if self.global_config: + self._annotate_with_config( + model, + self.global_config, + _global_config_filter, + ) + + return model + + def _annotate_with_config( + self, + model: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: FilterFn, + ) -> None: + """Annotate the model with the given quantization configuration. + + High-level description of quantization recipe for Arm Inductor Backend: + Apply quantization recipe for fusion patterns of conv/linear to enable int8 data type actively. + """ + + # Step1: Recipe of fusion patterns like conv/linear. + self._annotate_conv2d_fusion_pattern(model, quantization_config, filter_fn) + self._annotate_linear_fusion_pattern(model, quantization_config, filter_fn) + self._annotate_matmul(model, quantization_config, filter_fn) + + def _annotate_qat_conv2d_fusion_pattern( + self, + model: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, + ): + # Annotate QAT Specific patterns + self._annotate_qat_conv2d_bn_binary(model, quantization_config, filter_fn) + self._annotate_qat_conv2d_bn(model, quantization_config, filter_fn) + + def _annotate_conv2d_fusion_pattern( + self, + model: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, + ): + if (quantization_config is None) or (quantization_config.is_qat): + # Annotate QAT specific pattern: mainly due to BN not folded in prepare_qat + self._annotate_qat_conv2d_fusion_pattern( + model, quantization_config, filter_fn + ) + self._annotate_conv2d_binary(model, quantization_config, filter_fn) + self._annotate_conv2d(model, quantization_config, filter_fn) + + def _annotate_linear_fusion_pattern( + self, + model: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, + ): + self._annotate_linear_unary(model, quantization_config, filter_fn) + self._annotate_linear(model, quantization_config, filter_fn)