Skip to content

Commit

Permalink
Feat (proxy): flag to enable/disable QT return
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Dec 17, 2024
1 parent d51087c commit 3ccf130
Show file tree
Hide file tree
Showing 7 changed files with 49 additions and 12 deletions.
1 change: 1 addition & 0 deletions src/brevitas/export/common/handler/qcdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,6 +796,7 @@ def symbolic_execution(
flat_pre_scale = to_0dim_if_scalar(pre_scale.flatten())
flat_scale = to_0dim_if_scalar(scale.flatten())
zp = to_0dim_if_scalar(zero_point.flatten()).expand_as(flat_scale)
zp = self.zero_point_with_dtype(signed, output_bit_width, zp)
x = self.quantize_fn(x, flat_pre_scale, zp, dtype, self.quant_axis(pre_scale))
clip_symbolic_kwargs = self.int_clip_symbolic_kwargs(
signed=signed, narrow=False, bit_width=output_bit_width)
Expand Down
11 changes: 11 additions & 0 deletions src/brevitas/export/inference/manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from functools import partial

from torch.nn import Module
import torch.nn as nn

Expand Down Expand Up @@ -37,6 +39,11 @@ def _override_weight_caching_mode(m: nn.Module, enabled: bool, metadata_only: bo
_override_caching_mode(m, 'weight', enabled, metadata_only)


def _override_quant_tensor_return_state(m: nn.Module, state: bool):
if hasattr(m, 'return_quant_tensor'):
m.return_quant_tensor = state


class quant_inference_mode:

def __init__(self, model, cache_quant_weight=False, enabled=True):
Expand Down Expand Up @@ -74,6 +81,8 @@ def __exit__(self, type, value, traceback):
lambda m: _override_weight_caching_mode(m, enabled=False, metadata_only=False))
InferenceManager.set_export_mode(self.model, enabled=False)
restore_return_quant_tensor(self.model, self.return_quant_tensor_state)
enable_quant_tensor = partial(_override_quant_tensor_return_state, state=True)
self.model.apply(enable_quant_tensor)

def hook(self, module, inp, out):
# After one forward pass with caching enabled, we can:
Expand All @@ -85,6 +94,8 @@ def hook(self, module, inp, out):
self.model.apply(InferenceManager.set_export_handler)
InferenceManager.set_export_mode(self.model, enabled=True)
self.return_quant_tensor_state = disable_return_quant_tensor(self.model)
disable_quant_tensor = partial(_override_quant_tensor_return_state, state=False)
self.model.apply(disable_quant_tensor)


# Inheritance from BaseManager is not techincally needed
Expand Down
6 changes: 2 additions & 4 deletions src/brevitas/export/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,8 @@

from abc import ABC
from abc import abstractmethod
from contextlib import ExitStack
from functools import partial
from io import BytesIO
from typing import Optional, Tuple, Union
from typing import Tuple, Union

import torch
from torch import nn
Expand All @@ -20,7 +18,6 @@
from brevitas.quant_tensor import QuantTensor
from brevitas.utils.jit_utils import clear_class_registry
from brevitas.utils.python_utils import patch
from brevitas.utils.quant_utils import _CachedIO


class _JitTraceExportWrapper(nn.Module):
Expand Down Expand Up @@ -219,6 +216,7 @@ def jit_inference_trace(
# wrapping with a lambda forces inlining during tracing,
# converts everything to const and removes unused params/buffers
traced_model = torch.jit.trace(_JitTraceExportWrapper(module), args)

# Hack to clone the function, otherwise restoring requires_grad
# on module will break traced_model
with BytesIO() as tmp:
Expand Down
2 changes: 0 additions & 2 deletions src/brevitas/export/onnx/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# SPDX-License-Identifier: BSD-3-Clause

from abc import ABC
from contextlib import ExitStack
from io import BytesIO
from typing import Optional, Tuple, Union
import warnings
Expand Down Expand Up @@ -167,7 +166,6 @@ def export_onnx(

with PatchFp8Ops():
torch.onnx.export(module, args, export_target, **onnx_export_kwargs)

# restore the model to previous properties
module.apply(lambda m: _restore_act_caching_mode(m))
cls.set_export_mode(module, enabled=False)
Expand Down
21 changes: 19 additions & 2 deletions src/brevitas/nn/quant_avg_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from brevitas.inject.defaults import RoundTo8bit
from brevitas.quant_tensor import _unpack_quant_tensor
from brevitas.quant_tensor import QuantTensor
from brevitas.utils.quant_utils import _CachedIO

from .mixin.acc import AccQuantType
from .mixin.acc import TruncMixin
Expand All @@ -38,6 +39,9 @@ def __init__(
AvgPool2d.__init__(self, kernel_size=kernel_size, stride=stride)
QuantLayerMixin.__init__(self, return_quant_tensor)
TruncMixin.__init__(self, trunc_quant=trunc_quant, **kwargs)
self.cache_inference_quant_act = False
self.cache_quant_io_metadata_only = True
self.cache_class = None

@property
def channelwise_separable(self) -> bool:
Expand All @@ -60,7 +64,12 @@ def forward(self, input: Union[Tensor, QuantTensor]):
if self.export_mode:
return self.export_handler(_unpack_quant_tensor(x))

if isinstance(x, QuantTensor) and self.is_trunc_quant_enabled:
if (isinstance(x, QuantTensor) or
self.cache_class is not None) and self.is_trunc_quant_enabled:
if self.cache_inference_quant_act:
self.cache_class = _CachedIO(x, self.cache_quant_io_metadata_only)
if not isinstance(x, QuantTensor):
x = self.cache_class.quant_tensor.set(value=x)
y = AvgPool2d.forward(self, x)
rescaled_value = y.value * self._avg_scaling
y = y.set(value=rescaled_value)
Expand All @@ -87,6 +96,9 @@ def __init__(
AdaptiveAvgPool2d.__init__(self, output_size=output_size)
QuantLayerMixin.__init__(self, return_quant_tensor)
TruncMixin.__init__(self, trunc_quant=trunc_quant, **kwargs)
self.cache_inference_quant_act = False
self.cache_quant_io_metadata_only = True
self.cache_class = None

@property
def channelwise_separable(self) -> bool:
Expand Down Expand Up @@ -120,7 +132,12 @@ def forward(self, input: Union[Tensor, QuantTensor]):
self._set_global_is_quant_layer(False)
return out

if isinstance(x, QuantTensor) and self.is_trunc_quant_enabled:
if (isinstance(x, QuantTensor) or
self.cache_class is not None) and self.is_trunc_quant_enabled:
if self.cache_inference_quant_act:
self.cache_class = _CachedIO(x, self.cache_quant_io_metadata_only)
if not isinstance(x, QuantTensor):
x = self.cache_class.quant_tensor.set(value=x)
y = AdaptiveAvgPool2d.forward(self, x)
k_size, stride = self.compute_kernel_size_stride(x.value.shape[2:], y.value.shape[2:])
reduce_size = reduce(mul, k_size, 1)
Expand Down
9 changes: 6 additions & 3 deletions src/brevitas/proxy/parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None:
self.cache_inference_quant_weight_metadata_only = False
self.cache_class = None # To be redefined by each class
self.quant_tensor_class = None # To be redefined by each class
self.return_quant_tensor = True

@property
def input_view_impl(self):
Expand Down Expand Up @@ -138,7 +139,7 @@ def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]:
out = self.create_quant_tensor(out)
else:
out = self.tensor_quant(x)
if is_dynamo_compiling():
if not self.return_quant_tensor:
out = out[0]
else:
out = self.create_quant_tensor(out)
Expand Down Expand Up @@ -263,7 +264,7 @@ def forward(
self._cached_act = cached_inp

if self.is_quant_enabled:
if quant_input is None:
if quant_input is None or isinstance(quant_input, Tensor):
assert self._cached_act is not None, "No cached quant input found. Enable caching and perform a forward pass"
quant_input = self._cached_act
else:
Expand All @@ -274,6 +275,8 @@ def forward(

impl = self.export_handler if self.export_mode else self.tensor_quant
out, scale, zero_point, bit_width, pre_scale, pre_zero_point = impl(x, input_bit_width, input_is_signed)
if torch._C._get_tracing_state() is not None:
return out
return IntQuantTensor(out, scale, zero_point, bit_width, self.is_signed, self.training)
else: # quantization disabled
return x
Expand Down Expand Up @@ -356,7 +359,7 @@ def forward(
out, out_scale, out_zp, out_bit_width = impl(x, input_scale)
else:
out, out_scale, out_zp, out_bit_width = impl(x)
if not is_dynamo_compiling():
if not is_dynamo_compiling() or torch._C._get_tracing_state() is not None:
out = IntQuantTensor(
out, out_scale, out_zp, out_bit_width, self.is_signed, self.training)
if not self.training and self.cache_inference_quant_bias:
Expand Down
11 changes: 10 additions & 1 deletion src/brevitas/proxy/runtime_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> QuantTensor:
@runtime_checkable
class AccQuantProxyProtocol(QuantProxyProtocol, Protocol):

def __init__(self):
super().__init__()
self.return_quant_tensor = True

def forward(self, x: QuantTensor) -> QuantTensor:
...

Expand Down Expand Up @@ -98,6 +102,7 @@ def __init__(self, quant_layer, quant_injector):
self.cache_inference_quant_act = False
self.cache_quant_io_metadata_only = True
self.cache_class = None
self.return_quant_tensor = False

@property
def input_view_impl(self):
Expand Down Expand Up @@ -188,7 +193,7 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]:
# If y is an empty QuantTensor, we need to check if this is a passthrough proxy,
# otherwise return a simple Tensor

if is_dynamo_compiling():
if not self.return_quant_tensor:
out = y[0]
else:
# If the second value (i.e., scale) is None, then quant is disabled
Expand Down Expand Up @@ -250,6 +255,8 @@ def forward(self, x: IntQuantTensor) -> Union[Tensor, IntQuantTensor]:
if self.is_quant_enabled:
out_tuple = self.tensor_quant(x.value, x.scale, x.bit_width)
out_value, out_scale, out_zp, out_bit_width = out_tuple
if not self.return_quant_tensor:
return out_value
return IntQuantTensor(
out_value, out_scale, out_zp, out_bit_width, self.is_signed, self.training)
return x
Expand All @@ -274,6 +281,8 @@ def forward(self, x: IntQuantTensor) -> Union[Tensor, IntQuantTensor]:
else:
out_tuple = self.tensor_quant(x.value, x.scale, x.zero_point, x.bit_width)
out_value, out_scale, out_zp, out_bit_width = out_tuple
if not self.return_quant_tensor:
return out_value
return IntQuantTensor(
out_value, out_scale, out_zp, out_bit_width, x.signed, self.training)
else:
Expand Down

0 comments on commit 3ccf130

Please sign in to comment.