From af0782c41d0a7bf205d9839c9f6dcc853ae50058 Mon Sep 17 00:00:00 2001 From: vishwamartur Date: Fri, 1 Nov 2024 22:09:16 +0530 Subject: [PATCH] Moved DelayWrapper logic to Proxy --- src/brevitas/core/quant/int_base.py | 2 -- src/brevitas/proxy/parameter_quant.py | 4 ++- src/brevitas/proxy/runtime_quant.py | 7 ++++- tests/brevitas/proxy/test_proxy.py | 39 +++++++++++++++++++++++++++ 4 files changed, 48 insertions(+), 4 deletions(-) diff --git a/src/brevitas/core/quant/int_base.py b/src/brevitas/core/quant/int_base.py index 338e5a433..19725efb4 100644 --- a/src/brevitas/core/quant/int_base.py +++ b/src/brevitas/core/quant/int_base.py @@ -60,7 +60,6 @@ def __init__( self.tensor_clamp_impl = tensor_clamp_impl self.signed = signed self.narrow_range = narrow_range - self.delay_wrapper = DelayWrapper(quant_delay_steps) self.input_view_impl = input_view_impl @brevitas.jit.script_method @@ -87,7 +86,6 @@ def forward(self, scale: Tensor, zero_point: Tensor, bit_width: Tensor, x: Tenso y_int = self.to_int(scale, zero_point, bit_width, x) y = y_int - zero_point y = y * scale - y = self.delay_wrapper(x, y) return y diff --git a/src/brevitas/proxy/parameter_quant.py b/src/brevitas/proxy/parameter_quant.py index f28233aed..cc129d0c8 100644 --- a/src/brevitas/proxy/parameter_quant.py +++ b/src/brevitas/proxy/parameter_quant.py @@ -22,6 +22,7 @@ from brevitas.quant_tensor import IntQuantTensor from brevitas.quant_tensor import QuantTensor from brevitas.utils.quant_utils import _CachedIO +from brevitas.utils.quant_utils import DelayWrapper from brevitas.utils.torch_utils import compute_channel_view_shape from .quant_proxy import QuantProxyFromInjector @@ -94,6 +95,7 @@ def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None: self.cache_inference_quant_weight_metadata_only = False self.cache_class = None # To be redefined by each class self.quant_tensor_class = None # To be redefined by each class + self.delay_wrapper = DelayWrapper() @property def input_view_impl(self): @@ -136,7 +138,7 @@ def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]: else: out = self.create_quant_tensor(out) else: - out = self.tensor_quant(x) + out = self.delay_wrapper(self.tensor_quant)(x) if is_dynamo_compiling(): out = out[0] else: diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index 9feb593b4..6a38a5477 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -17,6 +17,7 @@ from brevitas.quant_tensor import IntQuantTensor from brevitas.quant_tensor import QuantTensor from brevitas.utils.quant_utils import _CachedIO +from brevitas.utils.quant_utils import DelayWrapper from .quant_proxy import QuantProxyFromInjector from .quant_proxy import QuantProxyProtocol @@ -90,6 +91,8 @@ def forward(self, x): class ActQuantProxyFromInjectorBase(QuantProxyFromInjector, ActQuantProxyProtocol, ABC): + delay_wrapper: DelayWrapper + def __init__(self, quant_layer, quant_injector): QuantProxyFromInjector.__init__(self, quant_layer, quant_injector) ActQuantProxyProtocol.__init__(self) @@ -98,6 +101,7 @@ def __init__(self, quant_layer, quant_injector): self.cache_inference_quant_act = False self.cache_quant_io_metadata_only = True self.cache_class = None + self.delay_wrapper = DelayWrapper() @property def input_view_impl(self): @@ -184,7 +188,8 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: y = self.apply_input_view(self.fused_activation_quant_proxy.activation_impl(y)) y = (y, None) else: - y = self.fused_activation_quant_proxy(y) + y = self.delay_wrapper(self.fused_activation_quant_proxy)(y) + # If y is an empty QuantTensor, we need to check if this is a passthrough proxy, # otherwise return a simple Tensor diff --git a/tests/brevitas/proxy/test_proxy.py b/tests/brevitas/proxy/test_proxy.py index 28c3eed9e..f6efdc90d 100644 --- a/tests/brevitas/proxy/test_proxy.py +++ b/tests/brevitas/proxy/test_proxy.py @@ -80,3 +80,42 @@ def test_dynamic_act_proxy(self): model.act_quant.disable_quant = True assert model.act_quant.bit_width() is None + + def test_delay_wrapper_weight_proxy(self): + model = QuantLinear(10, 5, weight_quant=Int8WeightPerTensorFloat, quant_delay_steps=2) + # Initially quantization should be disabled + assert model.weight_quant.scale() is None + assert model.weight_quant.zero_point() is None + assert model.weight_quant.bit_width() is None + + # After 1 step, still disabled + model.weight_quant.quant_delay_steps() + assert model.weight_quant.scale() is None + assert model.weight_quant.zero_point() is None + assert model.weight_quant.bit_width() is None + + # After 2 steps, quantization should be enabled + model.weight_quant.quant_delay_steps() + assert model.weight_quant.scale() is not None + assert model.weight_quant.zero_point() is not None + assert model.weight_quant.bit_width() is not None + + def test_delay_wrapper_act_proxy(self): + model = QuantReLU(quant_delay_steps=3) + # Initially quantization should be disabled + assert model.act_quant.scale() is None + assert model.act_quant.zero_point() is None + assert model.act_quant.bit_width() is None + + # After 2 steps, still disabled + model.act_quant.quant_delay_steps() + model.act_quant.quant_delay_steps() + assert model.act_quant.scale() is None + assert model.act_quant.zero_point() is None + assert model.act_quant.bit_width() is None + + # After 3 steps, quantization should be enabled + model.act_quant.quant_delay_steps() + assert model.act_quant.scale() is not None + assert model.act_quant.zero_point() is not None + assert model.act_quant.bit_width() is not None