Skip to content

Commit

Permalink
Feat (proxy): flag to enable/disable QT return (#1083)
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 authored Dec 18, 2024
1 parent 48efcf6 commit 3612e90
Show file tree
Hide file tree
Showing 8 changed files with 84 additions and 18 deletions.
1 change: 1 addition & 0 deletions src/brevitas/export/common/handler/qcdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,6 +796,7 @@ def symbolic_execution(
flat_pre_scale = to_0dim_if_scalar(pre_scale.flatten())
flat_scale = to_0dim_if_scalar(scale.flatten())
zp = to_0dim_if_scalar(zero_point.flatten()).expand_as(flat_scale)
zp = self.zero_point_with_dtype(signed, output_bit_width, zp)
x = self.quantize_fn(x, flat_pre_scale, zp, dtype, self.quant_axis(pre_scale))
clip_symbolic_kwargs = self.int_clip_symbolic_kwargs(
signed=signed, narrow=False, bit_width=output_bit_width)
Expand Down
34 changes: 29 additions & 5 deletions src/brevitas/export/inference/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from torch import Tensor
import torch.nn as nn

from brevitas import is_dynamo_compiling
from brevitas.function.ops import max_float
from brevitas.function.ops import max_int
from brevitas.function.ops import min_int
Expand Down Expand Up @@ -110,14 +109,20 @@ def forward(self, x: Tensor, unused_scale: Tensor = None) -> Tuple[Tensor]:
class GroupwiseIntInferenceHandler(IntInferencetHandler):
handled_layer = GroupwiseActQuantProxyFromInjector

def __init__(self):
super().__init__()
self.skip_create_quant_tensor = False

def prepare_for_export(self, module):
if module.is_quant_enabled:
self.module_forward = module.fused_activation_quant_proxy
self.group_dim = module.group_dim

def forward(self, x: Tensor, unused_scale: Tensor = None) -> Tuple[Tensor]:
x, *other = self.module_forward(x)
if is_dynamo_compiling():

# If we skip quant tensor, we return the flattened version of the groupwise tensor
if self.skip_create_quant_tensor:
start_dim = self.group_dim if self.group_dim != -1 else -2
x = x.flatten(start_dim, start_dim + 1)
output_args = tuple([x] + list(other))
Expand All @@ -127,6 +132,10 @@ def forward(self, x: Tensor, unused_scale: Tensor = None) -> Tuple[Tensor]:
class GroupwiseIntWeightInferenceHandler(IntWeightInferencetHandler):
handled_layer = GroupwiseWeightQuantProxyFromInjector

def __init__(self):
super().__init__()
self.skip_create_quant_tensor = False

def prepare_for_export(self, module):
super().prepare_for_export(module)
if module.is_quant_enabled:
Expand All @@ -151,7 +160,9 @@ def forward(self, x: Tensor) -> Tuple[Tensor]:
else:
x = self.input_view(x)
out = self.dequantize(self.quantize(x, scale, zero_point), scale, zero_point)
if is_dynamo_compiling():

# If we skip quant tensor, we return the flattened version of the groupwise tensor
if self.skip_create_quant_tensor:
out = self.flattened_view(out)
return out, scale, zero_point, self.bit_width

Expand Down Expand Up @@ -242,14 +253,20 @@ def forward(self, x: Tensor) -> Tuple[Tensor]:
class GroupwiseFloatInferenceHandler(FloatInferencetHandler):
handled_layer = GroupwiseActFloatQuantProxyFromInjector

def __init__(self):
super().__init__()
self.skip_create_quant_tensor = False

def prepare_for_export(self, module: nn.Module):
if module.is_quant_enabled:
self.module_forward = module.fused_activation_quant_proxy
self.group_dim = module.group_dim

def forward(self, x: Tensor) -> Tuple[Tensor]:
x, *other = self.module_forward(x)
if is_dynamo_compiling():

# If we skip quant tensor, we return the flattened version of the groupwise tensor
if self.skip_create_quant_tensor:
start_dim = self.group_dim if self.group_dim != -1 else -2
x = x.flatten(start_dim, start_dim + 1)
output_args = tuple([x] + list(other))
Expand All @@ -259,6 +276,10 @@ def forward(self, x: Tensor) -> Tuple[Tensor]:
class GroupwiseFloatWeightInferenceHandler(FloatWeightInferencetHandler):
handled_layer = GroupwiseWeightFloatQuantProxyFromInjector

def __init__(self):
super().__init__()
self.skip_create_quant_tensor = False

def prepare_for_export(self, module: nn.Module):
super().prepare_for_export(module)
if module.is_quant_enabled:
Expand All @@ -283,6 +304,9 @@ def forward(self, x: Tensor) -> Tuple[Tensor]:
else:
x = self.input_view(x)
out = self.dequantize(self.quantize(x, scale, zero_point), scale, zero_point)
if is_dynamo_compiling():

# If we skip quant tensor, we return the flattened version of the groupwise tensor
if self.skip_create_quant_tensor:
out = self.flattened_view(out)

return out, scale, zero_point, self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias, self.saturating, self.inf_values, self.nan_values
11 changes: 11 additions & 0 deletions src/brevitas/export/inference/manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from functools import partial

from torch.nn import Module
import torch.nn as nn

Expand Down Expand Up @@ -42,6 +44,11 @@ def _override_weight_caching_mode(m: nn.Module, enabled: bool, metadata_only: bo
_override_caching_mode(m, 'weight', enabled, metadata_only)


def _override_create_quant_tensor(m: nn.Module, state: bool):
if hasattr(m, 'skip_create_quant_tensor'):
m.skip_create_quant_tensor = state


class quant_inference_mode:

def __init__(self, model, cache_quant_weight=False, enabled=True):
Expand Down Expand Up @@ -79,6 +86,8 @@ def __exit__(self, type, value, traceback):
self.model.apply(
lambda m: _override_weight_caching_mode(m, enabled=False, metadata_only=False))
restore_return_quant_tensor(self.model, self.return_quant_tensor_state)
enable_quant_tensor = partial(_override_create_quant_tensor, state=False)
self.model.apply(enable_quant_tensor)

def hook(self, module, inp, out):
# After one forward pass with caching enabled, we can:
Expand All @@ -90,6 +99,8 @@ def hook(self, module, inp, out):
self.model.apply(InferenceManager.set_export_handler)
InferenceManager.set_export_mode(self.model, enabled=True)
self.return_quant_tensor_state = disable_return_quant_tensor(self.model)
disable_quant_tensor = partial(_override_create_quant_tensor, state=True)
self.model.apply(disable_quant_tensor)


# Inheritance from BaseManager is not techincally needed
Expand Down
6 changes: 2 additions & 4 deletions src/brevitas/export/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,8 @@

from abc import ABC
from abc import abstractmethod
from contextlib import ExitStack
from functools import partial
from io import BytesIO
from typing import Optional, Tuple, Union
from typing import Tuple, Union

import torch
from torch import nn
Expand All @@ -20,7 +18,6 @@
from brevitas.quant_tensor import QuantTensor
from brevitas.utils.jit_utils import clear_class_registry
from brevitas.utils.python_utils import patch
from brevitas.utils.quant_utils import _CachedIO


class _JitTraceExportWrapper(nn.Module):
Expand Down Expand Up @@ -219,6 +216,7 @@ def jit_inference_trace(
# wrapping with a lambda forces inlining during tracing,
# converts everything to const and removes unused params/buffers
traced_model = torch.jit.trace(_JitTraceExportWrapper(module), args)

# Hack to clone the function, otherwise restoring requires_grad
# on module will break traced_model
with BytesIO() as tmp:
Expand Down
2 changes: 0 additions & 2 deletions src/brevitas/export/onnx/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# SPDX-License-Identifier: BSD-3-Clause

from abc import ABC
from contextlib import ExitStack
from io import BytesIO
from typing import Optional, Tuple, Union
import warnings
Expand Down Expand Up @@ -167,7 +166,6 @@ def export_onnx(

with PatchFp8Ops():
torch.onnx.export(module, args, export_target, **onnx_export_kwargs)

# restore the model to previous properties
module.apply(lambda m: _restore_act_caching_mode(m))
cls.set_export_mode(module, enabled=False)
Expand Down
21 changes: 19 additions & 2 deletions src/brevitas/nn/quant_avg_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from brevitas.inject.defaults import RoundTo8bit
from brevitas.quant_tensor import _unpack_quant_tensor
from brevitas.quant_tensor import QuantTensor
from brevitas.utils.quant_utils import _CachedIO

from .mixin.acc import AccQuantType
from .mixin.acc import TruncMixin
Expand All @@ -38,6 +39,9 @@ def __init__(
AvgPool2d.__init__(self, kernel_size=kernel_size, stride=stride)
QuantLayerMixin.__init__(self, return_quant_tensor)
TruncMixin.__init__(self, trunc_quant=trunc_quant, **kwargs)
self.cache_inference_quant_act = False
self.cache_quant_io_metadata_only = True
self.cache_class = None

@property
def channelwise_separable(self) -> bool:
Expand All @@ -60,7 +64,12 @@ def forward(self, input: Union[Tensor, QuantTensor]):
if self.export_mode:
return self.export_handler(_unpack_quant_tensor(x))

if isinstance(x, QuantTensor) and self.is_trunc_quant_enabled:
if (isinstance(x, QuantTensor) or
self.cache_class is not None) and self.is_trunc_quant_enabled:
if self.cache_inference_quant_act:
self.cache_class = _CachedIO(x, self.cache_quant_io_metadata_only)
if not isinstance(x, QuantTensor):
x = self.cache_class.quant_tensor.set(value=x)
y = AvgPool2d.forward(self, x)
rescaled_value = y.value * self._avg_scaling
y = y.set(value=rescaled_value)
Expand All @@ -87,6 +96,9 @@ def __init__(
AdaptiveAvgPool2d.__init__(self, output_size=output_size)
QuantLayerMixin.__init__(self, return_quant_tensor)
TruncMixin.__init__(self, trunc_quant=trunc_quant, **kwargs)
self.cache_inference_quant_act = False
self.cache_quant_io_metadata_only = True
self.cache_class = None

@property
def channelwise_separable(self) -> bool:
Expand Down Expand Up @@ -120,7 +132,12 @@ def forward(self, input: Union[Tensor, QuantTensor]):
self._set_global_is_quant_layer(False)
return out

if isinstance(x, QuantTensor) and self.is_trunc_quant_enabled:
if (isinstance(x, QuantTensor) or
self.cache_class is not None) and self.is_trunc_quant_enabled:
if self.cache_inference_quant_act:
self.cache_class = _CachedIO(x, self.cache_quant_io_metadata_only)
if not isinstance(x, QuantTensor):
x = self.cache_class.quant_tensor.set(value=x)
y = AdaptiveAvgPool2d.forward(self, x)
k_size, stride = self.compute_kernel_size_stride(x.value.shape[2:], y.value.shape[2:])
reduce_size = reduce(mul, k_size, 1)
Expand Down
12 changes: 8 additions & 4 deletions src/brevitas/proxy/parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None:
self.cache_inference_quant_weight_metadata_only = False
self.cache_class = None # To be redefined by each class
self.quant_tensor_class = None # To be redefined by each class
self.skip_create_quant_tensor = False

@property
def input_view_impl(self):
Expand Down Expand Up @@ -132,13 +133,13 @@ def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]:
# - quantization flow
if self.export_mode:
out = self.export_handler(x)
if is_dynamo_compiling():
if self.skip_create_quant_tensor:
out = out[0]
else:
out = self.create_quant_tensor(out)
else:
out = self.tensor_quant(x)
if is_dynamo_compiling():
if self.skip_create_quant_tensor:
out = out[0]
else:
out = self.create_quant_tensor(out)
Expand All @@ -159,6 +160,7 @@ def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None:
self.cache_inference_quant_bias = False
self.cache_inference_quant_bias_metadata_only = False
self.requires_input_scale = self.quant_injector.requires_input_scale
self.skip_create_quant_tensor = False

@property
def tracked_parameter_list(self):
Expand Down Expand Up @@ -263,7 +265,7 @@ def forward(
self._cached_act = cached_inp

if self.is_quant_enabled:
if quant_input is None:
if quant_input is None or isinstance(quant_input, Tensor):
assert self._cached_act is not None, "No cached quant input found. Enable caching and perform a forward pass"
quant_input = self._cached_act
else:
Expand All @@ -274,6 +276,8 @@ def forward(

impl = self.export_handler if self.export_mode else self.tensor_quant
out, scale, zero_point, bit_width, pre_scale, pre_zero_point = impl(x, input_bit_width, input_is_signed)
if self.skip_create_quant_tensor:
return out
return IntQuantTensor(out, scale, zero_point, bit_width, self.is_signed, self.training)
else: # quantization disabled
return x
Expand Down Expand Up @@ -356,7 +360,7 @@ def forward(
out, out_scale, out_zp, out_bit_width = impl(x, input_scale)
else:
out, out_scale, out_zp, out_bit_width = impl(x)
if not is_dynamo_compiling():
if not self.skip_create_quant_tensor:
out = IntQuantTensor(
out, out_scale, out_zp, out_bit_width, self.is_signed, self.training)
if not self.training and self.cache_inference_quant_bias:
Expand Down
15 changes: 14 additions & 1 deletion src/brevitas/proxy/runtime_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def __init__(self, quant_layer, quant_injector):
self.cache_inference_quant_act = False
self.cache_quant_io_metadata_only = True
self.cache_class = None
self.skip_create_quant_tensor = False

@property
def input_view_impl(self):
Expand Down Expand Up @@ -188,7 +189,7 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]:
# If y is an empty QuantTensor, we need to check if this is a passthrough proxy,
# otherwise return a simple Tensor

if is_dynamo_compiling():
if self.skip_create_quant_tensor:
out = y[0]
else:
# If the second value (i.e., scale) is None, then quant is disabled
Expand Down Expand Up @@ -246,17 +247,27 @@ def zero_point(self, force_eval=True):

class ClampQuantProxyFromInjector(QuantProxyFromInjector, AccQuantProxyProtocol):

def __init__(self):
super().__init__()
self.skip_create_quant_tensor = False

def forward(self, x: IntQuantTensor) -> Union[Tensor, IntQuantTensor]:
if self.is_quant_enabled:
out_tuple = self.tensor_quant(x.value, x.scale, x.bit_width)
out_value, out_scale, out_zp, out_bit_width = out_tuple
if self.skip_create_quant_tensor:
return out_value
return IntQuantTensor(
out_value, out_scale, out_zp, out_bit_width, self.is_signed, self.training)
return x


class TruncQuantProxyFromInjector(QuantProxyFromInjector, AccQuantProxyProtocol):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.skip_create_quant_tensor = False

def bit_width(self):
if not self.is_quant_enabled:
return None
Expand All @@ -274,6 +285,8 @@ def forward(self, x: IntQuantTensor) -> Union[Tensor, IntQuantTensor]:
else:
out_tuple = self.tensor_quant(x.value, x.scale, x.zero_point, x.bit_width)
out_value, out_scale, out_zp, out_bit_width = out_tuple
if self.skip_create_quant_tensor:
return out_value
return IntQuantTensor(
out_value, out_scale, out_zp, out_bit_width, x.signed, self.training)
else:
Expand Down

0 comments on commit 3612e90

Please sign in to comment.