From f5446a8a95eb11a6f78b8dad537102ff399981b5 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 10 Sep 2024 10:18:33 +0100 Subject: [PATCH] Cleanup --- src/brevitas/export/inference/handler.py | 90 ++++++++++++++++--- src/brevitas/export/inference/manager.py | 13 ++- src/brevitas/proxy/parameter_quant.py | 6 -- .../ptq/ptq_evaluate.py | 10 ++- .../stable_diffusion/main.py | 4 +- .../test_torchvision_models.py | 53 ++++++----- 6 files changed, 123 insertions(+), 53 deletions(-) diff --git a/src/brevitas/export/inference/handler.py b/src/brevitas/export/inference/handler.py index 0be56bc5b..bd81cadc3 100644 --- a/src/brevitas/export/inference/handler.py +++ b/src/brevitas/export/inference/handler.py @@ -1,3 +1,8 @@ +# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +from abc import ABC +from abc import abstractmethod from typing import Tuple import torch @@ -14,9 +19,26 @@ from brevitas.utils.torch_utils import float_internal_scale -class IntInferencetHandler(torch.nn.Module): - handled_layer = ( - ActQuantProxyFromInjector, WeightQuantProxyFromInjector, BiasQuantProxyFromInjector) +class InferenceHandler(torch.nn.Module, ABC): + + def attach_debug_info(self, module): + pass + + @abstractmethod + def prepare_for_export(self, module): + pass + + @abstractmethod + def quantize(self, x): + pass + + @abstractmethod + def dequantize(self, x): + pass + + +class IntInferencetHandler(InferenceHandler): + handled_layer = (ActQuantProxyFromInjector, BiasQuantProxyFromInjector) def attach_debug_info(self, module): pass @@ -29,22 +51,38 @@ def prepare_for_export(self, module): self.min_clamp = min_int(module.is_signed, module.is_narrow_range, self.bit_width) self.max_clamp = max_int(module.is_signed, module.is_narrow_range, self.bit_width) - def quant(self, x): + def quantize(self, x): return torch.clamp( torch.round(x / self.scale + self.zero_point), self.min_clamp, self.max_clamp) - def dequant(self, x): + def dequantize(self, x): return (x - self.zero_point) * self.scale def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]: - return self.dequant(self.quant(x)), self.scale, self.zero_point, self.bit_width + return self.dequantize(self.quantize(x)), self.scale, self.zero_point, self.bit_width -class FloatInferencetHandler(IntInferencetHandler): - handled_layer = (ActFloatQuantProxyFromInjector, WeightFloatQuantProxyFromInjector) +class IntWeightInferencetHandler(IntInferencetHandler): + handled_layer = WeightQuantProxyFromInjector + + def prepare_for_export(self, module): + if module.is_quant_enabled: + self.cached_weight = None + if module._cached_weight is not None and not module.cache_inference_quant_weight_metadata_only: + self.cached_weight = module._cached_weight + else: + super.prepare_for_export(module) + + def forward(self, x) -> Tuple[torch.Tensor]: + if self.cached_weight is not None: + x = self.cached_weight + else: + x = self.dequantize(self.quantize(x)) + return x, self.scale, self.zero_point, self.bit_width - def attach_debug_info(self, module): - pass + +class FloatInferencetHandler(InferenceHandler): + handled_layer = (ActFloatQuantProxyFromInjector, BiasQuantProxyFromInjector) def prepare_for_export(self, module): if module.is_quant_enabled: @@ -72,20 +110,46 @@ def prepare_for_export(self, module): self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias) self.min_value = torch.tensor(0.) if not module.is_signed else -self.max_value - def quant(self, x): + def quantize(self, x): # Compute masks inf_mask = x.isinf() p_max_val_mask = x > self.max_value n_max_val_mask = -x > self.max_value + + # Quantize x = x / self.scale internal_scale = float_internal_scale( x, self.mantissa_bit_width, self.fp_internal_scale_min, self.eps) x = internal_scale * self.float_to_int_impl(x / internal_scale) + + # Clamp x = self.float_clamp_impl.saturating_clamp(x, self.max_value, self.min_value) if not self.saturating: x = self.float_clamp_impl.inf_nan_clamp(x, inf_mask, p_max_val_mask, n_max_val_mask) return x - def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]: - return self.dequant(self.quant(x)), self.scale, self.zero_point, self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias, self.saturating, self.inf_values, self.nan_values + def dequantize(self, x): + return (x - self.zero_point) * self.scale + + def forward(self, x) -> Tuple[torch.Tensor]: + return self.dequantize(self.quantize(x)), self.scale, self.zero_point, self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias, self.saturating, self.inf_values, self.nan_values + + +class FloatWeightInferencetHandler(FloatInferencetHandler): + handled_layer = (ActFloatQuantProxyFromInjector, WeightFloatQuantProxyFromInjector) + + def prepare_for_export(self, module): + if module.is_quant_enabled: + self.cached_weight = None + if module._cached_weight is not None and not module.cache_inference_quant_weight_metadata_only: + self.cached_weight = module._cached_weight + else: + super().prepare_for_export(module) + + def forward(self, x) -> Tuple[torch.Tensor]: + if self.cached_weight is not None: + x = self.cached_weight + else: + x = self.dequantize(self.quantize(x)) + return x, self.scale, self.zero_point, self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias, self.saturating, self.inf_values, self.nan_values diff --git a/src/brevitas/export/inference/manager.py b/src/brevitas/export/inference/manager.py index 15723c0c0..936106884 100644 --- a/src/brevitas/export/inference/manager.py +++ b/src/brevitas/export/inference/manager.py @@ -1,8 +1,13 @@ +# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + from torch.nn import Module import torch.nn as nn from brevitas.export.inference.handler import FloatInferencetHandler +from brevitas.export.inference.handler import FloatWeightInferencetHandler from brevitas.export.inference.handler import IntInferencetHandler +from brevitas.export.inference.handler import IntWeightInferencetHandler from brevitas.export.manager import _set_proxy_export_handler from brevitas.export.manager import _set_proxy_export_mode from brevitas.export.manager import _set_recurrent_layer_export_handler @@ -32,7 +37,7 @@ def _override_weight_caching_mode(m: nn.Module, enabled: bool, metadata_only: bo _override_caching_mode(m, 'weight', enabled, metadata_only) -class inference_mode: +class quant_inference_mode: def __init__(self, model, cache_quant_weight=False, enabled=True): self.model = model @@ -84,7 +89,11 @@ def hook(self, module, inp, out): # Inheritance from BaseManager is not techincally needed class InferenceManager(BaseManager): - handlers = [IntInferencetHandler, FloatInferencetHandler] + handlers = [ + IntInferencetHandler, + FloatInferencetHandler, + IntWeightInferencetHandler, + FloatWeightInferencetHandler] @classmethod def set_export_mode(cls, model: Module, enabled: bool): diff --git a/src/brevitas/proxy/parameter_quant.py b/src/brevitas/proxy/parameter_quant.py index d76177e52..dc7c704c3 100644 --- a/src/brevitas/proxy/parameter_quant.py +++ b/src/brevitas/proxy/parameter_quant.py @@ -136,7 +136,6 @@ def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]: if self.is_quant_enabled: # If quant is enabled the priority is: # - export mode - # - cached weight # - quantization flow if self.export_mode: out = self.export_handler(x) @@ -144,11 +143,6 @@ def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]: out = out[0] else: out = self.create_quant_tensor(out) - elif self._cached_weight is not None and not self.cache_inference_quant_weight_metadata_only: - if is_dynamo_compiling(): - out = self._cached_weight.value - else: - out = self._cached_weight.quant_tensor else: out = self.tensor_quant(x) if is_dynamo_compiling(): diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py index a561dc0c4..f33946079 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py @@ -18,7 +18,7 @@ from brevitas.export import export_onnx_qcdq from brevitas.export import export_torch_qcdq -from brevitas.export.inference.manager import inference_mode +from brevitas.export.inference import quant_inference_mode from brevitas.graph.quantize import preprocess_for_quantize from brevitas.graph.target.flexml import preprocess_for_flexml_quantize from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_act_equalization @@ -268,6 +268,7 @@ def parse_type(v, default_type): 'uint_sym_act_for_unsigned_values', default=True, help='Use unsigned act quant when possible (default: enabled)') +add_bool_arg(parser, 'compile', default=False, help='Use torch.compile (default: disabled)') def generate_ref_input(args, device, dtype): @@ -482,13 +483,14 @@ def main(): # Validate the quant_model on the validation dataloader print("Starting validation:") - with torch.no_grad(), inference_mode(quant_model): + with torch.no_grad(), quant_inference_mode(quant_model): param = next(iter(quant_model.parameters())) device, dtype = param.device, param.dtype ref_input = generate_ref_input(args, device, dtype) quant_model(ref_input) - quant_model = torch.compile(quant_model, fullgraph=True, dynamic=True) - validate(val_loader, quant_model, stable=dtype != torch.bfloat16) + compiled_model = torch.compile( + quant_model, fullgraph=True, dynamic=True, disable=not args.compile) + validate(val_loader, compiled_model, stable=dtype != torch.bfloat16) if args.export_onnx_qcdq or args.export_torch_qcdq: # Generate reference input tensor to drive the export process diff --git a/src/brevitas_examples/stable_diffusion/main.py b/src/brevitas_examples/stable_diffusion/main.py index 3db0bca9a..5ac69d147 100644 --- a/src/brevitas_examples/stable_diffusion/main.py +++ b/src/brevitas_examples/stable_diffusion/main.py @@ -29,7 +29,6 @@ from brevitas.graph.base import ModuleToModuleByClass from brevitas.graph.calibrate import bias_correction_mode from brevitas.graph.calibrate import calibration_mode -from brevitas.graph.calibrate import inference_mode from brevitas.graph.calibrate import load_quant_model_mode from brevitas.graph.equalize import activation_equalization_mode from brevitas.graph.gptq import gptq_mode @@ -150,6 +149,7 @@ def main(args): calibration_prompts = CALIBRATION_PROMPTS if args.calibration_prompt_path is not None: calibration_prompts = load_calib_prompts(args.calibration_prompt_path) + print(args.calibration_prompt, len(calibration_prompts)) assert args.calibration_prompt <= len(calibration_prompts) , f"Only {len(calibration_prompts)} prompts are available" calibration_prompts = calibration_prompts[:args.calibration_prompt] @@ -231,6 +231,8 @@ def main(args): non_blacklist[name_to_add] = 1 else: non_blacklist[name_to_add] += 1 + print(f"Blacklisted layers: {set(blacklist)}") + print(f"Non blacklisted layers: {non_blacklist}") # Make sure there all LoRA layers are fused first, otherwise raise an error for m in pipe.unet.modules(): diff --git a/tests/brevitas_end_to_end/test_torchvision_models.py b/tests/brevitas_end_to_end/test_torchvision_models.py index af409f8ab..528319f26 100644 --- a/tests/brevitas_end_to_end/test_torchvision_models.py +++ b/tests/brevitas_end_to_end/test_torchvision_models.py @@ -13,7 +13,7 @@ from brevitas import torch_version from brevitas.export import export_onnx_qcdq from brevitas.export import export_torch_qcdq -from brevitas.export.inference.manager import inference_mode +from brevitas.export.inference import quant_inference_mode from brevitas.graph.calibrate import calibration_mode from brevitas.graph.quantize import layerwise_quantize from brevitas.graph.quantize import quantize @@ -28,19 +28,19 @@ MODEL_LIST = [ 'vit_b_32', 'efficientnet_b0', - # 'mobilenet_v3_small', - # 'mobilenet_v2', - # 'resnet50', - # 'resnet18', - # 'mnasnet0_5', - # 'alexnet', - # 'googlenet', - # 'vgg11', - # 'densenet121', - # 'deeplabv3_resnet50', - # 'fcn_resnet50', - # 'regnet_x_400mf', - # 'squeezenet1_0', + 'mobilenet_v3_small', + 'mobilenet_v2', + 'resnet50', + 'resnet18', + 'mnasnet0_5', + 'alexnet', + 'googlenet', + 'vgg11', + 'densenet121', + 'deeplabv3_resnet50', + 'fcn_resnet50', + 'regnet_x_400mf', + 'squeezenet1_0', 'inception_v3'] @@ -71,7 +71,7 @@ def quantize_float(model): @fixture @parametrize('model_name', MODEL_LIST) -@parametrize('quantize_fn', [quantize_float]) +@parametrize('quantize_fn', [quantize_float, quantize, layerwise_quantize, quantize_flexml]) def torchvision_model(model_name, quantize_fn): inp = torch.randn(BATCH, IN_CH, HEIGHT, WIDTH) @@ -122,28 +122,27 @@ def test_torchvision_graph_quantization_flexml_qcdq_onnx( pytest.skip('Model not instantiated') if enable_compile: model_name = test_id.split("-")[1] - if torch_version <= version.parse('2.0'): - pytest.skip("Pytorch 2.0 is required to test compile") + if torch_version <= version.parse('2.2'): + pytest.skip("Pytorch 2.2 is required to test compile") + else: + torch._dynamo.config.capture_scalar_outputs = True if 'vit' in model_name: pytest.skip("QuantMHA not supported with compile") inp = torch.randn(BATCH, IN_CH, HEIGHT, WIDTH) quantize_fn_name = test_id.split("-")[0] - if enable_compile: - torch._dynamo.config.capture_scalar_outputs = True - with torch.no_grad(), inference_mode(torchvision_model): - prehook_non_compiled_out = torchvision_model(inp) - post_hook_non_compiled_out = torchvision_model(inp) - assert torch.allclose(prehook_non_compiled_out, post_hook_non_compiled_out) + with torch.no_grad(), quant_inference_mode(torchvision_model): + prehook_non_compiled_out = torchvision_model(inp) + post_hook_non_compiled_out = torchvision_model(inp) + assert torch.allclose(prehook_non_compiled_out, post_hook_non_compiled_out) + if enable_compile: compiled_model = torch.compile(torchvision_model, fullgraph=True) compiled_out = compiled_model(inp) - # This fails! Compile might needs more small-scoped tests for accuracy evaluation - # assert torch.allclose(post_hook_non_compiled_out, compiled_out) - else: - torchvision_model(inp) + # This fails! Compile might needs more small-scoped tests for accuracy evaluation + # assert torch.allclose(post_hook_non_compiled_out, compiled_out) if quantize_fn_name != 'quantize_float' and not enable_compile: export_onnx_qcdq(torchvision_model, args=inp)