diff --git a/src/brevitas/core/quant/binary.py b/src/brevitas/core/quant/binary.py index 3a4b7346e..a00645b3f 100644 --- a/src/brevitas/core/quant/binary.py +++ b/src/brevitas/core/quant/binary.py @@ -10,7 +10,6 @@ import brevitas from brevitas.core.bit_width import BitWidthConst from brevitas.core.function_wrapper import TensorClamp -from brevitas.core.quant.delay import DelayWrapper from brevitas.core.utils import StatelessBuffer from brevitas.function.ops_ste import binary_sign_ste @@ -22,7 +21,6 @@ class BinaryQuant(brevitas.jit.ScriptModule): Args: scaling_impl (Module): Module that returns a scale factor. - quant_delay_steps (int): Number of training steps to delay quantization for. Default: 0 Returns: Tuple[Tensor, Tensor, Tensor, Tensor]: Quantized output in de-quantized format, scale, zero-point, bit_width. @@ -48,19 +46,17 @@ class BinaryQuant(brevitas.jit.ScriptModule): Set env variable BREVITAS_JIT=1 to enable TorchScript compilation of this module. """ - def __init__(self, scaling_impl: Module, signed: bool = True, quant_delay_steps: int = 0): + def __init__(self, scaling_impl: Module, signed: bool = True): super(BinaryQuant, self).__init__() assert signed, "Unsigned binary quant not supported" self.scaling_impl = scaling_impl self.bit_width = BitWidthConst(1) self.zero_point = StatelessBuffer(torch.tensor(0.0)) - self.delay_wrapper = DelayWrapper(quant_delay_steps) @brevitas.jit.script_method def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: scale = self.scaling_impl(x) y = binary_sign_ste(x) * scale - y = self.delay_wrapper(x, y) return y, scale, self.zero_point(), self.bit_width() @@ -74,7 +70,6 @@ class ClampedBinaryQuant(brevitas.jit.ScriptModule): Args: scaling_impl (Module): Module that returns a scale factor. tensor_clamp_impl (Module): Module that performs tensor-wise clamping. Default TensorClamp() - quant_delay_steps (int): Number of training steps to delay quantization for. Default: 0 Returns: Tuple[Tensor, Tensor, Tensor, Tensor]: Quantized output in de-quantized format, scale, zero-point, bit_width. @@ -104,16 +99,11 @@ class ClampedBinaryQuant(brevitas.jit.ScriptModule): Set env variable BREVITAS_JIT=1 to enable TorchScript compilation of this module. """ - def __init__( - self, - scaling_impl: Module, - tensor_clamp_impl: Module = TensorClamp(), - quant_delay_steps: int = 0): + def __init__(self, scaling_impl: Module, tensor_clamp_impl: Module = TensorClamp()): super(ClampedBinaryQuant, self).__init__() self.scaling_impl = scaling_impl self.bit_width = BitWidthConst(1) self.zero_point = StatelessBuffer(torch.tensor(0.0)) - self.delay_wrapper = DelayWrapper(quant_delay_steps) self.tensor_clamp_impl = tensor_clamp_impl @brevitas.jit.script_method @@ -121,5 +111,4 @@ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: scale = self.scaling_impl(x) y = self.tensor_clamp_impl(x, -scale, scale) y = binary_sign_ste(y) * scale - y = self.delay_wrapper(x, y) return y, scale, self.zero_point(), self.bit_width() diff --git a/src/brevitas/core/quant/int.py b/src/brevitas/core/quant/int.py index 328ad63b3..248931d68 100644 --- a/src/brevitas/core/quant/int.py +++ b/src/brevitas/core/quant/int.py @@ -8,7 +8,6 @@ from torch.nn import Module import brevitas -from brevitas.core.quant.delay import DelayWrapper from brevitas.core.utils import StatelessBuffer from brevitas.function.ops_ste import round_ste @@ -201,12 +200,10 @@ class TruncIntQuant(brevitas.jit.ScriptModule): """ """ - def __init__( - self, float_to_int_impl: Module, bit_width_impl: Module, quant_delay_steps: int = 0): + def __init__(self, float_to_int_impl: Module, bit_width_impl: Module): super(TruncIntQuant, self).__init__() self.msb_clamp_bit_width_impl = bit_width_impl self.float_to_int_impl = float_to_int_impl - self.delay_wrapper = DelayWrapper(quant_delay_steps) @brevitas.jit.script_method def forward(self, x: Tensor, scale: Tensor, zero_point: Tensor, @@ -221,7 +218,6 @@ def forward(self, x: Tensor, scale: Tensor, zero_point: Tensor, y = self.float_to_int_impl(y) y = y - zero_point y = y * scale - y = self.delay_wrapper(x, y) return y, scale, zero_point, output_bit_width diff --git a/src/brevitas/core/quant/int_base.py b/src/brevitas/core/quant/int_base.py index 338e5a433..8e94465c9 100644 --- a/src/brevitas/core/quant/int_base.py +++ b/src/brevitas/core/quant/int_base.py @@ -8,7 +8,6 @@ import brevitas from brevitas.core.function_wrapper import RoundSte from brevitas.core.function_wrapper import TensorClamp -from brevitas.core.quant.delay import DelayWrapper from brevitas.function.ops import max_int from brevitas.function.ops import min_int @@ -24,7 +23,6 @@ class IntQuant(brevitas.jit.ScriptModule): float_to_int_impl (Module): Module that performs the conversion from floating point to integer representation. Default: RoundSte() tensor_clamp_impl (Module): Module that performs clamping. Default: TensorClamp() - quant_delay_steps (int): Number of training steps to delay quantization for. Default: 0 Returns: Tensor: Quantized output in de-quantized format. @@ -48,19 +46,17 @@ class IntQuant(brevitas.jit.ScriptModule): __constants__ = ['signed', 'narrow_range'] def __init__( - self, - narrow_range: bool, - signed: bool, - input_view_impl: Module, - float_to_int_impl: Module = RoundSte(), - tensor_clamp_impl: Module = TensorClamp(), - quant_delay_steps: int = 0): + self, + narrow_range: bool, + signed: bool, + input_view_impl: Module, + float_to_int_impl: Module = RoundSte(), + tensor_clamp_impl: Module = TensorClamp()): super(IntQuant, self).__init__() self.float_to_int_impl = float_to_int_impl self.tensor_clamp_impl = tensor_clamp_impl self.signed = signed self.narrow_range = narrow_range - self.delay_wrapper = DelayWrapper(quant_delay_steps) self.input_view_impl = input_view_impl @brevitas.jit.script_method @@ -87,7 +83,6 @@ def forward(self, scale: Tensor, zero_point: Tensor, bit_width: Tensor, x: Tenso y_int = self.to_int(scale, zero_point, bit_width, x) y = y_int - zero_point y = y * scale - y = self.delay_wrapper(x, y) return y @@ -102,7 +97,6 @@ class DecoupledIntQuant(brevitas.jit.ScriptModule): float_to_int_impl (Module): Module that performs the conversion from floating point to integer representation. Default: RoundSte() tensor_clamp_impl (Module): Module that performs clamping. Default: TensorClamp() - quant_delay_steps (int): Number of training steps to delay quantization for. Default: 0 Returns: Tensor: Quantized output in de-quantized format. @@ -124,19 +118,17 @@ class DecoupledIntQuant(brevitas.jit.ScriptModule): __constants__ = ['signed', 'narrow_range'] def __init__( - self, - narrow_range: bool, - signed: bool, - input_view_impl: Module, - float_to_int_impl: Module = RoundSte(), - tensor_clamp_impl: Module = TensorClamp(), - quant_delay_steps: int = 0): + self, + narrow_range: bool, + signed: bool, + input_view_impl: Module, + float_to_int_impl: Module = RoundSte(), + tensor_clamp_impl: Module = TensorClamp()): super(DecoupledIntQuant, self).__init__() self.float_to_int_impl = float_to_int_impl self.tensor_clamp_impl = tensor_clamp_impl self.signed = signed self.narrow_range = narrow_range - self.delay_wrapper = DelayWrapper(quant_delay_steps) self.input_view_impl = input_view_impl @brevitas.jit.script_method @@ -172,5 +164,4 @@ def forward( y_int = self.to_int(pre_scale, pre_zero_point, bit_width, x) y = y_int - zero_point y = y * scale - y = self.delay_wrapper(x, y) return y diff --git a/src/brevitas/core/quant/ternary.py b/src/brevitas/core/quant/ternary.py index ffaa873de..9fd8f78ce 100644 --- a/src/brevitas/core/quant/ternary.py +++ b/src/brevitas/core/quant/ternary.py @@ -9,7 +9,6 @@ import brevitas from brevitas.core.bit_width import BitWidthConst -from brevitas.core.quant.delay import DelayWrapper from brevitas.core.utils import StatelessBuffer from brevitas.function.ops_ste import ternary_sign_ste @@ -57,7 +56,6 @@ def __init__(self, scaling_impl: Module, threshold: float, quant_delay_steps: in self.threshold = threshold self.bit_width = BitWidthConst(2) self.zero_point = StatelessBuffer(torch.tensor(0.0)) - self.delay_wrapper = DelayWrapper(quant_delay_steps) @brevitas.jit.script_method def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: @@ -65,5 +63,4 @@ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: mask = x.abs().gt(self.threshold * scale) y = mask.float() * ternary_sign_ste(x) y = y * scale - y = self.delay_wrapper(x, y) return y, scale, self.zero_point(), self.bit_width() diff --git a/src/brevitas/proxy/parameter_quant.py b/src/brevitas/proxy/parameter_quant.py index 5c4e447d4..b29962059 100644 --- a/src/brevitas/proxy/parameter_quant.py +++ b/src/brevitas/proxy/parameter_quant.py @@ -16,9 +16,9 @@ from brevitas import config from brevitas import is_dynamo_compiling from brevitas.core.function_wrapper.misc import Identity +from brevitas.core.quant.delay import DelayWrapper from brevitas.function import max_int from brevitas.inject import BaseInjector as Injector -from brevitas.quant_tensor import _unpack_quant_tensor from brevitas.quant_tensor import IntQuantTensor from brevitas.quant_tensor import QuantTensor from brevitas.utils.quant_utils import _CachedIO @@ -96,6 +96,8 @@ def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None: self.cache_class = None # To be redefined by each class self.quant_tensor_class = None # To be redefined by each class self.skip_create_quant_tensor = False + quant_delay_steps = quant_injector.quant_delay_steps if 'quant_delay_steps' in quant_injector else None + self.delay_wrapper = DelayWrapper(quant_delay_steps) @property def input_view_impl(self): @@ -138,11 +140,13 @@ def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]: else: out = self.create_quant_tensor(out) else: - out = self.tensor_quant(x) + quant_value, *quant_args = self.tensor_quant(x) + quant_args = tuple(quant_args) + quant_value = self.delay_wrapper(x, quant_value) if self.skip_create_quant_tensor: - out = out[0] + out = quant_value else: - out = self.create_quant_tensor(out) + out = self.create_quant_tensor((quant_value,) + quant_args) if not self.training and self.cache_inference_quant_weight and self._cached_weight is None: self._cached_weight = self.cache_class( out.detach(), diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index cff192490..64a8faefe 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -14,6 +14,7 @@ import brevitas from brevitas import is_dynamo_compiling +from brevitas.core.quant.delay import DelayWrapper from brevitas.quant_tensor import IntQuantTensor from brevitas.quant_tensor import QuantTensor from brevitas.utils.quant_utils import _CachedIO @@ -99,6 +100,8 @@ def __init__(self, quant_layer, quant_injector): self.cache_quant_io_metadata_only = True self.cache_class = None self.skip_create_quant_tensor = False + quant_delay_steps = quant_injector.quant_delay_steps if 'quant_delay_steps' in quant_injector else None + self.delay_wrapper = DelayWrapper(quant_delay_steps) @property def input_view_impl(self): @@ -176,31 +179,33 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: y = y.value if self.export_mode: - y = self.fused_activation_quant_proxy.activation_impl(y) - y = self.export_handler(y) + out = self.fused_activation_quant_proxy.activation_impl(y) + out = self.export_handler(out) elif not self.is_quant_enabled: # A tuple helps later with control flows # The second None value is used later # If quant is not enabled, we still apply input_view in the case of groupwise + padding - y = self.apply_input_view(self.fused_activation_quant_proxy.activation_impl(y)) - y = (y, None) + out = self.apply_input_view(self.fused_activation_quant_proxy.activation_impl(y)) + out = (out, None) else: - y = self.fused_activation_quant_proxy(y) + out = self.fused_activation_quant_proxy(y) # If y is an empty QuantTensor, we need to check if this is a passthrough proxy, # otherwise return a simple Tensor - + quant_value, *quant_args = out + quant_args = tuple(quant_args) + quant_value = self.delay_wrapper(y, quant_value) if self.skip_create_quant_tensor: - out = y[0] + out = quant_value else: # If the second value (i.e., scale) is None, then quant is disabled - if y[1] is not None: - out = self.create_quant_tensor(y) + if out[1] is not None: + out = self.create_quant_tensor((quant_value,) + quant_args) elif self.is_passthrough_act and isinstance(x, QuantTensor): # preserve scale/zp/bit/sign even without output quant - y = y[0] - out = self.create_quant_tensor(y, x=x) + out = quant_value + out = self.create_quant_tensor(out, x=x) else: - out = y[0] + out = quant_value if not self.training and self.cache_inference_quant_act and isinstance(out, QuantTensor): cached_out = self.cache_class(out.detach(), self.cache_quant_io_metadata_only) @@ -267,6 +272,8 @@ class TruncQuantProxyFromInjector(QuantProxyFromInjector, AccQuantProxyProtocol) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.skip_create_quant_tensor = False + quant_delay_steps = self.quant_injector.quant_delay_steps if 'quant_delay_steps' in self.quant_injector else None + self.delay_wrapper = DelayWrapper(quant_delay_steps) def bit_width(self): if not self.is_quant_enabled: @@ -285,6 +292,7 @@ def forward(self, x: IntQuantTensor) -> Union[Tensor, IntQuantTensor]: else: out_tuple = self.tensor_quant(x.value, x.scale, x.zero_point, x.bit_width) out_value, out_scale, out_zp, out_bit_width = out_tuple + out_value = self.delay_wrapper(x, out_value) if self.skip_create_quant_tensor: return out_value return IntQuantTensor( diff --git a/tests/brevitas/core/binary_quant_fixture.py b/tests/brevitas/core/binary_quant_fixture.py index 32937c2c4..f7fab6a1c 100644 --- a/tests/brevitas/core/binary_quant_fixture.py +++ b/tests/brevitas/core/binary_quant_fixture.py @@ -10,11 +10,8 @@ __all__ = [ 'binary_quant', 'clamped_binary_quant', - 'delayed_binary_quant', - 'delayed_clamped_binary_quant', 'binary_quant_impl_all', 'binary_quant_all', # noqa - 'delayed_binary_quant_all', # noqa ] @@ -43,21 +40,4 @@ def clamped_binary_quant(scaling_impl_all): return ClampedBinaryQuant(scaling_impl=scaling_impl_all) -@pytest_cases.fixture() -def delayed_binary_quant(scaling_impl_all, quant_delay_steps): - """ - Delayed BinaryQuant with all variants of scaling - """ - return BinaryQuant(scaling_impl=scaling_impl_all, quant_delay_steps=quant_delay_steps) - - -@pytest_cases.fixture() -def delayed_clamped_binary_quant(scaling_impl_all, quant_delay_steps): - """ - ClampedBinaryQuant with all variants of scaling - """ - return ClampedBinaryQuant(scaling_impl=scaling_impl_all, quant_delay_steps=quant_delay_steps) - - fixture_union('binary_quant_all', ['binary_quant', 'clamped_binary_quant']) -fixture_union('delayed_binary_quant_all', ['delayed_binary_quant', 'delayed_clamped_binary_quant']) diff --git a/tests/brevitas/core/shared_quant_fixture.py b/tests/brevitas/core/shared_quant_fixture.py index baf565d06..f2a440823 100644 --- a/tests/brevitas/core/shared_quant_fixture.py +++ b/tests/brevitas/core/shared_quant_fixture.py @@ -9,7 +9,6 @@ from brevitas.core.scaling import ParameterScaling __all__ = [ - 'quant_delay_steps', 'const_scaling_impl', 'parameter_scaling_impl', 'standalone_scaling_init', @@ -18,15 +17,6 @@ ] -@pytest_cases.fixture() -@pytest_cases.parametrize('steps', [1, 10]) -def quant_delay_steps(steps): - """ - Non-zero steps to delay quantization - """ - return steps - - @pytest_cases.fixture() def const_scaling_impl(standalone_scaling_init): """ diff --git a/tests/brevitas/core/ternary_quant_fixture.py b/tests/brevitas/core/ternary_quant_fixture.py index 2cb7ade78..782631817 100644 --- a/tests/brevitas/core/ternary_quant_fixture.py +++ b/tests/brevitas/core/ternary_quant_fixture.py @@ -5,7 +5,7 @@ from brevitas.core.quant import TernaryQuant -__all__ = ['threshold_init', 'ternary_quant', 'delayed_ternary_quant'] +__all__ = ['threshold_init', 'ternary_quant'] @pytest_cases.fixture() @@ -22,14 +22,3 @@ def ternary_quant(scaling_impl_all, threshold_init): Ternary quant with all variants of scaling """ return TernaryQuant(scaling_impl=scaling_impl_all, threshold=threshold_init) - - -@pytest_cases.fixture() -def delayed_ternary_quant(scaling_impl_all, quant_delay_steps, threshold_init): - """ - Delayed TernaryQuant with all variants of scaling - """ - return TernaryQuant( - scaling_impl=scaling_impl_all, - quant_delay_steps=quant_delay_steps, - threshold=threshold_init) diff --git a/tests/brevitas/core/test_binary_quant.py b/tests/brevitas/core/test_binary_quant.py index 4f82e4815..bef166053 100644 --- a/tests/brevitas/core/test_binary_quant.py +++ b/tests/brevitas/core/test_binary_quant.py @@ -57,17 +57,6 @@ def test_output_value(self, binary_quant_all, inp): output, scale, _, _ = binary_quant_all(inp) assert is_binary_output_value_correct(scale, output) - def test_delayed_output_value(self, delayed_binary_quant_all, quant_delay_steps, randn_inp): - """ - Test delayed quantization by a certain number of steps. Because delayed quantization is - stateful, we can't use Hypothesis to generate the input, so we resort to a basic fixture. - """ - for i in range(quant_delay_steps): - output, _, _, _ = delayed_binary_quant_all(randn_inp) - assert (output == randn_inp).all() - output, scale, _, _ = delayed_binary_quant_all(randn_inp) - assert is_binary_output_value_correct(scale, output) - @given(inp=float_tensor_random_shape_st()) def test_output_bit_width(self, binary_quant_all, inp): _, _, _, bit_width = binary_quant_all(inp) diff --git a/tests/brevitas/core/test_ternary_quant.py b/tests/brevitas/core/test_ternary_quant.py index d2b1817fc..2d6d66c1f 100644 --- a/tests/brevitas/core/test_ternary_quant.py +++ b/tests/brevitas/core/test_ternary_quant.py @@ -67,17 +67,6 @@ def test_output_value(self, ternary_quant, inp): output, scale, _, _ = ternary_quant(inp) assert is_ternary_output_value_correct(scale, output) - def test_delayed_output_value(self, delayed_ternary_quant, quant_delay_steps, randn_inp): - """ - Test delayed quantization by a certain number of steps. Because delayed quantization is - stateful, we can't use Hypothesis to generate the input, so we resort to a basic fixture. - """ - for i in range(quant_delay_steps): - output, _, _, _ = delayed_ternary_quant(randn_inp) - assert (output == randn_inp).all() - output, scale, _, _ = delayed_ternary_quant(randn_inp) - assert is_ternary_output_value_correct(scale, output) - @given(inp=float_tensor_random_shape_st()) def test_output_bit_width(self, ternary_quant, inp): _, _, _, bit_width = ternary_quant(inp) diff --git a/tests/brevitas/proxy/test_proxy.py b/tests/brevitas/proxy/test_proxy.py index 28c3eed9e..cb4bedbf9 100644 --- a/tests/brevitas/proxy/test_proxy.py +++ b/tests/brevitas/proxy/test_proxy.py @@ -1,4 +1,5 @@ import pytest +import torch from brevitas.nn import QuantLinear from brevitas.nn.quant_activation import QuantReLU @@ -80,3 +81,11 @@ def test_dynamic_act_proxy(self): model.act_quant.disable_quant = True assert model.act_quant.bit_width() is None + + def test_delay_act_proxy(self): + model = QuantReLU(quant_delay_steps=1) + inp = torch.randn(1, 5) + o = model(inp) + assert torch.allclose(inp, o) + o = model(inp) + assert not torch.allclose(inp, o)