From eb9b9b43d93be6da2d7b21049d2ea3d8d3dbeafe Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 29 Aug 2024 10:57:19 +0100 Subject: [PATCH 01/25] Compile --- src/brevitas/export/manager.py | 8 ++- src/brevitas/nn/mixin/base.py | 10 +++- src/brevitas/proxy/parameter_quant.py | 56 +++++++++++++------ src/brevitas/proxy/quant_proxy.py | 2 +- src/brevitas/proxy/runtime_quant.py | 30 +++++++--- .../imagenet_classification/ptq/ptq_common.py | 1 - .../ptq/ptq_evaluate.py | 29 +++++++--- .../imagenet_classification/utils.py | 2 - .../stable_diffusion/main.py | 4 +- 9 files changed, 99 insertions(+), 43 deletions(-) diff --git a/src/brevitas/export/manager.py b/src/brevitas/export/manager.py index 2805c6174..7b7e7a145 100644 --- a/src/brevitas/export/manager.py +++ b/src/brevitas/export/manager.py @@ -166,11 +166,15 @@ def _trace_fn_dispatcher(cls, fn, input, *args, **kwargs): @classmethod def handler_from_module(cls, module: Module, no_inheritance=False): for handler in cls.handlers: + if not isinstance(handler.handled_layer, tuple): + handled_classes = (handler.handled_layer,) + else: + handled_classes = handler.handled_layer if no_inheritance: - if type(module) == handler.handled_layer: + if type(module) in handled_classes: return handler else: - if isinstance(module, handler.handled_layer): + if any([isinstance(module, handler) for handler in handled_classes]): return handler return None diff --git a/src/brevitas/nn/mixin/base.py b/src/brevitas/nn/mixin/base.py index d64271cb5..59b559787 100644 --- a/src/brevitas/nn/mixin/base.py +++ b/src/brevitas/nn/mixin/base.py @@ -8,12 +8,15 @@ from typing import Optional, Tuple, Union from warnings import warn +import packaging.version +import torch from torch import nn from torch import Tensor import torch.jit from torch.nn.utils.rnn import PackedSequence from brevitas import config +from brevitas import torch_version from brevitas.common import ExportMixin from brevitas.inject import ExtendedInjector from brevitas.inject import Injector @@ -26,6 +29,11 @@ from .utils import filter_kwargs +if torch_version < packaging.version.parse('2.0'): + is_dynamo_compiling = lambda: False +else: + is_dynamo_compiling = torch._dynamo.is_compiling + class QuantProxyMixin(object): __metaclass__ = ABCMeta @@ -85,7 +93,7 @@ def unpack_input(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe qt_class = self.get_quant_tensor_class(inp) if qt_class is not None: inp = qt_class(*inp) - if not torch._C._get_tracing_state(): + if not torch._C._get_tracing_state() and not is_dynamo_compiling(): if isinstance(inp, QuantTensor): inp = inp.set(value=inp.value.rename(None)) else: diff --git a/src/brevitas/proxy/parameter_quant.py b/src/brevitas/proxy/parameter_quant.py index 77a806ee8..f91d90a75 100644 --- a/src/brevitas/proxy/parameter_quant.py +++ b/src/brevitas/proxy/parameter_quant.py @@ -4,10 +4,19 @@ from abc import ABC from abc import ABCMeta from abc import abstractmethod -from typing import Any, Optional, Tuple, Union +from typing import Any, List, Optional, Tuple, Union from warnings import warn +import packaging.version import torch + +from brevitas import torch_version + +if torch_version < packaging.version.parse('2.0'): + is_dynamo_compiling = lambda: False +else: + is_dynamo_compiling = torch._dynamo.is_compiling + from torch import Tensor import torch.nn as nn from typing_extensions import Protocol @@ -16,6 +25,7 @@ from brevitas import config from brevitas.function import max_int from brevitas.inject import BaseInjector as Injector +from brevitas.quant_tensor import _unpack_quant_tensor from brevitas.quant_tensor import IntQuantTensor from brevitas.quant_tensor import QuantTensor from brevitas.utils.quant_utils import _CachedIO @@ -122,15 +132,25 @@ def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]: # - quantization flow if self.export_mode: out = self.export_handler(x) - out = self.create_quant_tensor(out) + if is_dynamo_compiling(): + out = out[0] + else: + out = self.create_quant_tensor(out) elif self._cached_weight is not None and not self.cache_inference_quant_weight_metadata_only: - out = self._cached_weight.quant_tensor + if is_dynamo_compiling(): + out = self._cached_weight.value + else: + out = self._cached_weight.quant_tensor else: out = self.tensor_quant(x) - out = self.create_quant_tensor(out) - if not self.training and self.cache_inference_quant_weight and self._cached_weight is None: - self._cached_weight = self.cache_class( - out.detach(), metadata_only=self.cache_inference_quant_weight_metadata_only) + if is_dynamo_compiling(): + out = out[0] + else: + out = self.create_quant_tensor(out) + if not self.training and self.cache_inference_quant_weight and self._cached_weight is None: + self._cached_weight = self.cache_class( + out.detach(), + metadata_only=self.cache_inference_quant_weight_metadata_only) else: # quantization disabled out = self.apply_input_view(x) return out @@ -151,9 +171,10 @@ def tracked_parameter_list(self): def get_cached(self, attr): if self._cached_bias is None: - warn( - "No quant bias cache found, set cache_inference_quant_bias=True and run an " - "inference pass first") + if not is_dynamo_compiling(): + warn( + "No quant bias cache found, set cache_inference_quant_bias=True and run an " + "inference pass first") return None if self.training: warn("Cached quant bias scale is being used in training mode.") @@ -268,7 +289,7 @@ class BiasQuantProxyFromInjector(BiasQuantProxyFromInjectorBase): def scale(self): if not self.is_quant_enabled: return None - if self.requires_input_scale and self.is_quant_enabled and self.is_quant_enabled: + if self.requires_input_scale and self.is_quant_enabled: cache = self.get_cached('scale') return cache zhs = self._zero_hw_sentinel() @@ -335,12 +356,13 @@ def forward( out, out_scale, out_zp, out_bit_width = impl(x, input_scale) else: out, out_scale, out_zp, out_bit_width = impl(x) - out = IntQuantTensor( - out, out_scale, out_zp, out_bit_width, self.is_signed, self.training) - if not self.training and self.cache_inference_quant_bias: - cached_bias = _CachedIO( - out.detach(), metadata_only=self.cache_inference_quant_bias_metadata_only) - self._cached_bias = cached_bias + if not is_dynamo_compiling(): + out = IntQuantTensor( + out, out_scale, out_zp, out_bit_width, self.is_signed, self.training) + if not self.training and self.cache_inference_quant_bias: + cached_bias = _CachedIO( + out.detach(), metadata_only=self.cache_inference_quant_bias_metadata_only) + self._cached_bias = cached_bias else: out = x return out diff --git a/src/brevitas/proxy/quant_proxy.py b/src/brevitas/proxy/quant_proxy.py index 9c4255773..7f61d3c45 100644 --- a/src/brevitas/proxy/quant_proxy.py +++ b/src/brevitas/proxy/quant_proxy.py @@ -122,7 +122,7 @@ def add_tracked_module(self, module: nn.Module) -> None: raise RuntimeError("Trying to add None as a parent module.") def apply_input_view(self, x): - return self.quant_injector.input_view_impl(x) + return self.tensor_quant.int_quant.input_view_impl(x) def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index 511f914e6..7a5381148 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -5,7 +5,15 @@ from abc import abstractmethod from typing import Any, Optional, Tuple, Union +import packaging.version import torch + +from brevitas import torch_version + +if torch_version < packaging.version.parse('2.0'): + is_dynamo_compiling = lambda: False +else: + is_dynamo_compiling = torch._dynamo.is_compiling from torch import nn from torch import Tensor from torch.nn import Identity @@ -115,6 +123,9 @@ def retrieve_attribute(self, attribute, force_eval): elif self._cached_act is None: return None + def apply_input_view(self, x): + return self.fused_activation_quant_proxy.tensor_quant.int_quant.input_view_impl(x) + @property def is_quant_enabled(self): return self._is_quant_enabled and not self.disable_quant @@ -176,15 +187,18 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: # If y is an empty QuantTensor, we need to check if this is a passthrough proxy, # otherwise return a simple Tensor - # If the second value (i.e., scale) is None, then quant is disabled - if isinstance(y, tuple) and y[1] is not None: - out = self.create_quant_tensor(y) - elif self.is_passthrough_act and isinstance(x, QuantTensor): - # preserve quant_metadata - y = y[0] - out = self.create_quant_tensor(y, x=x) - else: + if is_dynamo_compiling(): out = y[0] + else: + # If the second value (i.e., scale) is None, then quant is disabled + if y[1] is not None: + out = self.create_quant_tensor(y) + elif self.is_passthrough_act and isinstance(x, QuantTensor): + # preserve scale/zp/bit/sign even without output quant + y = y[0] + out = self.create_quant_tensor(y, x=x) + else: + out = y[0] if not self.training and self.cache_inference_quant_act and isinstance(out, QuantTensor): cached_out = self.cache_class(out.detach(), self.cache_quant_io_metadata_only) diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py index bac596be5..661328d7e 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py @@ -22,7 +22,6 @@ from brevitas.graph.target.flexml import quantize_flexml from brevitas.inject import value import brevitas.nn as qnn -from brevitas.quant.experimental.float import Fp8e4m3Act from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloat from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloatMSE from brevitas.quant.experimental.float import Fp8e4m3WeightPerChannelFloat diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py index c960a89e6..a561dc0c4 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py @@ -18,6 +18,7 @@ from brevitas.export import export_onnx_qcdq from brevitas.export import export_torch_qcdq +from brevitas.export.inference.manager import inference_mode from brevitas.graph.quantize import preprocess_for_quantize from brevitas.graph.target.flexml import preprocess_for_flexml_quantize from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_act_equalization @@ -269,6 +270,13 @@ def parse_type(v, default_type): help='Use unsigned act quant when possible (default: enabled)') +def generate_ref_input(args, device, dtype): + model_config = get_model_config(args.model_name) + center_crop_shape = model_config['center_crop_shape'] + img_shape = center_crop_shape + return torch.ones(1, 3, img_shape, img_shape, device=device, dtype=dtype) + + def main(): args = parser.parse_args() dtype = getattr(torch, args.dtype) @@ -474,23 +482,28 @@ def main(): # Validate the quant_model on the validation dataloader print("Starting validation:") - validate(val_loader, quant_model, stable=dtype != torch.bfloat16) + with torch.no_grad(), inference_mode(quant_model): + param = next(iter(quant_model.parameters())) + device, dtype = param.device, param.dtype + ref_input = generate_ref_input(args, device, dtype) + quant_model(ref_input) + quant_model = torch.compile(quant_model, fullgraph=True, dynamic=True) + validate(val_loader, quant_model, stable=dtype != torch.bfloat16) if args.export_onnx_qcdq or args.export_torch_qcdq: # Generate reference input tensor to drive the export process - model_config = get_model_config(args.model_name) - center_crop_shape = model_config['center_crop_shape'] - img_shape = center_crop_shape - device, dtype = next(model.parameters()).device, next(model.parameters()).dtype - ref_input = torch.ones(1, 3, img_shape, img_shape, device=device, dtype=dtype) + param = next(iter(quant_model.parameters())) + device, dtype = param.device, param.dtype + ref_input = generate_ref_input(args, device, dtype) export_name = os.path.join(args.export_dir, config) if args.export_onnx_qcdq: export_name = export_name + '.onnx' - export_onnx_qcdq(model, ref_input, export_name, opset_version=args.onnx_opset_version) + export_onnx_qcdq( + quant_model, ref_input, export_name, opset_version=args.onnx_opset_version) if args.export_torch_qcdq: export_name = export_name + '.pt' - export_torch_qcdq(model, ref_input, export_name) + export_torch_qcdq(quant_model, ref_input, export_name) if __name__ == '__main__': diff --git a/src/brevitas_examples/imagenet_classification/utils.py b/src/brevitas_examples/imagenet_classification/utils.py index d506b8a61..460e7d77f 100644 --- a/src/brevitas_examples/imagenet_classification/utils.py +++ b/src/brevitas_examples/imagenet_classification/utils.py @@ -1,5 +1,3 @@ -import csv - import torch import torchvision.datasets as datasets import torchvision.transforms as transforms diff --git a/src/brevitas_examples/stable_diffusion/main.py b/src/brevitas_examples/stable_diffusion/main.py index a1c4fef53..3db0bca9a 100644 --- a/src/brevitas_examples/stable_diffusion/main.py +++ b/src/brevitas_examples/stable_diffusion/main.py @@ -29,6 +29,7 @@ from brevitas.graph.base import ModuleToModuleByClass from brevitas.graph.calibrate import bias_correction_mode from brevitas.graph.calibrate import calibration_mode +from brevitas.graph.calibrate import inference_mode from brevitas.graph.calibrate import load_quant_model_mode from brevitas.graph.equalize import activation_equalization_mode from brevitas.graph.gptq import gptq_mode @@ -149,7 +150,6 @@ def main(args): calibration_prompts = CALIBRATION_PROMPTS if args.calibration_prompt_path is not None: calibration_prompts = load_calib_prompts(args.calibration_prompt_path) - print(args.calibration_prompt, len(calibration_prompts)) assert args.calibration_prompt <= len(calibration_prompts) , f"Only {len(calibration_prompts)} prompts are available" calibration_prompts = calibration_prompts[:args.calibration_prompt] @@ -231,8 +231,6 @@ def main(args): non_blacklist[name_to_add] = 1 else: non_blacklist[name_to_add] += 1 - print(f"Blacklisted layers: {set(blacklist)}") - print(f"Non blacklisted layers: {non_blacklist}") # Make sure there all LoRA layers are fused first, otherwise raise an error for m in pipe.unet.modules(): From d7657047af782b93e6c3e419b2072f4ec78930d5 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Sat, 7 Sep 2024 11:16:57 +0100 Subject: [PATCH 02/25] Inference handler --- src/brevitas/export/inference/handler.py | 35 +++++++++ src/brevitas/export/inference/manager.py | 95 ++++++++++++++++++++++++ 2 files changed, 130 insertions(+) create mode 100644 src/brevitas/export/inference/handler.py create mode 100644 src/brevitas/export/inference/manager.py diff --git a/src/brevitas/export/inference/handler.py b/src/brevitas/export/inference/handler.py new file mode 100644 index 000000000..6ff5bf258 --- /dev/null +++ b/src/brevitas/export/inference/handler.py @@ -0,0 +1,35 @@ +from typing import Tuple + +import torch + +from brevitas.function.ops import max_int +from brevitas.function.ops import min_int +from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjector +from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector +from brevitas.proxy.runtime_quant import ActQuantProxyFromInjector + + +class IntInferencetHandler(torch.nn.Module): + handled_layer = ( + ActQuantProxyFromInjector, WeightQuantProxyFromInjector, BiasQuantProxyFromInjector) + + def attach_debug_info(self, module): + pass + + def prepare_for_export(self, module): + if module.is_quant_enabled: + self.scale = module.scale() + self.zero_point = module.zero_point().to(self.scale.device) + self.bit_width = module.bit_width() + self.min_int = min_int(module.is_signed, module.is_narrow_range, self.bit_width) + self.max_int = max_int(module.is_signed, module.is_narrow_range, self.bit_width) + + def quant(self, x): + return torch.clamp( + torch.round(x / self.scale + self.zero_point), self.min_int, self.max_int) + + def dequant(self, x): + return (x - self.zero_point) * self.scale + + def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]: + return self.dequant(self.quant(x)), self.scale, self.zero_point, self.bit_width diff --git a/src/brevitas/export/inference/manager.py b/src/brevitas/export/inference/manager.py new file mode 100644 index 000000000..b1db6501d --- /dev/null +++ b/src/brevitas/export/inference/manager.py @@ -0,0 +1,95 @@ +from torch.nn import Module +import torch.nn as nn + +from brevitas.export.inference.handler import IntInferencetHandler +from brevitas.export.manager import _set_proxy_export_handler +from brevitas.export.manager import _set_proxy_export_mode +from brevitas.export.manager import _set_recurrent_layer_export_handler +from brevitas.export.manager import _set_recurrent_layer_export_mode +from brevitas.export.manager import BaseManager +from brevitas.graph.calibrate import disable_return_quant_tensor +from brevitas.graph.calibrate import restore_return_quant_tensor + + +def _override_caching_mode(m: nn.Module, attr: str, enabled: bool, metadata_only: bool = True): + cache_var = 'cache_inference_quant_' + attr + cache_var_metadata_only = cache_var + '_metadata_only' + if hasattr(m, cache_var): + setattr(m, cache_var, enabled) + setattr(m, cache_var_metadata_only, metadata_only) + + +def _override_bias_caching_mode(m: nn.Module, enabled: bool, metadata_only: bool = True): + _override_caching_mode(m, 'bias', enabled, metadata_only) + + +def _override_act_caching_mode(m: nn.Module, enabled: bool, metadata_only: bool = True): + _override_caching_mode(m, 'act', enabled, metadata_only) + + +def _override_weight_caching_mode(m: nn.Module, enabled: bool, metadata_only: bool = False): + _override_caching_mode(m, 'weight', enabled, metadata_only) + + +class inference_mode: + + def __init__(self, model, cache_quant_weight=False, enabled=True): + self.model = model + self.enabled = enabled + self.cache_quant_weight = cache_quant_weight + self.export_manager = InferenceManager + self.hook_list = [] + + def __enter__(self): + if self.enabled: + # Register the hook and store it in the list so that it can be removed by the hook itself when called + handle = self.model.register_forward_hook(self.hook) + self.hook_list.append(handle) + + # Enable bias for everything. Optionally, store the fully fake-quantized weights + self.model.apply( + lambda m: _override_bias_caching_mode(m, enabled=True, metadata_only=True)) + self.model.apply(lambda m: _override_act_caching_mode(m, enabled=True)) + self.model.apply( + lambda m: _override_weight_caching_mode( + m, enabled=True, metadata_only=not self.cache_quant_weight)) + + def __exit__(self, type, value, traceback): + # Disable all caching + # deactivate export mode + # restore return quant tensor + self.model.apply( + lambda m: _override_bias_caching_mode(m, enabled=False, metadata_only=False)) + self.model.apply( + lambda m: _override_act_caching_mode(m, enabled=False, metadata_only=False)) + if self.cache_quant_weight: + self.model.apply( + lambda m: _override_weight_caching_mode(m, enabled=False, metadata_only=False)) + InferenceManager.set_export_mode(self.model, enabled=False) + restore_return_quant_tensor(self.model, self.return_quant_tensor_state) + + def hook(self, module, inp, out): + # After one forward pass with caching enabled, we can: + # - Set the model in export mode + # - Attach export handlers + # - Disable return quant tensor since all quant metadata is cached + assert len(self.hook_list) == 1 + self.hook_list[0].remove() + self.model.apply(InferenceManager.set_export_handler) + InferenceManager.set_export_mode(self.model, enabled=True) + self.return_quant_tensor_state = disable_return_quant_tensor(self.model) + + +# Inheritance from BaseManager is not techincally needed +class InferenceManager(BaseManager): + handlers = [IntInferencetHandler] + + @classmethod + def set_export_mode(cls, model: Module, enabled: bool): + _set_proxy_export_mode(model, enabled) + _set_recurrent_layer_export_mode(model, enabled) + + @classmethod + def set_export_handler(cls, module: Module): + _set_proxy_export_handler(cls, module) + _set_recurrent_layer_export_handler(cls, module) From a73e5789b0817f16b265c015c334da13fd141b83 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 9 Sep 2024 17:08:24 +0100 Subject: [PATCH 03/25] update --- src/brevitas/core/function_wrapper/clamp.py | 50 +++++++++------- src/brevitas/export/inference/handler.py | 58 +++++++++++++++++-- src/brevitas/export/inference/manager.py | 5 +- src/brevitas/graph/calibrate.py | 2 +- src/brevitas/proxy/float_parameter_quant.py | 7 +++ src/brevitas/proxy/float_runtime_quant.py | 10 +++- src/brevitas/proxy/parameter_quant.py | 8 +++ src/brevitas/proxy/quant_proxy.py | 2 +- src/brevitas/proxy/runtime_quant.py | 9 ++- .../test_torchvision_models.py | 32 ++++++++-- 10 files changed, 147 insertions(+), 36 deletions(-) diff --git a/src/brevitas/core/function_wrapper/clamp.py b/src/brevitas/core/function_wrapper/clamp.py index 70d1fc23f..65ad6175f 100644 --- a/src/brevitas/core/function_wrapper/clamp.py +++ b/src/brevitas/core/function_wrapper/clamp.py @@ -113,6 +113,32 @@ def __init__( else: self.max_available_float = None + def inf_nan_clamp(self, x, max_value): + inf_mask = x.isinf() + + p_max_val_mask = x > max_value + n_max_val_mask = -x > max_value + # if non-saturating, we need to map values greater than max_val to nan or inf + if self.inf_values is not None: + # we have inf values, so we set abs values > max_value to +- inf, and leave inf at inf + x[p_max_val_mask] = torch.tensor(float('inf')) + x[n_max_val_mask] = torch.tensor(float('-inf')) + elif self.nan_values is not None: + # no inf values, so we need to map them to NaN + full_max_val_mask = torch.logical_or(p_max_val_mask, n_max_val_mask) + x[full_max_val_mask] = torch.tensor(float('nan')) + + # we also map the inf values to NaN in this case + x[inf_mask] = torch.tensor(float('nan')) + else: + raise RuntimeError( + "Clamping is not saturating, but neither `inf_values` nor `nan_values` is specified" + ) + return x + + def saturating_clamp(self, x, max_value, min_value): + return self.tensor_clamp_impl(x, min_val=min_value, max_val=max_value) + @brevitas.jit.script_method def forward( self, @@ -120,33 +146,15 @@ def forward( exponent_bit_width: Tensor, mantissa_bit_width: Tensor, exponent_bias: Tensor): - inf_mask = x.isinf() max_value = max_float(exponent_bit_width, mantissa_bit_width, exponent_bias) max_value = max_value if self.max_available_float is None else torch.min( max_value, self.max_available_float()) - p_max_val_mask = x > max_value - n_max_val_mask = -x > max_value - min_float = torch.tensor(0.) if not self.signed else -max_value + min_value = torch.tensor(0.) if not self.signed else -max_value # first clamp everything to +- max_value, basically the saturating case - x = self.tensor_clamp_impl(x, min_val=min_float, max_val=max_value) + x = self.saturating_clamp(x, max_value, min_value) if not self.saturating: - # if non-saturating, we need to map values greater than max_val to nan or inf - if self.inf_values is not None: - # we have inf values, so we set abs values > max_value to +- inf, and leave inf at inf - x[p_max_val_mask] = torch.tensor(float('inf')) - x[n_max_val_mask] = torch.tensor(float('-inf')) - elif self.nan_values is not None: - # no inf values, so we need to map them to NaN - full_max_val_mask = torch.logical_or(p_max_val_mask, n_max_val_mask) - x[full_max_val_mask] = torch.tensor(float('nan')) - - # we also map the inf values to NaN in this case - x[inf_mask] = torch.tensor(float('nan')) - else: - raise RuntimeError( - "Clamping is not saturating, but neither `inf_values` nor `nan_values` is specified" - ) + x = self.inf_nan_clamp(x, max_value) return x, self.saturating, self.inf_values, self.nan_values diff --git a/src/brevitas/export/inference/handler.py b/src/brevitas/export/inference/handler.py index 6ff5bf258..f39e3dba8 100644 --- a/src/brevitas/export/inference/handler.py +++ b/src/brevitas/export/inference/handler.py @@ -2,11 +2,14 @@ import torch -from brevitas.function.ops import max_int +from brevitas.function.ops import max_float, max_int from brevitas.function.ops import min_int +from brevitas.proxy.float_parameter_quant import WeightFloatQuantProxyFromInjector +from brevitas.proxy.float_runtime_quant import ActFloatQuantProxyFromInjector, ActFloatQuantProxyFromInjectorBase from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjector from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector from brevitas.proxy.runtime_quant import ActQuantProxyFromInjector +from brevitas.utils.torch_utils import float_internal_scale class IntInferencetHandler(torch.nn.Module): @@ -21,15 +24,62 @@ def prepare_for_export(self, module): self.scale = module.scale() self.zero_point = module.zero_point().to(self.scale.device) self.bit_width = module.bit_width() - self.min_int = min_int(module.is_signed, module.is_narrow_range, self.bit_width) - self.max_int = max_int(module.is_signed, module.is_narrow_range, self.bit_width) + self.min_clamp = min_int(module.is_signed, module.is_narrow_range, self.bit_width) + self.max_clamp = max_int(module.is_signed, module.is_narrow_range, self.bit_width) def quant(self, x): return torch.clamp( - torch.round(x / self.scale + self.zero_point), self.min_int, self.max_int) + torch.round(x / self.scale + self.zero_point), self.min_clamp, self.max_clamp) def dequant(self, x): return (x - self.zero_point) * self.scale def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]: return self.dequant(self.quant(x)), self.scale, self.zero_point, self.bit_width + + + +class FloatInferencetHandler(IntInferencetHandler): + handled_layer = ( + ActFloatQuantProxyFromInjector, WeightFloatQuantProxyFromInjector) + + def attach_debug_info(self, module): + pass + + def prepare_for_export(self, module): + if module.is_quant_enabled: + self.scale = module.scale() + self.zero_point = module.zero_point().to(self.scale.device) + self.exponent_bit_width = module.exponent_bit_width() + self.mantissa_bit_width = module.mantissa_bit_width() + self.exponent_bias = module.exponent_bias() + self.saturating = module.is_saturating() + self.inf_values = module.inf_values() + self.nan_values = module.nan_values() + self.eps = torch.finfo(self.scale.dtype).tiny + if hasattr(module.tensor_quant, 'float_to_int_impl'): + self.float_to_int_impl = module.tensor_quant.float_to_int_impl + self.float_clamp_impl = module.tensor_quant.float_clamp_impl + elif hasattr(module, 'fused_activation_quant_proxy'): + self.float_to_int_impl = module.fused_activation_quant_proxy.tensor_quant.float_to_int_impl + self.float_clamp_impl = module.fused_activation_quant_proxy.tensor_quant.float_clamp_impl + + self.max_clamp = max_float(self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias) + self.min_clamp = -self.max_clamp + self.fp_internal_scale_min = 1. - self.exponent_bias - self.mantissa_bit_width + self.max_value = max_float(self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias) + self.min_value = torch.tensor(0.) if not module.is_signed else -self.max_value + + def quant(self, x): + x = x/self.scale + internal_scale = float_internal_scale( + x, self.mantissa_bit_width, self.fp_internal_scale_min, self.eps) + x = internal_scale * self.float_to_int_impl(x / internal_scale) + x = self.float_clamp_impl.saturating_clamp(x, self.max_value, self.min_value) + if not self.saturating: + x = self.float_clamp_impl.inf_nan_clamp(x, self.max_value) + + return x + + def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]: + return self.dequant(self.quant(x)), self.scale, self.zero_point, self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias, self.saturating, self.inf_values, self.nan_values \ No newline at end of file diff --git a/src/brevitas/export/inference/manager.py b/src/brevitas/export/inference/manager.py index b1db6501d..7248ab362 100644 --- a/src/brevitas/export/inference/manager.py +++ b/src/brevitas/export/inference/manager.py @@ -1,7 +1,7 @@ from torch.nn import Module import torch.nn as nn -from brevitas.export.inference.handler import IntInferencetHandler +from brevitas.export.inference.handler import FloatInferencetHandler, IntInferencetHandler from brevitas.export.manager import _set_proxy_export_handler from brevitas.export.manager import _set_proxy_export_mode from brevitas.export.manager import _set_recurrent_layer_export_handler @@ -39,6 +39,7 @@ def __init__(self, model, cache_quant_weight=False, enabled=True): self.cache_quant_weight = cache_quant_weight self.export_manager = InferenceManager self.hook_list = [] + self.return_quant_tensor_state = dict() def __enter__(self): if self.enabled: @@ -82,7 +83,7 @@ def hook(self, module, inp, out): # Inheritance from BaseManager is not techincally needed class InferenceManager(BaseManager): - handlers = [IntInferencetHandler] + handlers = [IntInferencetHandler, FloatInferencetHandler] @classmethod def set_export_mode(cls, model: Module, enabled: bool): diff --git a/src/brevitas/graph/calibrate.py b/src/brevitas/graph/calibrate.py index 92228b7a3..2b1f6833e 100644 --- a/src/brevitas/graph/calibrate.py +++ b/src/brevitas/graph/calibrate.py @@ -58,7 +58,7 @@ def disable_return_quant_tensor(model): def restore_return_quant_tensor(model, previous_state): for module in model.modules(): - if hasattr(module, 'return_quant_tensor'): + if hasattr(module, 'return_quant_tensor') and module in previous_state: module.return_quant_tensor = previous_state[module] diff --git a/src/brevitas/proxy/float_parameter_quant.py b/src/brevitas/proxy/float_parameter_quant.py index 4e6452792..47b8604a7 100644 --- a/src/brevitas/proxy/float_parameter_quant.py +++ b/src/brevitas/proxy/float_parameter_quant.py @@ -4,6 +4,7 @@ from torch import Tensor import torch.nn as nn +from brevitas.core.function_wrapper.misc import Identity from brevitas.inject import BaseInjector as Injector from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjectorBase from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjectorBase @@ -83,6 +84,12 @@ def is_fnuz(self): ) is None and self.exponent_bias() == 16 return is_fnuz_e4m3 or is_fnuz_e5m2 + @property + def input_view_impl(self): + if self.tensor_quant is not None: + return self.tensor_quant.input_view_impl + else: + return Identity() class WeightFloatQuantProxyFromInjector(WeightFloatQuantProxyFromInjectorBase): diff --git a/src/brevitas/proxy/float_runtime_quant.py b/src/brevitas/proxy/float_runtime_quant.py index b38f4ecdb..7350e5e32 100644 --- a/src/brevitas/proxy/float_runtime_quant.py +++ b/src/brevitas/proxy/float_runtime_quant.py @@ -4,6 +4,7 @@ import torch import torch.nn as nn +from brevitas.core.function_wrapper.misc import Identity from brevitas.inject import BaseInjector as Injector from brevitas.proxy.runtime_quant import ActQuantProxyFromInjectorBase from brevitas.quant_tensor import FloatQuantTensor @@ -27,7 +28,7 @@ def mantissa_bit_width(self, force_eval=True): def exponent_bias(self, force_eval=True): return self.retrieve_attribute('exponent_bias', force_eval) - def saturating(self, force_eval=True): + def is_saturating(self, force_eval=True): return self.retrieve_attribute('saturating', force_eval) def inf_values(self, force_eval=True): @@ -36,6 +37,13 @@ def inf_values(self, force_eval=True): def nan_values(self, force_eval=True): return self.retrieve_attribute('nan_values', force_eval) + @property + def input_view_impl(self): + if self.fused_activation_quant_proxy.tensor_quant is not None: + return self.fused_activation_quant_proxy.tensor_quant.input_view_impl + else: + return Identity() + @property def is_ocp(self): is_e4m3 = self.mantissa_bit_width() == 3 and self.exponent_bit_width() == 4 diff --git a/src/brevitas/proxy/parameter_quant.py b/src/brevitas/proxy/parameter_quant.py index f91d90a75..d76177e52 100644 --- a/src/brevitas/proxy/parameter_quant.py +++ b/src/brevitas/proxy/parameter_quant.py @@ -11,6 +11,7 @@ import torch from brevitas import torch_version +from brevitas.core.function_wrapper.misc import Identity if torch_version < packaging.version.parse('2.0'): is_dynamo_compiling = lambda: False @@ -102,6 +103,13 @@ def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None: self.cache_class = None # To be redefined by each class self.quant_tensor_class = None # To be redefined by each class + @property + def input_view_impl(self): + if self.tensor_quant is not None: + return self.tensor_quant.int_quant.input_view_impl + else: + return Identity() + @property def cache_inference_quant_weight(self): return self._cache_inference_quant_weight diff --git a/src/brevitas/proxy/quant_proxy.py b/src/brevitas/proxy/quant_proxy.py index 7f61d3c45..845bfd515 100644 --- a/src/brevitas/proxy/quant_proxy.py +++ b/src/brevitas/proxy/quant_proxy.py @@ -122,7 +122,7 @@ def add_tracked_module(self, module: nn.Module) -> None: raise RuntimeError("Trying to add None as a parent module.") def apply_input_view(self, x): - return self.tensor_quant.int_quant.input_view_impl(x) + return self.input_view_impl(x) def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index 7a5381148..1ba9330ad 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -106,6 +106,13 @@ def __init__(self, quant_layer, quant_injector): self.cache_quant_io_metadata_only = True self.cache_class = None + @property + def input_view_impl(self): + if self.fused_activation_quant_proxy.tensor_quant is not None: + return self.fused_activation_quant_proxy.tensor_quant.int_quant.input_view_impl + else: + return Identity() + def internal_forward(self, force_eval): current_status = self.training if force_eval: @@ -124,7 +131,7 @@ def retrieve_attribute(self, attribute, force_eval): return None def apply_input_view(self, x): - return self.fused_activation_quant_proxy.tensor_quant.int_quant.input_view_impl(x) + return self.input_view_impl(x) @property def is_quant_enabled(self): diff --git a/tests/brevitas_end_to_end/test_torchvision_models.py b/tests/brevitas_end_to_end/test_torchvision_models.py index 0d76ae2db..1f1e2f614 100644 --- a/tests/brevitas_end_to_end/test_torchvision_models.py +++ b/tests/brevitas_end_to_end/test_torchvision_models.py @@ -70,7 +70,7 @@ def quantize_float(model): @fixture @parametrize('model_name', MODEL_LIST) -@parametrize('quantize_fn', [quantize, quantize_flexml, layerwise_quantize]) +@parametrize('quantize_fn', [quantize_float, quantize, quantize_flexml, layerwise_quantize]) def torchvision_model(model_name, quantize_fn): inp = torch.randn(BATCH, IN_CH, HEIGHT, WIDTH) @@ -113,15 +113,37 @@ def torchvision_model(model_name, quantize_fn): @requires_pt_ge('1.8.1') -def test_torchvision_graph_quantization_flexml_qcdq_onnx(torchvision_model, request): +@parametrize('enable_compile', [True, False]) +def test_torchvision_graph_quantization_flexml_qcdq_onnx(torchvision_model, enable_compile, request): + test_id = request.node.callspec.id if torchvision_model is None: pytest.skip('Model not instantiated') + if enable_compile: + torch._dynamo.config.capture_scalar_outputs = True + model_name = test_id.split("-")[1] + if torch_version <= version.parse('2.0'): + pytest.skip("Pytorch 2.0 is required to test compile") + if 'vit' in model_name: + pytest.skip("QuantMHA not supported with compile") + inp = torch.randn(BATCH, IN_CH, HEIGHT, WIDTH) - test_id = request.node.callspec.id quantize_fn_name = test_id.split("-")[0] - torchvision_model(inp) - if quantize_fn_name != 'quantize_float': + if enable_compile: + with torch.no_grad(), inference_mode(torchvision_model): + prehook_non_compiled_out = torchvision_model(inp) + post_hook_non_compiled_out = torchvision_model(inp) + assert torch.allclose(prehook_non_compiled_out, post_hook_non_compiled_out) + + compiled_model = torch.compile(torchvision_model, fullgraph=True) + compiled_out = compiled_model(inp) + + # This fails! Compile might needs more small-scoped tests for accuracy evaluation + # assert torch.allclose(post_hook_non_compiled_out, compiled_out) + else: + torchvision_model(inp) + + if quantize_fn_name != 'quantize_float' and not enable_compile: export_onnx_qcdq(torchvision_model, args=inp) From db899aee5dc4ea243e77ba18ad4ee87af491e0b2 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 9 Sep 2024 17:08:51 +0100 Subject: [PATCH 04/25] precommit --- src/brevitas/export/inference/handler.py | 22 ++++++++++--------- src/brevitas/export/inference/manager.py | 3 ++- src/brevitas/proxy/float_parameter_quant.py | 1 + .../test_torchvision_models.py | 7 +++--- 4 files changed, 19 insertions(+), 14 deletions(-) diff --git a/src/brevitas/export/inference/handler.py b/src/brevitas/export/inference/handler.py index f39e3dba8..0cab461a0 100644 --- a/src/brevitas/export/inference/handler.py +++ b/src/brevitas/export/inference/handler.py @@ -2,10 +2,12 @@ import torch -from brevitas.function.ops import max_float, max_int +from brevitas.function.ops import max_float +from brevitas.function.ops import max_int from brevitas.function.ops import min_int from brevitas.proxy.float_parameter_quant import WeightFloatQuantProxyFromInjector -from brevitas.proxy.float_runtime_quant import ActFloatQuantProxyFromInjector, ActFloatQuantProxyFromInjectorBase +from brevitas.proxy.float_runtime_quant import ActFloatQuantProxyFromInjector +from brevitas.proxy.float_runtime_quant import ActFloatQuantProxyFromInjectorBase from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjector from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector from brevitas.proxy.runtime_quant import ActQuantProxyFromInjector @@ -38,10 +40,8 @@ def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]: return self.dequant(self.quant(x)), self.scale, self.zero_point, self.bit_width - class FloatInferencetHandler(IntInferencetHandler): - handled_layer = ( - ActFloatQuantProxyFromInjector, WeightFloatQuantProxyFromInjector) + handled_layer = (ActFloatQuantProxyFromInjector, WeightFloatQuantProxyFromInjector) def attach_debug_info(self, module): pass @@ -64,22 +64,24 @@ def prepare_for_export(self, module): self.float_to_int_impl = module.fused_activation_quant_proxy.tensor_quant.float_to_int_impl self.float_clamp_impl = module.fused_activation_quant_proxy.tensor_quant.float_clamp_impl - self.max_clamp = max_float(self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias) + self.max_clamp = max_float( + self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias) self.min_clamp = -self.max_clamp self.fp_internal_scale_min = 1. - self.exponent_bias - self.mantissa_bit_width - self.max_value = max_float(self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias) + self.max_value = max_float( + self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias) self.min_value = torch.tensor(0.) if not module.is_signed else -self.max_value def quant(self, x): - x = x/self.scale + x = x / self.scale internal_scale = float_internal_scale( x, self.mantissa_bit_width, self.fp_internal_scale_min, self.eps) x = internal_scale * self.float_to_int_impl(x / internal_scale) x = self.float_clamp_impl.saturating_clamp(x, self.max_value, self.min_value) if not self.saturating: x = self.float_clamp_impl.inf_nan_clamp(x, self.max_value) - + return x def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]: - return self.dequant(self.quant(x)), self.scale, self.zero_point, self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias, self.saturating, self.inf_values, self.nan_values \ No newline at end of file + return self.dequant(self.quant(x)), self.scale, self.zero_point, self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias, self.saturating, self.inf_values, self.nan_values diff --git a/src/brevitas/export/inference/manager.py b/src/brevitas/export/inference/manager.py index 7248ab362..15723c0c0 100644 --- a/src/brevitas/export/inference/manager.py +++ b/src/brevitas/export/inference/manager.py @@ -1,7 +1,8 @@ from torch.nn import Module import torch.nn as nn -from brevitas.export.inference.handler import FloatInferencetHandler, IntInferencetHandler +from brevitas.export.inference.handler import FloatInferencetHandler +from brevitas.export.inference.handler import IntInferencetHandler from brevitas.export.manager import _set_proxy_export_handler from brevitas.export.manager import _set_proxy_export_mode from brevitas.export.manager import _set_recurrent_layer_export_handler diff --git a/src/brevitas/proxy/float_parameter_quant.py b/src/brevitas/proxy/float_parameter_quant.py index 47b8604a7..0d6ffd106 100644 --- a/src/brevitas/proxy/float_parameter_quant.py +++ b/src/brevitas/proxy/float_parameter_quant.py @@ -91,6 +91,7 @@ def input_view_impl(self): else: return Identity() + class WeightFloatQuantProxyFromInjector(WeightFloatQuantProxyFromInjectorBase): def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None: diff --git a/tests/brevitas_end_to_end/test_torchvision_models.py b/tests/brevitas_end_to_end/test_torchvision_models.py index 1f1e2f614..fcfc90d80 100644 --- a/tests/brevitas_end_to_end/test_torchvision_models.py +++ b/tests/brevitas_end_to_end/test_torchvision_models.py @@ -114,7 +114,8 @@ def torchvision_model(model_name, quantize_fn): @requires_pt_ge('1.8.1') @parametrize('enable_compile', [True, False]) -def test_torchvision_graph_quantization_flexml_qcdq_onnx(torchvision_model, enable_compile, request): +def test_torchvision_graph_quantization_flexml_qcdq_onnx( + torchvision_model, enable_compile, request): test_id = request.node.callspec.id if torchvision_model is None: pytest.skip('Model not instantiated') @@ -133,13 +134,13 @@ def test_torchvision_graph_quantization_flexml_qcdq_onnx(torchvision_model, enab with torch.no_grad(), inference_mode(torchvision_model): prehook_non_compiled_out = torchvision_model(inp) post_hook_non_compiled_out = torchvision_model(inp) - assert torch.allclose(prehook_non_compiled_out, post_hook_non_compiled_out) + assert torch.allclose(prehook_non_compiled_out, post_hook_non_compiled_out) compiled_model = torch.compile(torchvision_model, fullgraph=True) compiled_out = compiled_model(inp) # This fails! Compile might needs more small-scoped tests for accuracy evaluation - # assert torch.allclose(post_hook_non_compiled_out, compiled_out) + # assert torch.allclose(post_hook_non_compiled_out, compiled_out) else: torchvision_model(inp) From f2cdb416badd3b6a9e8465630795de86b7c31bb0 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 9 Sep 2024 17:39:07 +0100 Subject: [PATCH 05/25] fix --- src/brevitas/proxy/runtime_quant.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index 1ba9330ad..03303bcc8 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -108,7 +108,8 @@ def __init__(self, quant_layer, quant_injector): @property def input_view_impl(self): - if self.fused_activation_quant_proxy.tensor_quant is not None: + if self.fused_activation_quant_proxy.tensor_quant is not None and not isinstance( + self.fused_activation_quant_proxy.tensor_quant, _TensorQuantDisabledIdentity): return self.fused_activation_quant_proxy.tensor_quant.int_quant.input_view_impl else: return Identity() From 726933e14c96a29996b0a3f06171d95d547b5717 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 9 Sep 2024 17:45:50 +0100 Subject: [PATCH 06/25] fix test --- tests/brevitas_end_to_end/test_torchvision_models.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/brevitas_end_to_end/test_torchvision_models.py b/tests/brevitas_end_to_end/test_torchvision_models.py index fcfc90d80..ce12c8e7c 100644 --- a/tests/brevitas_end_to_end/test_torchvision_models.py +++ b/tests/brevitas_end_to_end/test_torchvision_models.py @@ -13,6 +13,7 @@ from brevitas import torch_version from brevitas.export import export_onnx_qcdq from brevitas.export import export_torch_qcdq +from brevitas.export.inference.manager import inference_mode from brevitas.graph.calibrate import calibration_mode from brevitas.graph.quantize import layerwise_quantize from brevitas.graph.quantize import quantize @@ -120,7 +121,6 @@ def test_torchvision_graph_quantization_flexml_qcdq_onnx( if torchvision_model is None: pytest.skip('Model not instantiated') if enable_compile: - torch._dynamo.config.capture_scalar_outputs = True model_name = test_id.split("-")[1] if torch_version <= version.parse('2.0'): pytest.skip("Pytorch 2.0 is required to test compile") @@ -131,6 +131,7 @@ def test_torchvision_graph_quantization_flexml_qcdq_onnx( quantize_fn_name = test_id.split("-")[0] if enable_compile: + torch._dynamo.config.capture_scalar_outputs = True with torch.no_grad(), inference_mode(torchvision_model): prehook_non_compiled_out = torchvision_model(inp) post_hook_non_compiled_out = torchvision_model(inp) From f2c22abbd64640745e8c4808367243443a082e2f Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 9 Sep 2024 18:09:28 +0100 Subject: [PATCH 07/25] cleanup --- src/brevitas/proxy/parameter_quant.py | 2 +- src/brevitas/proxy/runtime_quant.py | 5 +---- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/src/brevitas/proxy/parameter_quant.py b/src/brevitas/proxy/parameter_quant.py index d76177e52..6ee882ed5 100644 --- a/src/brevitas/proxy/parameter_quant.py +++ b/src/brevitas/proxy/parameter_quant.py @@ -160,7 +160,7 @@ def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]: out.detach(), metadata_only=self.cache_inference_quant_weight_metadata_only) else: # quantization disabled - out = self.apply_input_view(x) + out = self.input_view_impl(x) return out diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index 03303bcc8..ebe864f82 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -131,9 +131,6 @@ def retrieve_attribute(self, attribute, force_eval): elif self._cached_act is None: return None - def apply_input_view(self, x): - return self.input_view_impl(x) - @property def is_quant_enabled(self): return self._is_quant_enabled and not self.disable_quant @@ -188,7 +185,7 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: # A tuple helps later with control flows # The second None value is used later # If quant is not enabled, we still apply input_view in the case of groupwise + padding - y = self.apply_input_view(self.fused_activation_quant_proxy.activation_impl(y)) + y = self.input_view_impl(self.fused_activation_quant_proxy.activation_impl(y)) y = (y, None) else: y = self.fused_activation_quant_proxy(y) From 12a428329d9a67ad731e7ce48b95cbb21d3d26f5 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 9 Sep 2024 18:23:20 +0100 Subject: [PATCH 08/25] fix inf mask --- src/brevitas/core/function_wrapper/clamp.py | 6 +++--- src/brevitas/export/inference/handler.py | 3 ++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/brevitas/core/function_wrapper/clamp.py b/src/brevitas/core/function_wrapper/clamp.py index 65ad6175f..cfb0a693b 100644 --- a/src/brevitas/core/function_wrapper/clamp.py +++ b/src/brevitas/core/function_wrapper/clamp.py @@ -113,8 +113,7 @@ def __init__( else: self.max_available_float = None - def inf_nan_clamp(self, x, max_value): - inf_mask = x.isinf() + def inf_nan_clamp(self, x, max_value, inf_mask): p_max_val_mask = x > max_value n_max_val_mask = -x > max_value @@ -146,6 +145,7 @@ def forward( exponent_bit_width: Tensor, mantissa_bit_width: Tensor, exponent_bias: Tensor): + inf_mask = x.isinf() max_value = max_float(exponent_bit_width, mantissa_bit_width, exponent_bias) max_value = max_value if self.max_available_float is None else torch.min( max_value, self.max_available_float()) @@ -155,6 +155,6 @@ def forward( x = self.saturating_clamp(x, max_value, min_value) if not self.saturating: - x = self.inf_nan_clamp(x, max_value) + x = self.inf_nan_clamp(x, max_value, inf_mask) return x, self.saturating, self.inf_values, self.nan_values diff --git a/src/brevitas/export/inference/handler.py b/src/brevitas/export/inference/handler.py index 0cab461a0..a35fa74b8 100644 --- a/src/brevitas/export/inference/handler.py +++ b/src/brevitas/export/inference/handler.py @@ -73,13 +73,14 @@ def prepare_for_export(self, module): self.min_value = torch.tensor(0.) if not module.is_signed else -self.max_value def quant(self, x): + inf_mask = x.isinf() x = x / self.scale internal_scale = float_internal_scale( x, self.mantissa_bit_width, self.fp_internal_scale_min, self.eps) x = internal_scale * self.float_to_int_impl(x / internal_scale) x = self.float_clamp_impl.saturating_clamp(x, self.max_value, self.min_value) if not self.saturating: - x = self.float_clamp_impl.inf_nan_clamp(x, self.max_value) + x = self.float_clamp_impl.inf_nan_clamp(x, self.max_value, inf_mask) return x From 1e4a89b7b81c2201b9dce9619e315e70afaa7fbe Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 9 Sep 2024 19:11:29 +0100 Subject: [PATCH 09/25] fix float tests --- src/brevitas/core/function_wrapper/clamp.py | 13 ++++++++----- src/brevitas/export/inference/handler.py | 5 ++++- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/src/brevitas/core/function_wrapper/clamp.py b/src/brevitas/core/function_wrapper/clamp.py index cfb0a693b..163e63a22 100644 --- a/src/brevitas/core/function_wrapper/clamp.py +++ b/src/brevitas/core/function_wrapper/clamp.py @@ -113,10 +113,8 @@ def __init__( else: self.max_available_float = None - def inf_nan_clamp(self, x, max_value, inf_mask): + def inf_nan_clamp(self, x, inf_mask, p_max_val_mask, n_max_val_mask): - p_max_val_mask = x > max_value - n_max_val_mask = -x > max_value # if non-saturating, we need to map values greater than max_val to nan or inf if self.inf_values is not None: # we have inf values, so we set abs values > max_value to +- inf, and leave inf at inf @@ -145,16 +143,21 @@ def forward( exponent_bit_width: Tensor, mantissa_bit_width: Tensor, exponent_bias: Tensor): - inf_mask = x.isinf() + max_value = max_float(exponent_bit_width, mantissa_bit_width, exponent_bias) max_value = max_value if self.max_available_float is None else torch.min( max_value, self.max_available_float()) min_value = torch.tensor(0.) if not self.signed else -max_value + # Compute masks + inf_mask = x.isinf() + p_max_val_mask = x > max_value + n_max_val_mask = -x > max_value + # first clamp everything to +- max_value, basically the saturating case x = self.saturating_clamp(x, max_value, min_value) if not self.saturating: - x = self.inf_nan_clamp(x, max_value, inf_mask) + x = self.inf_nan_clamp(x, inf_mask, p_max_val_mask, n_max_val_mask) return x, self.saturating, self.inf_values, self.nan_values diff --git a/src/brevitas/export/inference/handler.py b/src/brevitas/export/inference/handler.py index a35fa74b8..0be56bc5b 100644 --- a/src/brevitas/export/inference/handler.py +++ b/src/brevitas/export/inference/handler.py @@ -73,14 +73,17 @@ def prepare_for_export(self, module): self.min_value = torch.tensor(0.) if not module.is_signed else -self.max_value def quant(self, x): + # Compute masks inf_mask = x.isinf() + p_max_val_mask = x > self.max_value + n_max_val_mask = -x > self.max_value x = x / self.scale internal_scale = float_internal_scale( x, self.mantissa_bit_width, self.fp_internal_scale_min, self.eps) x = internal_scale * self.float_to_int_impl(x / internal_scale) x = self.float_clamp_impl.saturating_clamp(x, self.max_value, self.min_value) if not self.saturating: - x = self.float_clamp_impl.inf_nan_clamp(x, self.max_value, inf_mask) + x = self.float_clamp_impl.inf_nan_clamp(x, inf_mask, p_max_val_mask, n_max_val_mask) return x From f392e0eb8695a608c24569813df8823370e9d17b Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 9 Sep 2024 19:47:18 +0100 Subject: [PATCH 10/25] restore apply_input_view --- src/brevitas/proxy/parameter_quant.py | 2 +- src/brevitas/proxy/runtime_quant.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/brevitas/proxy/parameter_quant.py b/src/brevitas/proxy/parameter_quant.py index 6ee882ed5..d76177e52 100644 --- a/src/brevitas/proxy/parameter_quant.py +++ b/src/brevitas/proxy/parameter_quant.py @@ -160,7 +160,7 @@ def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]: out.detach(), metadata_only=self.cache_inference_quant_weight_metadata_only) else: # quantization disabled - out = self.input_view_impl(x) + out = self.apply_input_view(x) return out diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index ebe864f82..03303bcc8 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -131,6 +131,9 @@ def retrieve_attribute(self, attribute, force_eval): elif self._cached_act is None: return None + def apply_input_view(self, x): + return self.input_view_impl(x) + @property def is_quant_enabled(self): return self._is_quant_enabled and not self.disable_quant @@ -185,7 +188,7 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: # A tuple helps later with control flows # The second None value is used later # If quant is not enabled, we still apply input_view in the case of groupwise + padding - y = self.input_view_impl(self.fused_activation_quant_proxy.activation_impl(y)) + y = self.apply_input_view(self.fused_activation_quant_proxy.activation_impl(y)) y = (y, None) else: y = self.fused_activation_quant_proxy(y) From 1c638c09280357aeb1a1d75d74b53313b0727f41 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 9 Sep 2024 21:15:10 +0100 Subject: [PATCH 11/25] Fix API export --- src/brevitas/export/common/handler/qcdq.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/brevitas/export/common/handler/qcdq.py b/src/brevitas/export/common/handler/qcdq.py index 44061ce42..bbc03b630 100644 --- a/src/brevitas/export/common/handler/qcdq.py +++ b/src/brevitas/export/common/handler/qcdq.py @@ -454,7 +454,7 @@ def prepare_for_export(self, module): self.symbolic_kwargs['exponent_bit_width'] = module.exponent_bit_width() self.symbolic_kwargs['mantissa_bit_width'] = module.mantissa_bit_width() self.symbolic_kwargs['exponent_bias'] = module.exponent_bias() - self.symbolic_kwargs['saturating'] = module.saturating() + self.symbolic_kwargs['saturating'] = module.is_saturating() self.symbolic_kwargs['inf_values'] = module.inf_values() self.symbolic_kwargs['nan_values'] = module.nan_values() @@ -659,7 +659,7 @@ def prepare_for_export(self, module): 'exponent_bit_width': module.exponent_bit_width(), 'mantissa_bit_width': module.mantissa_bit_width(), 'exponent_bias': module.exponent_bias(), - 'saturating': module.saturating(), + 'saturating': module.is_saturating(), 'inf_values': module.inf_values(), 'nan_values': module.nan_values()} From 8299eb91b076abee7fd345ca0ad78f0ec32a180c Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 9 Sep 2024 23:00:40 +0100 Subject: [PATCH 12/25] small test, temp --- .../test_torchvision_models.py | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/tests/brevitas_end_to_end/test_torchvision_models.py b/tests/brevitas_end_to_end/test_torchvision_models.py index ce12c8e7c..82e1cdb57 100644 --- a/tests/brevitas_end_to_end/test_torchvision_models.py +++ b/tests/brevitas_end_to_end/test_torchvision_models.py @@ -28,19 +28,19 @@ MODEL_LIST = [ 'vit_b_32', 'efficientnet_b0', - 'mobilenet_v3_small', - 'mobilenet_v2', - 'resnet50', - 'resnet18', - 'mnasnet0_5', - 'alexnet', - 'googlenet', - 'vgg11', - 'densenet121', - 'deeplabv3_resnet50', - 'fcn_resnet50', - 'regnet_x_400mf', - 'squeezenet1_0', + # 'mobilenet_v3_small', + # 'mobilenet_v2', + # 'resnet50', + # 'resnet18', + # 'mnasnet0_5', + # 'alexnet', + # 'googlenet', + # 'vgg11', + # 'densenet121', + # 'deeplabv3_resnet50', + # 'fcn_resnet50', + # 'regnet_x_400mf', + # 'squeezenet1_0', 'inception_v3'] From d8cab12149a65d9a61baed370d356e0e52315bed Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 10 Sep 2024 00:03:36 +0100 Subject: [PATCH 13/25] even smaller tests --- tests/brevitas_end_to_end/test_torchvision_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/brevitas_end_to_end/test_torchvision_models.py b/tests/brevitas_end_to_end/test_torchvision_models.py index 82e1cdb57..af409f8ab 100644 --- a/tests/brevitas_end_to_end/test_torchvision_models.py +++ b/tests/brevitas_end_to_end/test_torchvision_models.py @@ -71,7 +71,7 @@ def quantize_float(model): @fixture @parametrize('model_name', MODEL_LIST) -@parametrize('quantize_fn', [quantize_float, quantize, quantize_flexml, layerwise_quantize]) +@parametrize('quantize_fn', [quantize_float]) def torchvision_model(model_name, quantize_fn): inp = torch.randn(BATCH, IN_CH, HEIGHT, WIDTH) From f5446a8a95eb11a6f78b8dad537102ff399981b5 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 10 Sep 2024 10:18:33 +0100 Subject: [PATCH 14/25] Cleanup --- src/brevitas/export/inference/handler.py | 90 ++++++++++++++++--- src/brevitas/export/inference/manager.py | 13 ++- src/brevitas/proxy/parameter_quant.py | 6 -- .../ptq/ptq_evaluate.py | 10 ++- .../stable_diffusion/main.py | 4 +- .../test_torchvision_models.py | 53 ++++++----- 6 files changed, 123 insertions(+), 53 deletions(-) diff --git a/src/brevitas/export/inference/handler.py b/src/brevitas/export/inference/handler.py index 0be56bc5b..bd81cadc3 100644 --- a/src/brevitas/export/inference/handler.py +++ b/src/brevitas/export/inference/handler.py @@ -1,3 +1,8 @@ +# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +from abc import ABC +from abc import abstractmethod from typing import Tuple import torch @@ -14,9 +19,26 @@ from brevitas.utils.torch_utils import float_internal_scale -class IntInferencetHandler(torch.nn.Module): - handled_layer = ( - ActQuantProxyFromInjector, WeightQuantProxyFromInjector, BiasQuantProxyFromInjector) +class InferenceHandler(torch.nn.Module, ABC): + + def attach_debug_info(self, module): + pass + + @abstractmethod + def prepare_for_export(self, module): + pass + + @abstractmethod + def quantize(self, x): + pass + + @abstractmethod + def dequantize(self, x): + pass + + +class IntInferencetHandler(InferenceHandler): + handled_layer = (ActQuantProxyFromInjector, BiasQuantProxyFromInjector) def attach_debug_info(self, module): pass @@ -29,22 +51,38 @@ def prepare_for_export(self, module): self.min_clamp = min_int(module.is_signed, module.is_narrow_range, self.bit_width) self.max_clamp = max_int(module.is_signed, module.is_narrow_range, self.bit_width) - def quant(self, x): + def quantize(self, x): return torch.clamp( torch.round(x / self.scale + self.zero_point), self.min_clamp, self.max_clamp) - def dequant(self, x): + def dequantize(self, x): return (x - self.zero_point) * self.scale def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]: - return self.dequant(self.quant(x)), self.scale, self.zero_point, self.bit_width + return self.dequantize(self.quantize(x)), self.scale, self.zero_point, self.bit_width -class FloatInferencetHandler(IntInferencetHandler): - handled_layer = (ActFloatQuantProxyFromInjector, WeightFloatQuantProxyFromInjector) +class IntWeightInferencetHandler(IntInferencetHandler): + handled_layer = WeightQuantProxyFromInjector + + def prepare_for_export(self, module): + if module.is_quant_enabled: + self.cached_weight = None + if module._cached_weight is not None and not module.cache_inference_quant_weight_metadata_only: + self.cached_weight = module._cached_weight + else: + super.prepare_for_export(module) + + def forward(self, x) -> Tuple[torch.Tensor]: + if self.cached_weight is not None: + x = self.cached_weight + else: + x = self.dequantize(self.quantize(x)) + return x, self.scale, self.zero_point, self.bit_width - def attach_debug_info(self, module): - pass + +class FloatInferencetHandler(InferenceHandler): + handled_layer = (ActFloatQuantProxyFromInjector, BiasQuantProxyFromInjector) def prepare_for_export(self, module): if module.is_quant_enabled: @@ -72,20 +110,46 @@ def prepare_for_export(self, module): self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias) self.min_value = torch.tensor(0.) if not module.is_signed else -self.max_value - def quant(self, x): + def quantize(self, x): # Compute masks inf_mask = x.isinf() p_max_val_mask = x > self.max_value n_max_val_mask = -x > self.max_value + + # Quantize x = x / self.scale internal_scale = float_internal_scale( x, self.mantissa_bit_width, self.fp_internal_scale_min, self.eps) x = internal_scale * self.float_to_int_impl(x / internal_scale) + + # Clamp x = self.float_clamp_impl.saturating_clamp(x, self.max_value, self.min_value) if not self.saturating: x = self.float_clamp_impl.inf_nan_clamp(x, inf_mask, p_max_val_mask, n_max_val_mask) return x - def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]: - return self.dequant(self.quant(x)), self.scale, self.zero_point, self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias, self.saturating, self.inf_values, self.nan_values + def dequantize(self, x): + return (x - self.zero_point) * self.scale + + def forward(self, x) -> Tuple[torch.Tensor]: + return self.dequantize(self.quantize(x)), self.scale, self.zero_point, self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias, self.saturating, self.inf_values, self.nan_values + + +class FloatWeightInferencetHandler(FloatInferencetHandler): + handled_layer = (ActFloatQuantProxyFromInjector, WeightFloatQuantProxyFromInjector) + + def prepare_for_export(self, module): + if module.is_quant_enabled: + self.cached_weight = None + if module._cached_weight is not None and not module.cache_inference_quant_weight_metadata_only: + self.cached_weight = module._cached_weight + else: + super().prepare_for_export(module) + + def forward(self, x) -> Tuple[torch.Tensor]: + if self.cached_weight is not None: + x = self.cached_weight + else: + x = self.dequantize(self.quantize(x)) + return x, self.scale, self.zero_point, self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias, self.saturating, self.inf_values, self.nan_values diff --git a/src/brevitas/export/inference/manager.py b/src/brevitas/export/inference/manager.py index 15723c0c0..936106884 100644 --- a/src/brevitas/export/inference/manager.py +++ b/src/brevitas/export/inference/manager.py @@ -1,8 +1,13 @@ +# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + from torch.nn import Module import torch.nn as nn from brevitas.export.inference.handler import FloatInferencetHandler +from brevitas.export.inference.handler import FloatWeightInferencetHandler from brevitas.export.inference.handler import IntInferencetHandler +from brevitas.export.inference.handler import IntWeightInferencetHandler from brevitas.export.manager import _set_proxy_export_handler from brevitas.export.manager import _set_proxy_export_mode from brevitas.export.manager import _set_recurrent_layer_export_handler @@ -32,7 +37,7 @@ def _override_weight_caching_mode(m: nn.Module, enabled: bool, metadata_only: bo _override_caching_mode(m, 'weight', enabled, metadata_only) -class inference_mode: +class quant_inference_mode: def __init__(self, model, cache_quant_weight=False, enabled=True): self.model = model @@ -84,7 +89,11 @@ def hook(self, module, inp, out): # Inheritance from BaseManager is not techincally needed class InferenceManager(BaseManager): - handlers = [IntInferencetHandler, FloatInferencetHandler] + handlers = [ + IntInferencetHandler, + FloatInferencetHandler, + IntWeightInferencetHandler, + FloatWeightInferencetHandler] @classmethod def set_export_mode(cls, model: Module, enabled: bool): diff --git a/src/brevitas/proxy/parameter_quant.py b/src/brevitas/proxy/parameter_quant.py index d76177e52..dc7c704c3 100644 --- a/src/brevitas/proxy/parameter_quant.py +++ b/src/brevitas/proxy/parameter_quant.py @@ -136,7 +136,6 @@ def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]: if self.is_quant_enabled: # If quant is enabled the priority is: # - export mode - # - cached weight # - quantization flow if self.export_mode: out = self.export_handler(x) @@ -144,11 +143,6 @@ def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]: out = out[0] else: out = self.create_quant_tensor(out) - elif self._cached_weight is not None and not self.cache_inference_quant_weight_metadata_only: - if is_dynamo_compiling(): - out = self._cached_weight.value - else: - out = self._cached_weight.quant_tensor else: out = self.tensor_quant(x) if is_dynamo_compiling(): diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py index a561dc0c4..f33946079 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py @@ -18,7 +18,7 @@ from brevitas.export import export_onnx_qcdq from brevitas.export import export_torch_qcdq -from brevitas.export.inference.manager import inference_mode +from brevitas.export.inference import quant_inference_mode from brevitas.graph.quantize import preprocess_for_quantize from brevitas.graph.target.flexml import preprocess_for_flexml_quantize from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_act_equalization @@ -268,6 +268,7 @@ def parse_type(v, default_type): 'uint_sym_act_for_unsigned_values', default=True, help='Use unsigned act quant when possible (default: enabled)') +add_bool_arg(parser, 'compile', default=False, help='Use torch.compile (default: disabled)') def generate_ref_input(args, device, dtype): @@ -482,13 +483,14 @@ def main(): # Validate the quant_model on the validation dataloader print("Starting validation:") - with torch.no_grad(), inference_mode(quant_model): + with torch.no_grad(), quant_inference_mode(quant_model): param = next(iter(quant_model.parameters())) device, dtype = param.device, param.dtype ref_input = generate_ref_input(args, device, dtype) quant_model(ref_input) - quant_model = torch.compile(quant_model, fullgraph=True, dynamic=True) - validate(val_loader, quant_model, stable=dtype != torch.bfloat16) + compiled_model = torch.compile( + quant_model, fullgraph=True, dynamic=True, disable=not args.compile) + validate(val_loader, compiled_model, stable=dtype != torch.bfloat16) if args.export_onnx_qcdq or args.export_torch_qcdq: # Generate reference input tensor to drive the export process diff --git a/src/brevitas_examples/stable_diffusion/main.py b/src/brevitas_examples/stable_diffusion/main.py index 3db0bca9a..5ac69d147 100644 --- a/src/brevitas_examples/stable_diffusion/main.py +++ b/src/brevitas_examples/stable_diffusion/main.py @@ -29,7 +29,6 @@ from brevitas.graph.base import ModuleToModuleByClass from brevitas.graph.calibrate import bias_correction_mode from brevitas.graph.calibrate import calibration_mode -from brevitas.graph.calibrate import inference_mode from brevitas.graph.calibrate import load_quant_model_mode from brevitas.graph.equalize import activation_equalization_mode from brevitas.graph.gptq import gptq_mode @@ -150,6 +149,7 @@ def main(args): calibration_prompts = CALIBRATION_PROMPTS if args.calibration_prompt_path is not None: calibration_prompts = load_calib_prompts(args.calibration_prompt_path) + print(args.calibration_prompt, len(calibration_prompts)) assert args.calibration_prompt <= len(calibration_prompts) , f"Only {len(calibration_prompts)} prompts are available" calibration_prompts = calibration_prompts[:args.calibration_prompt] @@ -231,6 +231,8 @@ def main(args): non_blacklist[name_to_add] = 1 else: non_blacklist[name_to_add] += 1 + print(f"Blacklisted layers: {set(blacklist)}") + print(f"Non blacklisted layers: {non_blacklist}") # Make sure there all LoRA layers are fused first, otherwise raise an error for m in pipe.unet.modules(): diff --git a/tests/brevitas_end_to_end/test_torchvision_models.py b/tests/brevitas_end_to_end/test_torchvision_models.py index af409f8ab..528319f26 100644 --- a/tests/brevitas_end_to_end/test_torchvision_models.py +++ b/tests/brevitas_end_to_end/test_torchvision_models.py @@ -13,7 +13,7 @@ from brevitas import torch_version from brevitas.export import export_onnx_qcdq from brevitas.export import export_torch_qcdq -from brevitas.export.inference.manager import inference_mode +from brevitas.export.inference import quant_inference_mode from brevitas.graph.calibrate import calibration_mode from brevitas.graph.quantize import layerwise_quantize from brevitas.graph.quantize import quantize @@ -28,19 +28,19 @@ MODEL_LIST = [ 'vit_b_32', 'efficientnet_b0', - # 'mobilenet_v3_small', - # 'mobilenet_v2', - # 'resnet50', - # 'resnet18', - # 'mnasnet0_5', - # 'alexnet', - # 'googlenet', - # 'vgg11', - # 'densenet121', - # 'deeplabv3_resnet50', - # 'fcn_resnet50', - # 'regnet_x_400mf', - # 'squeezenet1_0', + 'mobilenet_v3_small', + 'mobilenet_v2', + 'resnet50', + 'resnet18', + 'mnasnet0_5', + 'alexnet', + 'googlenet', + 'vgg11', + 'densenet121', + 'deeplabv3_resnet50', + 'fcn_resnet50', + 'regnet_x_400mf', + 'squeezenet1_0', 'inception_v3'] @@ -71,7 +71,7 @@ def quantize_float(model): @fixture @parametrize('model_name', MODEL_LIST) -@parametrize('quantize_fn', [quantize_float]) +@parametrize('quantize_fn', [quantize_float, quantize, layerwise_quantize, quantize_flexml]) def torchvision_model(model_name, quantize_fn): inp = torch.randn(BATCH, IN_CH, HEIGHT, WIDTH) @@ -122,28 +122,27 @@ def test_torchvision_graph_quantization_flexml_qcdq_onnx( pytest.skip('Model not instantiated') if enable_compile: model_name = test_id.split("-")[1] - if torch_version <= version.parse('2.0'): - pytest.skip("Pytorch 2.0 is required to test compile") + if torch_version <= version.parse('2.2'): + pytest.skip("Pytorch 2.2 is required to test compile") + else: + torch._dynamo.config.capture_scalar_outputs = True if 'vit' in model_name: pytest.skip("QuantMHA not supported with compile") inp = torch.randn(BATCH, IN_CH, HEIGHT, WIDTH) quantize_fn_name = test_id.split("-")[0] - if enable_compile: - torch._dynamo.config.capture_scalar_outputs = True - with torch.no_grad(), inference_mode(torchvision_model): - prehook_non_compiled_out = torchvision_model(inp) - post_hook_non_compiled_out = torchvision_model(inp) - assert torch.allclose(prehook_non_compiled_out, post_hook_non_compiled_out) + with torch.no_grad(), quant_inference_mode(torchvision_model): + prehook_non_compiled_out = torchvision_model(inp) + post_hook_non_compiled_out = torchvision_model(inp) + assert torch.allclose(prehook_non_compiled_out, post_hook_non_compiled_out) + if enable_compile: compiled_model = torch.compile(torchvision_model, fullgraph=True) compiled_out = compiled_model(inp) - # This fails! Compile might needs more small-scoped tests for accuracy evaluation - # assert torch.allclose(post_hook_non_compiled_out, compiled_out) - else: - torchvision_model(inp) + # This fails! Compile might needs more small-scoped tests for accuracy evaluation + # assert torch.allclose(post_hook_non_compiled_out, compiled_out) if quantize_fn_name != 'quantize_float' and not enable_compile: export_onnx_qcdq(torchvision_model, args=inp) From d50a31186df43dc6b319ece62d09e5d999e441d4 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 10 Sep 2024 10:22:45 +0100 Subject: [PATCH 15/25] missing file --- src/brevitas/export/inference/__init__.py | 5 +++++ src/brevitas_examples/stable_diffusion/main.py | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) create mode 100644 src/brevitas/export/inference/__init__.py diff --git a/src/brevitas/export/inference/__init__.py b/src/brevitas/export/inference/__init__.py new file mode 100644 index 000000000..0e6d113e0 --- /dev/null +++ b/src/brevitas/export/inference/__init__.py @@ -0,0 +1,5 @@ +# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +from .manager import InferenceManager +from .manager import quant_inference_mode diff --git a/src/brevitas_examples/stable_diffusion/main.py b/src/brevitas_examples/stable_diffusion/main.py index 5ac69d147..a1c4fef53 100644 --- a/src/brevitas_examples/stable_diffusion/main.py +++ b/src/brevitas_examples/stable_diffusion/main.py @@ -149,7 +149,7 @@ def main(args): calibration_prompts = CALIBRATION_PROMPTS if args.calibration_prompt_path is not None: calibration_prompts = load_calib_prompts(args.calibration_prompt_path) - print(args.calibration_prompt, len(calibration_prompts)) + print(args.calibration_prompt, len(calibration_prompts)) assert args.calibration_prompt <= len(calibration_prompts) , f"Only {len(calibration_prompts)} prompts are available" calibration_prompts = calibration_prompts[:args.calibration_prompt] From b4d744d2f0c672a23593e1657a45eb783eaf4f45 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 10 Sep 2024 10:50:43 +0100 Subject: [PATCH 16/25] fix super --- src/brevitas/export/inference/handler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/brevitas/export/inference/handler.py b/src/brevitas/export/inference/handler.py index bd81cadc3..39a047437 100644 --- a/src/brevitas/export/inference/handler.py +++ b/src/brevitas/export/inference/handler.py @@ -71,7 +71,7 @@ def prepare_for_export(self, module): if module._cached_weight is not None and not module.cache_inference_quant_weight_metadata_only: self.cached_weight = module._cached_weight else: - super.prepare_for_export(module) + super().prepare_for_export(module) def forward(self, x) -> Tuple[torch.Tensor]: if self.cached_weight is not None: From d4ce3390ac0805f44dc56ada1f7f49cd43da2b09 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 10 Sep 2024 10:54:03 +0100 Subject: [PATCH 17/25] correct weight handling --- src/brevitas/export/inference/handler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/brevitas/export/inference/handler.py b/src/brevitas/export/inference/handler.py index 39a047437..d74fde7b7 100644 --- a/src/brevitas/export/inference/handler.py +++ b/src/brevitas/export/inference/handler.py @@ -137,7 +137,7 @@ def forward(self, x) -> Tuple[torch.Tensor]: class FloatWeightInferencetHandler(FloatInferencetHandler): - handled_layer = (ActFloatQuantProxyFromInjector, WeightFloatQuantProxyFromInjector) + handled_layer = WeightFloatQuantProxyFromInjector def prepare_for_export(self, module): if module.is_quant_enabled: From 3ca7505abaec6bffcde0677c5fc32a704d1df6f6 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 10 Sep 2024 12:21:19 +0100 Subject: [PATCH 18/25] parallel tests end_to_end --- noxfile.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/noxfile.py b/noxfile.py index ffb1c5fbd..59ff7122f 100644 --- a/noxfile.py +++ b/noxfile.py @@ -174,4 +174,4 @@ def tests_brevitas_end_to_end(session, pytorch): install_pytorch(pytorch, session) install_torchvision(pytorch, session) session.install('--upgrade', '-e', '.[test, ort_integration]') - session.run('pytest', '-v', 'tests/brevitas_end_to_end') + session.run('pytest', '-n', 'logical', '-v', 'tests/brevitas_end_to_end') From 9d82357fb140a7f5571866e25c5d543f5a2c3778 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 11 Sep 2024 02:12:57 +0100 Subject: [PATCH 19/25] Fixes --- src/brevitas/export/inference/handler.py | 10 ++++------ .../imagenet_classification/ptq/ptq_evaluate.py | 3 +-- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/src/brevitas/export/inference/handler.py b/src/brevitas/export/inference/handler.py index d74fde7b7..1416014ec 100644 --- a/src/brevitas/export/inference/handler.py +++ b/src/brevitas/export/inference/handler.py @@ -68,10 +68,9 @@ class IntWeightInferencetHandler(IntInferencetHandler): def prepare_for_export(self, module): if module.is_quant_enabled: self.cached_weight = None + super().prepare_for_export(module) if module._cached_weight is not None and not module.cache_inference_quant_weight_metadata_only: - self.cached_weight = module._cached_weight - else: - super().prepare_for_export(module) + self.cached_weight = module._cached_weight.value def forward(self, x) -> Tuple[torch.Tensor]: if self.cached_weight is not None: @@ -142,10 +141,9 @@ class FloatWeightInferencetHandler(FloatInferencetHandler): def prepare_for_export(self, module): if module.is_quant_enabled: self.cached_weight = None + super().prepare_for_export(module) if module._cached_weight is not None and not module.cache_inference_quant_weight_metadata_only: - self.cached_weight = module._cached_weight - else: - super().prepare_for_export(module) + self.cached_weight = module._cached_weight.value def forward(self, x) -> Tuple[torch.Tensor]: if self.cached_weight is not None: diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py index f33946079..4520bc3e2 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py @@ -488,8 +488,7 @@ def main(): device, dtype = param.device, param.dtype ref_input = generate_ref_input(args, device, dtype) quant_model(ref_input) - compiled_model = torch.compile( - quant_model, fullgraph=True, dynamic=True, disable=not args.compile) + compiled_model = torch.compile(quant_model, fullgraph=True, disable=not args.compile) validate(val_loader, compiled_model, stable=dtype != torch.bfloat16) if args.export_onnx_qcdq or args.export_torch_qcdq: From ea4250696a46b0a2122699ae7782146b08559025 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 13 Sep 2024 13:13:34 +0100 Subject: [PATCH 20/25] fix --- src/brevitas/proxy/groupwise_int_runtime_quant.py | 2 +- src/brevitas/proxy/runtime_quant.py | 6 +++--- .../imagenet_classification/ptq/ptq_common.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/brevitas/proxy/groupwise_int_runtime_quant.py b/src/brevitas/proxy/groupwise_int_runtime_quant.py index ec9418e19..453cb3f9b 100644 --- a/src/brevitas/proxy/groupwise_int_runtime_quant.py +++ b/src/brevitas/proxy/groupwise_int_runtime_quant.py @@ -31,7 +31,7 @@ def create_quant_tensor( qt_args: Union[torch.Tensor, Tuple[Any]], x: Optional[GroupwiseIntQuantTensor] = None) -> GroupwiseIntQuantTensor: if x is None: - value, scale, zero_point, bit_width, = qt_args + value, scale, zero_point, bit_width = qt_args out = GroupwiseIntQuantTensor( value, scale, diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index 03303bcc8..b2ded7239 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -123,11 +123,11 @@ def internal_forward(self, force_eval): return out def retrieve_attribute(self, attribute, force_eval): - if self.is_quant_enabled: + if self._cached_act is not None: + return getattr(self._cached_act, attribute) + elif self.is_quant_enabled: out = self.internal_forward(force_eval) return getattr(out, attribute) - elif self._cached_act is not None: - return getattr(self._cached_act, attribute) elif self._cached_act is None: return None diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py index 661328d7e..a6cdd2af7 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py @@ -178,7 +178,7 @@ class CNNInt8DynamicActPerTensorFloat(Int8DynamicActPerTensorFloat): 'asym': CNNShiftedUint8DynamicActPerTensorFloat}}}, 'po2_scale': { 'stats': { - 'per_group': MXInt8Act}}}}, + 'per_group': {'sym':MXInt8Act} }}}}, 'float': { 'static': { 'float_scale': { From 3dc7dfa6ad6e6afc1b49a5578c625e20e3401b6f Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 13 Sep 2024 13:13:59 +0100 Subject: [PATCH 21/25] precommit fix --- .../imagenet_classification/ptq/ptq_common.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py index a6cdd2af7..9e5c90e26 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py @@ -178,7 +178,8 @@ class CNNInt8DynamicActPerTensorFloat(Int8DynamicActPerTensorFloat): 'asym': CNNShiftedUint8DynamicActPerTensorFloat}}}, 'po2_scale': { 'stats': { - 'per_group': {'sym':MXInt8Act} }}}}, + 'per_group': { + 'sym': MXInt8Act}}}}}, 'float': { 'static': { 'float_scale': { From 04231e469c663d8299d06703c1cf1e1c8b72f669 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 13 Sep 2024 16:22:28 +0100 Subject: [PATCH 22/25] Fix import --- src/brevitas/__init__.py | 6 ++++++ src/brevitas/nn/mixin/base.py | 6 +----- src/brevitas/proxy/parameter_quant.py | 12 ++---------- src/brevitas/proxy/runtime_quant.py | 9 +-------- 4 files changed, 10 insertions(+), 23 deletions(-) diff --git a/src/brevitas/__init__.py b/src/brevitas/__init__.py index eddc35a02..fe46102a7 100644 --- a/src/brevitas/__init__.py +++ b/src/brevitas/__init__.py @@ -23,6 +23,12 @@ else: torch_version = version.parse(torch.__version__) +try: + # Attempt _dynamo import + is_dynamo_compiling = torch._dynamo.is_compiling +except: + is_dynamo_compiling = lambda: False + try: __version__ = get_distribution(__name__).version except DistributionNotFound: diff --git a/src/brevitas/nn/mixin/base.py b/src/brevitas/nn/mixin/base.py index 59b559787..a5c4407fd 100644 --- a/src/brevitas/nn/mixin/base.py +++ b/src/brevitas/nn/mixin/base.py @@ -16,6 +16,7 @@ from torch.nn.utils.rnn import PackedSequence from brevitas import config +from brevitas import is_dynamo_compiling from brevitas import torch_version from brevitas.common import ExportMixin from brevitas.inject import ExtendedInjector @@ -29,11 +30,6 @@ from .utils import filter_kwargs -if torch_version < packaging.version.parse('2.0'): - is_dynamo_compiling = lambda: False -else: - is_dynamo_compiling = torch._dynamo.is_compiling - class QuantProxyMixin(object): __metaclass__ = ABCMeta diff --git a/src/brevitas/proxy/parameter_quant.py b/src/brevitas/proxy/parameter_quant.py index dc7c704c3..f28233aed 100644 --- a/src/brevitas/proxy/parameter_quant.py +++ b/src/brevitas/proxy/parameter_quant.py @@ -7,23 +7,15 @@ from typing import Any, List, Optional, Tuple, Union from warnings import warn -import packaging.version import torch - -from brevitas import torch_version -from brevitas.core.function_wrapper.misc import Identity - -if torch_version < packaging.version.parse('2.0'): - is_dynamo_compiling = lambda: False -else: - is_dynamo_compiling = torch._dynamo.is_compiling - from torch import Tensor import torch.nn as nn from typing_extensions import Protocol from typing_extensions import runtime_checkable from brevitas import config +from brevitas import is_dynamo_compiling +from brevitas.core.function_wrapper.misc import Identity from brevitas.function import max_int from brevitas.inject import BaseInjector as Injector from brevitas.quant_tensor import _unpack_quant_tensor diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index b2ded7239..9feb593b4 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -5,15 +5,7 @@ from abc import abstractmethod from typing import Any, Optional, Tuple, Union -import packaging.version import torch - -from brevitas import torch_version - -if torch_version < packaging.version.parse('2.0'): - is_dynamo_compiling = lambda: False -else: - is_dynamo_compiling = torch._dynamo.is_compiling from torch import nn from torch import Tensor from torch.nn import Identity @@ -21,6 +13,7 @@ from typing_extensions import runtime_checkable import brevitas +from brevitas import is_dynamo_compiling from brevitas.quant_tensor import IntQuantTensor from brevitas.quant_tensor import QuantTensor from brevitas.utils.quant_utils import _CachedIO From 762e613c8a5f1b7dec1282312bccd6ccb1b28b7c Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 13 Sep 2024 16:47:10 +0100 Subject: [PATCH 23/25] tests fix --- tests/brevitas_end_to_end/test_torchvision_models.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/brevitas_end_to_end/test_torchvision_models.py b/tests/brevitas_end_to_end/test_torchvision_models.py index 528319f26..80ee9a1b9 100644 --- a/tests/brevitas_end_to_end/test_torchvision_models.py +++ b/tests/brevitas_end_to_end/test_torchvision_models.py @@ -122,8 +122,11 @@ def test_torchvision_graph_quantization_flexml_qcdq_onnx( pytest.skip('Model not instantiated') if enable_compile: model_name = test_id.split("-")[1] + quant_func = test_id.split("-")[0] if torch_version <= version.parse('2.2'): pytest.skip("Pytorch 2.2 is required to test compile") + elif quant_func not in ('quantize_float', 'quantize'): + pytest.skip("Compile is tested only against base float and int quantization functions") else: torch._dynamo.config.capture_scalar_outputs = True if 'vit' in model_name: @@ -141,8 +144,8 @@ def test_torchvision_graph_quantization_flexml_qcdq_onnx( compiled_model = torch.compile(torchvision_model, fullgraph=True) compiled_out = compiled_model(inp) - # This fails! Compile might needs more small-scoped tests for accuracy evaluation - # assert torch.allclose(post_hook_non_compiled_out, compiled_out) + print(torch.max(torch.abs(post_hook_non_compiled_out - compiled_out))) + assert torch.allclose(post_hook_non_compiled_out, compiled_out) if quantize_fn_name != 'quantize_float' and not enable_compile: export_onnx_qcdq(torchvision_model, args=inp) From 6b491889a21fe152796d331e1a9bb85388e8094f Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 13 Sep 2024 19:00:19 +0100 Subject: [PATCH 24/25] Added compile test with tolerance --- tests/brevitas_end_to_end/test_torchvision_models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/brevitas_end_to_end/test_torchvision_models.py b/tests/brevitas_end_to_end/test_torchvision_models.py index 80ee9a1b9..f00920f3e 100644 --- a/tests/brevitas_end_to_end/test_torchvision_models.py +++ b/tests/brevitas_end_to_end/test_torchvision_models.py @@ -22,6 +22,7 @@ from brevitas_examples.imagenet_classification.ptq.ptq_common import quantize_model from tests.marker import requires_pt_ge +TORCH_COMPILE_ATOL = 0.35 BATCH = 1 HEIGHT, WIDTH = 224, 224 IN_CH = 3 @@ -144,8 +145,7 @@ def test_torchvision_graph_quantization_flexml_qcdq_onnx( compiled_model = torch.compile(torchvision_model, fullgraph=True) compiled_out = compiled_model(inp) - print(torch.max(torch.abs(post_hook_non_compiled_out - compiled_out))) - assert torch.allclose(post_hook_non_compiled_out, compiled_out) + assert torch.allclose(post_hook_non_compiled_out, compiled_out, atol=TORCH_COMPILE_ATOL) if quantize_fn_name != 'quantize_float' and not enable_compile: export_onnx_qcdq(torchvision_model, args=inp) From 24d84c10a203e219db28bd9c95890ca6448a612e Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 17 Sep 2024 02:37:15 +0100 Subject: [PATCH 25/25] Fix test structure, hopefully faster --- .../test_torchvision_models.py | 72 ++++++++++--------- 1 file changed, 40 insertions(+), 32 deletions(-) diff --git a/tests/brevitas_end_to_end/test_torchvision_models.py b/tests/brevitas_end_to_end/test_torchvision_models.py index f00920f3e..09f0b9253 100644 --- a/tests/brevitas_end_to_end/test_torchvision_models.py +++ b/tests/brevitas_end_to_end/test_torchvision_models.py @@ -26,6 +26,9 @@ BATCH = 1 HEIGHT, WIDTH = 224, 224 IN_CH = 3 + +COMPILE_MODEL_LIST = ['efficientnet_b0', 'resnet18', 'fcn_resnet50'] + MODEL_LIST = [ 'vit_b_32', 'efficientnet_b0', @@ -70,11 +73,7 @@ def quantize_float(model): quant_format='float') -@fixture -@parametrize('model_name', MODEL_LIST) -@parametrize('quantize_fn', [quantize_float, quantize, layerwise_quantize, quantize_flexml]) -def torchvision_model(model_name, quantize_fn): - +def shared_quant_fn(model_name, quantize_fn): inp = torch.randn(BATCH, IN_CH, HEIGHT, WIDTH) if torch_version <= version.parse('1.9.1') and model_name == 'regnet_x_400mf': @@ -114,44 +113,53 @@ def torchvision_model(model_name, quantize_fn): return model -@requires_pt_ge('1.8.1') -@parametrize('enable_compile', [True, False]) -def test_torchvision_graph_quantization_flexml_qcdq_onnx( - torchvision_model, enable_compile, request): - test_id = request.node.callspec.id - if torchvision_model is None: +@fixture +@parametrize('model_name', MODEL_LIST) +@parametrize('quantize_fn', [quantize_float, quantize, layerwise_quantize, quantize_flexml]) +def torchvision_model(model_name, quantize_fn): + return shared_quant_fn(model_name, quantize_fn) + + +@fixture +@parametrize('model_name', COMPILE_MODEL_LIST) +@parametrize('quantize_fn', [quantize_float, quantize]) +def torchvision_model_compile(model_name, quantize_fn): + return shared_quant_fn(model_name, quantize_fn) + + +@requires_pt_ge('2.2') +def test_torchvision_compile(torchvision_model_compile): + torch._dynamo.config.capture_scalar_outputs = True + if torchvision_model_compile is None: pytest.skip('Model not instantiated') - if enable_compile: - model_name = test_id.split("-")[1] - quant_func = test_id.split("-")[0] - if torch_version <= version.parse('2.2'): - pytest.skip("Pytorch 2.2 is required to test compile") - elif quant_func not in ('quantize_float', 'quantize'): - pytest.skip("Compile is tested only against base float and int quantization functions") - else: - torch._dynamo.config.capture_scalar_outputs = True - if 'vit' in model_name: - pytest.skip("QuantMHA not supported with compile") inp = torch.randn(BATCH, IN_CH, HEIGHT, WIDTH) - quantize_fn_name = test_id.split("-")[0] - with torch.no_grad(), quant_inference_mode(torchvision_model): - prehook_non_compiled_out = torchvision_model(inp) - post_hook_non_compiled_out = torchvision_model(inp) + with torch.no_grad(), quant_inference_mode(torchvision_model_compile): + prehook_non_compiled_out = torchvision_model_compile(inp) + post_hook_non_compiled_out = torchvision_model_compile(inp) + + compiled_model = torch.compile(torchvision_model_compile, fullgraph=True) + compiled_out = compiled_model(inp) + assert torch.allclose(prehook_non_compiled_out, post_hook_non_compiled_out) + assert torch.allclose(post_hook_non_compiled_out, compiled_out, atol=TORCH_COMPILE_ATOL) + - if enable_compile: - compiled_model = torch.compile(torchvision_model, fullgraph=True) - compiled_out = compiled_model(inp) +def test_torchvision_graph_quantization_flexml_qcdq_onnx(torchvision_model, request): + test_id = request.node.callspec.id + if torchvision_model is None: + pytest.skip('Model not instantiated') - assert torch.allclose(post_hook_non_compiled_out, compiled_out, atol=TORCH_COMPILE_ATOL) + inp = torch.randn(BATCH, IN_CH, HEIGHT, WIDTH) + + quantize_fn_name = test_id.split("-")[0] + torchvision_model(inp) - if quantize_fn_name != 'quantize_float' and not enable_compile: + if quantize_fn_name != 'quantize_float': export_onnx_qcdq(torchvision_model, args=inp) -@requires_pt_ge('1.9.1') def test_torchvision_graph_quantization_flexml_qcdq_torch(torchvision_model, request): if torchvision_model is None: pytest.skip('Model not instantiated')