diff --git a/src/brevitas/export/common/handler/qcdq.py b/src/brevitas/export/common/handler/qcdq.py index bbc03b630..39347baad 100644 --- a/src/brevitas/export/common/handler/qcdq.py +++ b/src/brevitas/export/common/handler/qcdq.py @@ -796,6 +796,7 @@ def symbolic_execution( flat_pre_scale = to_0dim_if_scalar(pre_scale.flatten()) flat_scale = to_0dim_if_scalar(scale.flatten()) zp = to_0dim_if_scalar(zero_point.flatten()).expand_as(flat_scale) + zp = self.zero_point_with_dtype(signed, output_bit_width, zp) x = self.quantize_fn(x, flat_pre_scale, zp, dtype, self.quant_axis(pre_scale)) clip_symbolic_kwargs = self.int_clip_symbolic_kwargs( signed=signed, narrow=False, bit_width=output_bit_width) diff --git a/src/brevitas/export/inference/manager.py b/src/brevitas/export/inference/manager.py index 936106884..3c933d940 100644 --- a/src/brevitas/export/inference/manager.py +++ b/src/brevitas/export/inference/manager.py @@ -1,6 +1,8 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +from functools import partial + from torch.nn import Module import torch.nn as nn @@ -37,6 +39,11 @@ def _override_weight_caching_mode(m: nn.Module, enabled: bool, metadata_only: bo _override_caching_mode(m, 'weight', enabled, metadata_only) +def _override_quant_tensor_return_state(m: nn.Module, state: bool): + if hasattr(m, 'return_quant_tensor'): + m.return_quant_tensor = state + + class quant_inference_mode: def __init__(self, model, cache_quant_weight=False, enabled=True): @@ -74,6 +81,8 @@ def __exit__(self, type, value, traceback): lambda m: _override_weight_caching_mode(m, enabled=False, metadata_only=False)) InferenceManager.set_export_mode(self.model, enabled=False) restore_return_quant_tensor(self.model, self.return_quant_tensor_state) + enable_quant_tensor = partial(_override_quant_tensor_return_state, state=True) + self.model.apply(enable_quant_tensor) def hook(self, module, inp, out): # After one forward pass with caching enabled, we can: @@ -85,6 +94,8 @@ def hook(self, module, inp, out): self.model.apply(InferenceManager.set_export_handler) InferenceManager.set_export_mode(self.model, enabled=True) self.return_quant_tensor_state = disable_return_quant_tensor(self.model) + disable_quant_tensor = partial(_override_quant_tensor_return_state, state=False) + self.model.apply(disable_quant_tensor) # Inheritance from BaseManager is not techincally needed diff --git a/src/brevitas/export/manager.py b/src/brevitas/export/manager.py index 7b7e7a145..550711e16 100644 --- a/src/brevitas/export/manager.py +++ b/src/brevitas/export/manager.py @@ -3,10 +3,8 @@ from abc import ABC from abc import abstractmethod -from contextlib import ExitStack -from functools import partial from io import BytesIO -from typing import Optional, Tuple, Union +from typing import Tuple, Union import torch from torch import nn @@ -20,7 +18,6 @@ from brevitas.quant_tensor import QuantTensor from brevitas.utils.jit_utils import clear_class_registry from brevitas.utils.python_utils import patch -from brevitas.utils.quant_utils import _CachedIO class _JitTraceExportWrapper(nn.Module): @@ -219,6 +216,7 @@ def jit_inference_trace( # wrapping with a lambda forces inlining during tracing, # converts everything to const and removes unused params/buffers traced_model = torch.jit.trace(_JitTraceExportWrapper(module), args) + # Hack to clone the function, otherwise restoring requires_grad # on module will break traced_model with BytesIO() as tmp: diff --git a/src/brevitas/export/onnx/manager.py b/src/brevitas/export/onnx/manager.py index ae3270cc9..cbede2dca 100644 --- a/src/brevitas/export/onnx/manager.py +++ b/src/brevitas/export/onnx/manager.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: BSD-3-Clause from abc import ABC -from contextlib import ExitStack from io import BytesIO from typing import Optional, Tuple, Union import warnings @@ -167,7 +166,6 @@ def export_onnx( with PatchFp8Ops(): torch.onnx.export(module, args, export_target, **onnx_export_kwargs) - # restore the model to previous properties module.apply(lambda m: _restore_act_caching_mode(m)) cls.set_export_mode(module, enabled=False) diff --git a/src/brevitas/nn/quant_avg_pool.py b/src/brevitas/nn/quant_avg_pool.py index 7a3f108da..12f324901 100644 --- a/src/brevitas/nn/quant_avg_pool.py +++ b/src/brevitas/nn/quant_avg_pool.py @@ -15,6 +15,7 @@ from brevitas.inject.defaults import RoundTo8bit from brevitas.quant_tensor import _unpack_quant_tensor from brevitas.quant_tensor import QuantTensor +from brevitas.utils.quant_utils import _CachedIO from .mixin.acc import AccQuantType from .mixin.acc import TruncMixin @@ -38,6 +39,9 @@ def __init__( AvgPool2d.__init__(self, kernel_size=kernel_size, stride=stride) QuantLayerMixin.__init__(self, return_quant_tensor) TruncMixin.__init__(self, trunc_quant=trunc_quant, **kwargs) + self.cache_inference_quant_act = False + self.cache_quant_io_metadata_only = True + self.cache_class = None @property def channelwise_separable(self) -> bool: @@ -60,7 +64,12 @@ def forward(self, input: Union[Tensor, QuantTensor]): if self.export_mode: return self.export_handler(_unpack_quant_tensor(x)) - if isinstance(x, QuantTensor) and self.is_trunc_quant_enabled: + if (isinstance(x, QuantTensor) or + self.cache_class is not None) and self.is_trunc_quant_enabled: + if self.cache_inference_quant_act: + self.cache_class = _CachedIO(x, self.cache_quant_io_metadata_only) + if not isinstance(x, QuantTensor): + x = self.cache_class.quant_tensor.set(value=x) y = AvgPool2d.forward(self, x) rescaled_value = y.value * self._avg_scaling y = y.set(value=rescaled_value) @@ -87,6 +96,9 @@ def __init__( AdaptiveAvgPool2d.__init__(self, output_size=output_size) QuantLayerMixin.__init__(self, return_quant_tensor) TruncMixin.__init__(self, trunc_quant=trunc_quant, **kwargs) + self.cache_inference_quant_act = False + self.cache_quant_io_metadata_only = True + self.cache_class = None @property def channelwise_separable(self) -> bool: @@ -120,7 +132,12 @@ def forward(self, input: Union[Tensor, QuantTensor]): self._set_global_is_quant_layer(False) return out - if isinstance(x, QuantTensor) and self.is_trunc_quant_enabled: + if (isinstance(x, QuantTensor) or + self.cache_class is not None) and self.is_trunc_quant_enabled: + if self.cache_inference_quant_act: + self.cache_class = _CachedIO(x, self.cache_quant_io_metadata_only) + if not isinstance(x, QuantTensor): + x = self.cache_class.quant_tensor.set(value=x) y = AdaptiveAvgPool2d.forward(self, x) k_size, stride = self.compute_kernel_size_stride(x.value.shape[2:], y.value.shape[2:]) reduce_size = reduce(mul, k_size, 1) diff --git a/src/brevitas/proxy/parameter_quant.py b/src/brevitas/proxy/parameter_quant.py index 2144edd53..eafc6e953 100644 --- a/src/brevitas/proxy/parameter_quant.py +++ b/src/brevitas/proxy/parameter_quant.py @@ -95,6 +95,7 @@ def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None: self.cache_inference_quant_weight_metadata_only = False self.cache_class = None # To be redefined by each class self.quant_tensor_class = None # To be redefined by each class + self.return_quant_tensor = True @property def input_view_impl(self): @@ -138,7 +139,7 @@ def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]: out = self.create_quant_tensor(out) else: out = self.tensor_quant(x) - if is_dynamo_compiling(): + if not self.return_quant_tensor: out = out[0] else: out = self.create_quant_tensor(out) @@ -263,7 +264,7 @@ def forward( self._cached_act = cached_inp if self.is_quant_enabled: - if quant_input is None: + if quant_input is None or isinstance(quant_input, Tensor): assert self._cached_act is not None, "No cached quant input found. Enable caching and perform a forward pass" quant_input = self._cached_act else: @@ -274,6 +275,8 @@ def forward( impl = self.export_handler if self.export_mode else self.tensor_quant out, scale, zero_point, bit_width, pre_scale, pre_zero_point = impl(x, input_bit_width, input_is_signed) + if torch._C._get_tracing_state() is not None: + return out return IntQuantTensor(out, scale, zero_point, bit_width, self.is_signed, self.training) else: # quantization disabled return x @@ -356,7 +359,7 @@ def forward( out, out_scale, out_zp, out_bit_width = impl(x, input_scale) else: out, out_scale, out_zp, out_bit_width = impl(x) - if not is_dynamo_compiling(): + if not is_dynamo_compiling() or torch._C._get_tracing_state() is not None: out = IntQuantTensor( out, out_scale, out_zp, out_bit_width, self.is_signed, self.training) if not self.training and self.cache_inference_quant_bias: diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index 9feb593b4..5cefc11c9 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -60,6 +60,10 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> QuantTensor: @runtime_checkable class AccQuantProxyProtocol(QuantProxyProtocol, Protocol): + def __init__(self): + super().__init__() + self.return_quant_tensor = True + def forward(self, x: QuantTensor) -> QuantTensor: ... @@ -98,6 +102,7 @@ def __init__(self, quant_layer, quant_injector): self.cache_inference_quant_act = False self.cache_quant_io_metadata_only = True self.cache_class = None + self.return_quant_tensor = False @property def input_view_impl(self): @@ -188,7 +193,7 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: # If y is an empty QuantTensor, we need to check if this is a passthrough proxy, # otherwise return a simple Tensor - if is_dynamo_compiling(): + if not self.return_quant_tensor: out = y[0] else: # If the second value (i.e., scale) is None, then quant is disabled @@ -250,6 +255,8 @@ def forward(self, x: IntQuantTensor) -> Union[Tensor, IntQuantTensor]: if self.is_quant_enabled: out_tuple = self.tensor_quant(x.value, x.scale, x.bit_width) out_value, out_scale, out_zp, out_bit_width = out_tuple + if not self.return_quant_tensor: + return out_value return IntQuantTensor( out_value, out_scale, out_zp, out_bit_width, self.is_signed, self.training) return x @@ -274,6 +281,8 @@ def forward(self, x: IntQuantTensor) -> Union[Tensor, IntQuantTensor]: else: out_tuple = self.tensor_quant(x.value, x.scale, x.zero_point, x.bit_width) out_value, out_scale, out_zp, out_bit_width = out_tuple + if not self.return_quant_tensor: + return out_value return IntQuantTensor( out_value, out_scale, out_zp, out_bit_width, x.signed, self.training) else: