From 3612e90ac04496b77a55656dada0282f78eda1f9 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 18 Dec 2024 10:15:40 +0000 Subject: [PATCH] Feat (proxy): flag to enable/disable QT return (#1083) --- src/brevitas/export/common/handler/qcdq.py | 1 + src/brevitas/export/inference/handler.py | 34 ++++++++++++++++++---- src/brevitas/export/inference/manager.py | 11 +++++++ src/brevitas/export/manager.py | 6 ++-- src/brevitas/export/onnx/manager.py | 2 -- src/brevitas/nn/quant_avg_pool.py | 21 +++++++++++-- src/brevitas/proxy/parameter_quant.py | 12 +++++--- src/brevitas/proxy/runtime_quant.py | 15 +++++++++- 8 files changed, 84 insertions(+), 18 deletions(-) diff --git a/src/brevitas/export/common/handler/qcdq.py b/src/brevitas/export/common/handler/qcdq.py index bbc03b630..39347baad 100644 --- a/src/brevitas/export/common/handler/qcdq.py +++ b/src/brevitas/export/common/handler/qcdq.py @@ -796,6 +796,7 @@ def symbolic_execution( flat_pre_scale = to_0dim_if_scalar(pre_scale.flatten()) flat_scale = to_0dim_if_scalar(scale.flatten()) zp = to_0dim_if_scalar(zero_point.flatten()).expand_as(flat_scale) + zp = self.zero_point_with_dtype(signed, output_bit_width, zp) x = self.quantize_fn(x, flat_pre_scale, zp, dtype, self.quant_axis(pre_scale)) clip_symbolic_kwargs = self.int_clip_symbolic_kwargs( signed=signed, narrow=False, bit_width=output_bit_width) diff --git a/src/brevitas/export/inference/handler.py b/src/brevitas/export/inference/handler.py index 32c1ac5ac..c7fc21790 100644 --- a/src/brevitas/export/inference/handler.py +++ b/src/brevitas/export/inference/handler.py @@ -9,7 +9,6 @@ from torch import Tensor import torch.nn as nn -from brevitas import is_dynamo_compiling from brevitas.function.ops import max_float from brevitas.function.ops import max_int from brevitas.function.ops import min_int @@ -110,6 +109,10 @@ def forward(self, x: Tensor, unused_scale: Tensor = None) -> Tuple[Tensor]: class GroupwiseIntInferenceHandler(IntInferencetHandler): handled_layer = GroupwiseActQuantProxyFromInjector + def __init__(self): + super().__init__() + self.skip_create_quant_tensor = False + def prepare_for_export(self, module): if module.is_quant_enabled: self.module_forward = module.fused_activation_quant_proxy @@ -117,7 +120,9 @@ def prepare_for_export(self, module): def forward(self, x: Tensor, unused_scale: Tensor = None) -> Tuple[Tensor]: x, *other = self.module_forward(x) - if is_dynamo_compiling(): + + # If we skip quant tensor, we return the flattened version of the groupwise tensor + if self.skip_create_quant_tensor: start_dim = self.group_dim if self.group_dim != -1 else -2 x = x.flatten(start_dim, start_dim + 1) output_args = tuple([x] + list(other)) @@ -127,6 +132,10 @@ def forward(self, x: Tensor, unused_scale: Tensor = None) -> Tuple[Tensor]: class GroupwiseIntWeightInferenceHandler(IntWeightInferencetHandler): handled_layer = GroupwiseWeightQuantProxyFromInjector + def __init__(self): + super().__init__() + self.skip_create_quant_tensor = False + def prepare_for_export(self, module): super().prepare_for_export(module) if module.is_quant_enabled: @@ -151,7 +160,9 @@ def forward(self, x: Tensor) -> Tuple[Tensor]: else: x = self.input_view(x) out = self.dequantize(self.quantize(x, scale, zero_point), scale, zero_point) - if is_dynamo_compiling(): + + # If we skip quant tensor, we return the flattened version of the groupwise tensor + if self.skip_create_quant_tensor: out = self.flattened_view(out) return out, scale, zero_point, self.bit_width @@ -242,6 +253,10 @@ def forward(self, x: Tensor) -> Tuple[Tensor]: class GroupwiseFloatInferenceHandler(FloatInferencetHandler): handled_layer = GroupwiseActFloatQuantProxyFromInjector + def __init__(self): + super().__init__() + self.skip_create_quant_tensor = False + def prepare_for_export(self, module: nn.Module): if module.is_quant_enabled: self.module_forward = module.fused_activation_quant_proxy @@ -249,7 +264,9 @@ def prepare_for_export(self, module: nn.Module): def forward(self, x: Tensor) -> Tuple[Tensor]: x, *other = self.module_forward(x) - if is_dynamo_compiling(): + + # If we skip quant tensor, we return the flattened version of the groupwise tensor + if self.skip_create_quant_tensor: start_dim = self.group_dim if self.group_dim != -1 else -2 x = x.flatten(start_dim, start_dim + 1) output_args = tuple([x] + list(other)) @@ -259,6 +276,10 @@ def forward(self, x: Tensor) -> Tuple[Tensor]: class GroupwiseFloatWeightInferenceHandler(FloatWeightInferencetHandler): handled_layer = GroupwiseWeightFloatQuantProxyFromInjector + def __init__(self): + super().__init__() + self.skip_create_quant_tensor = False + def prepare_for_export(self, module: nn.Module): super().prepare_for_export(module) if module.is_quant_enabled: @@ -283,6 +304,9 @@ def forward(self, x: Tensor) -> Tuple[Tensor]: else: x = self.input_view(x) out = self.dequantize(self.quantize(x, scale, zero_point), scale, zero_point) - if is_dynamo_compiling(): + + # If we skip quant tensor, we return the flattened version of the groupwise tensor + if self.skip_create_quant_tensor: out = self.flattened_view(out) + return out, scale, zero_point, self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias, self.saturating, self.inf_values, self.nan_values diff --git a/src/brevitas/export/inference/manager.py b/src/brevitas/export/inference/manager.py index b78a888f2..eeeedabaf 100644 --- a/src/brevitas/export/inference/manager.py +++ b/src/brevitas/export/inference/manager.py @@ -1,6 +1,8 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +from functools import partial + from torch.nn import Module import torch.nn as nn @@ -42,6 +44,11 @@ def _override_weight_caching_mode(m: nn.Module, enabled: bool, metadata_only: bo _override_caching_mode(m, 'weight', enabled, metadata_only) +def _override_create_quant_tensor(m: nn.Module, state: bool): + if hasattr(m, 'skip_create_quant_tensor'): + m.skip_create_quant_tensor = state + + class quant_inference_mode: def __init__(self, model, cache_quant_weight=False, enabled=True): @@ -79,6 +86,8 @@ def __exit__(self, type, value, traceback): self.model.apply( lambda m: _override_weight_caching_mode(m, enabled=False, metadata_only=False)) restore_return_quant_tensor(self.model, self.return_quant_tensor_state) + enable_quant_tensor = partial(_override_create_quant_tensor, state=False) + self.model.apply(enable_quant_tensor) def hook(self, module, inp, out): # After one forward pass with caching enabled, we can: @@ -90,6 +99,8 @@ def hook(self, module, inp, out): self.model.apply(InferenceManager.set_export_handler) InferenceManager.set_export_mode(self.model, enabled=True) self.return_quant_tensor_state = disable_return_quant_tensor(self.model) + disable_quant_tensor = partial(_override_create_quant_tensor, state=True) + self.model.apply(disable_quant_tensor) # Inheritance from BaseManager is not techincally needed diff --git a/src/brevitas/export/manager.py b/src/brevitas/export/manager.py index 7b7e7a145..550711e16 100644 --- a/src/brevitas/export/manager.py +++ b/src/brevitas/export/manager.py @@ -3,10 +3,8 @@ from abc import ABC from abc import abstractmethod -from contextlib import ExitStack -from functools import partial from io import BytesIO -from typing import Optional, Tuple, Union +from typing import Tuple, Union import torch from torch import nn @@ -20,7 +18,6 @@ from brevitas.quant_tensor import QuantTensor from brevitas.utils.jit_utils import clear_class_registry from brevitas.utils.python_utils import patch -from brevitas.utils.quant_utils import _CachedIO class _JitTraceExportWrapper(nn.Module): @@ -219,6 +216,7 @@ def jit_inference_trace( # wrapping with a lambda forces inlining during tracing, # converts everything to const and removes unused params/buffers traced_model = torch.jit.trace(_JitTraceExportWrapper(module), args) + # Hack to clone the function, otherwise restoring requires_grad # on module will break traced_model with BytesIO() as tmp: diff --git a/src/brevitas/export/onnx/manager.py b/src/brevitas/export/onnx/manager.py index ae3270cc9..cbede2dca 100644 --- a/src/brevitas/export/onnx/manager.py +++ b/src/brevitas/export/onnx/manager.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: BSD-3-Clause from abc import ABC -from contextlib import ExitStack from io import BytesIO from typing import Optional, Tuple, Union import warnings @@ -167,7 +166,6 @@ def export_onnx( with PatchFp8Ops(): torch.onnx.export(module, args, export_target, **onnx_export_kwargs) - # restore the model to previous properties module.apply(lambda m: _restore_act_caching_mode(m)) cls.set_export_mode(module, enabled=False) diff --git a/src/brevitas/nn/quant_avg_pool.py b/src/brevitas/nn/quant_avg_pool.py index 7a3f108da..12f324901 100644 --- a/src/brevitas/nn/quant_avg_pool.py +++ b/src/brevitas/nn/quant_avg_pool.py @@ -15,6 +15,7 @@ from brevitas.inject.defaults import RoundTo8bit from brevitas.quant_tensor import _unpack_quant_tensor from brevitas.quant_tensor import QuantTensor +from brevitas.utils.quant_utils import _CachedIO from .mixin.acc import AccQuantType from .mixin.acc import TruncMixin @@ -38,6 +39,9 @@ def __init__( AvgPool2d.__init__(self, kernel_size=kernel_size, stride=stride) QuantLayerMixin.__init__(self, return_quant_tensor) TruncMixin.__init__(self, trunc_quant=trunc_quant, **kwargs) + self.cache_inference_quant_act = False + self.cache_quant_io_metadata_only = True + self.cache_class = None @property def channelwise_separable(self) -> bool: @@ -60,7 +64,12 @@ def forward(self, input: Union[Tensor, QuantTensor]): if self.export_mode: return self.export_handler(_unpack_quant_tensor(x)) - if isinstance(x, QuantTensor) and self.is_trunc_quant_enabled: + if (isinstance(x, QuantTensor) or + self.cache_class is not None) and self.is_trunc_quant_enabled: + if self.cache_inference_quant_act: + self.cache_class = _CachedIO(x, self.cache_quant_io_metadata_only) + if not isinstance(x, QuantTensor): + x = self.cache_class.quant_tensor.set(value=x) y = AvgPool2d.forward(self, x) rescaled_value = y.value * self._avg_scaling y = y.set(value=rescaled_value) @@ -87,6 +96,9 @@ def __init__( AdaptiveAvgPool2d.__init__(self, output_size=output_size) QuantLayerMixin.__init__(self, return_quant_tensor) TruncMixin.__init__(self, trunc_quant=trunc_quant, **kwargs) + self.cache_inference_quant_act = False + self.cache_quant_io_metadata_only = True + self.cache_class = None @property def channelwise_separable(self) -> bool: @@ -120,7 +132,12 @@ def forward(self, input: Union[Tensor, QuantTensor]): self._set_global_is_quant_layer(False) return out - if isinstance(x, QuantTensor) and self.is_trunc_quant_enabled: + if (isinstance(x, QuantTensor) or + self.cache_class is not None) and self.is_trunc_quant_enabled: + if self.cache_inference_quant_act: + self.cache_class = _CachedIO(x, self.cache_quant_io_metadata_only) + if not isinstance(x, QuantTensor): + x = self.cache_class.quant_tensor.set(value=x) y = AdaptiveAvgPool2d.forward(self, x) k_size, stride = self.compute_kernel_size_stride(x.value.shape[2:], y.value.shape[2:]) reduce_size = reduce(mul, k_size, 1) diff --git a/src/brevitas/proxy/parameter_quant.py b/src/brevitas/proxy/parameter_quant.py index 2144edd53..5c4e447d4 100644 --- a/src/brevitas/proxy/parameter_quant.py +++ b/src/brevitas/proxy/parameter_quant.py @@ -95,6 +95,7 @@ def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None: self.cache_inference_quant_weight_metadata_only = False self.cache_class = None # To be redefined by each class self.quant_tensor_class = None # To be redefined by each class + self.skip_create_quant_tensor = False @property def input_view_impl(self): @@ -132,13 +133,13 @@ def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]: # - quantization flow if self.export_mode: out = self.export_handler(x) - if is_dynamo_compiling(): + if self.skip_create_quant_tensor: out = out[0] else: out = self.create_quant_tensor(out) else: out = self.tensor_quant(x) - if is_dynamo_compiling(): + if self.skip_create_quant_tensor: out = out[0] else: out = self.create_quant_tensor(out) @@ -159,6 +160,7 @@ def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None: self.cache_inference_quant_bias = False self.cache_inference_quant_bias_metadata_only = False self.requires_input_scale = self.quant_injector.requires_input_scale + self.skip_create_quant_tensor = False @property def tracked_parameter_list(self): @@ -263,7 +265,7 @@ def forward( self._cached_act = cached_inp if self.is_quant_enabled: - if quant_input is None: + if quant_input is None or isinstance(quant_input, Tensor): assert self._cached_act is not None, "No cached quant input found. Enable caching and perform a forward pass" quant_input = self._cached_act else: @@ -274,6 +276,8 @@ def forward( impl = self.export_handler if self.export_mode else self.tensor_quant out, scale, zero_point, bit_width, pre_scale, pre_zero_point = impl(x, input_bit_width, input_is_signed) + if self.skip_create_quant_tensor: + return out return IntQuantTensor(out, scale, zero_point, bit_width, self.is_signed, self.training) else: # quantization disabled return x @@ -356,7 +360,7 @@ def forward( out, out_scale, out_zp, out_bit_width = impl(x, input_scale) else: out, out_scale, out_zp, out_bit_width = impl(x) - if not is_dynamo_compiling(): + if not self.skip_create_quant_tensor: out = IntQuantTensor( out, out_scale, out_zp, out_bit_width, self.is_signed, self.training) if not self.training and self.cache_inference_quant_bias: diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index 9feb593b4..cff192490 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -98,6 +98,7 @@ def __init__(self, quant_layer, quant_injector): self.cache_inference_quant_act = False self.cache_quant_io_metadata_only = True self.cache_class = None + self.skip_create_quant_tensor = False @property def input_view_impl(self): @@ -188,7 +189,7 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: # If y is an empty QuantTensor, we need to check if this is a passthrough proxy, # otherwise return a simple Tensor - if is_dynamo_compiling(): + if self.skip_create_quant_tensor: out = y[0] else: # If the second value (i.e., scale) is None, then quant is disabled @@ -246,10 +247,16 @@ def zero_point(self, force_eval=True): class ClampQuantProxyFromInjector(QuantProxyFromInjector, AccQuantProxyProtocol): + def __init__(self): + super().__init__() + self.skip_create_quant_tensor = False + def forward(self, x: IntQuantTensor) -> Union[Tensor, IntQuantTensor]: if self.is_quant_enabled: out_tuple = self.tensor_quant(x.value, x.scale, x.bit_width) out_value, out_scale, out_zp, out_bit_width = out_tuple + if self.skip_create_quant_tensor: + return out_value return IntQuantTensor( out_value, out_scale, out_zp, out_bit_width, self.is_signed, self.training) return x @@ -257,6 +264,10 @@ def forward(self, x: IntQuantTensor) -> Union[Tensor, IntQuantTensor]: class TruncQuantProxyFromInjector(QuantProxyFromInjector, AccQuantProxyProtocol): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.skip_create_quant_tensor = False + def bit_width(self): if not self.is_quant_enabled: return None @@ -274,6 +285,8 @@ def forward(self, x: IntQuantTensor) -> Union[Tensor, IntQuantTensor]: else: out_tuple = self.tensor_quant(x.value, x.scale, x.zero_point, x.bit_width) out_value, out_scale, out_zp, out_bit_width = out_tuple + if self.skip_create_quant_tensor: + return out_value return IntQuantTensor( out_value, out_scale, out_zp, out_bit_width, x.signed, self.training) else: