Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Dec 30, 2024
1 parent c340d1a commit e4d25a0
Show file tree
Hide file tree
Showing 9 changed files with 9 additions and 76 deletions.
10 changes: 2 additions & 8 deletions src/brevitas/core/quant/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ class BinaryQuant(brevitas.jit.ScriptModule):
Args:
scaling_impl (Module): Module that returns a scale factor.
quant_delay_steps (int): Number of training steps to delay quantization for. Default: 0
Returns:
Tuple[Tensor, Tensor, Tensor, Tensor]: Quantized output in de-quantized format, scale, zero-point, bit_width.
Expand All @@ -47,7 +46,7 @@ class BinaryQuant(brevitas.jit.ScriptModule):
Set env variable BREVITAS_JIT=1 to enable TorchScript compilation of this module.
"""

def __init__(self, scaling_impl: Module, signed: bool = True, quant_delay_steps: int = 0):
def __init__(self, scaling_impl: Module, signed: bool = True):
super(BinaryQuant, self).__init__()
assert signed, "Unsigned binary quant not supported"
self.scaling_impl = scaling_impl
Expand All @@ -71,7 +70,6 @@ class ClampedBinaryQuant(brevitas.jit.ScriptModule):
Args:
scaling_impl (Module): Module that returns a scale factor.
tensor_clamp_impl (Module): Module that performs tensor-wise clamping. Default TensorClamp()
quant_delay_steps (int): Number of training steps to delay quantization for. Default: 0
Returns:
Tuple[Tensor, Tensor, Tensor, Tensor]: Quantized output in de-quantized format, scale, zero-point, bit_width.
Expand Down Expand Up @@ -101,11 +99,7 @@ class ClampedBinaryQuant(brevitas.jit.ScriptModule):
Set env variable BREVITAS_JIT=1 to enable TorchScript compilation of this module.
"""

def __init__(
self,
scaling_impl: Module,
tensor_clamp_impl: Module = TensorClamp(),
quant_delay_steps: int = 0):
def __init__(self, scaling_impl: Module, tensor_clamp_impl: Module = TensorClamp()):
super(ClampedBinaryQuant, self).__init__()
self.scaling_impl = scaling_impl
self.bit_width = BitWidthConst(1)
Expand Down
2 changes: 0 additions & 2 deletions src/brevitas/core/quant/int_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ class IntQuant(brevitas.jit.ScriptModule):
float_to_int_impl (Module): Module that performs the conversion from floating point to
integer representation. Default: RoundSte()
tensor_clamp_impl (Module): Module that performs clamping. Default: TensorClamp()
quant_delay_steps (int): Number of training steps to delay quantization for. Default: 0
Returns:
Tensor: Quantized output in de-quantized format.
Expand Down Expand Up @@ -98,7 +97,6 @@ class DecoupledIntQuant(brevitas.jit.ScriptModule):
float_to_int_impl (Module): Module that performs the conversion from floating point to
integer representation. Default: RoundSte()
tensor_clamp_impl (Module): Module that performs clamping. Default: TensorClamp()
quant_delay_steps (int): Number of training steps to delay quantization for. Default: 0
Returns:
Tensor: Quantized output in de-quantized format.
Expand Down
4 changes: 3 additions & 1 deletion src/brevitas/proxy/parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None:
self.cache_class = None # To be redefined by each class
self.quant_tensor_class = None # To be redefined by each class
self.skip_create_quant_tensor = False
self.delay_wrapper = DelayWrapper(quant_injector.quant_delay_steps)
quant_delay_steps = quant_injector.quant_delay_steps if 'quant_delay_steps' in quant_injector else None
self.delay_wrapper = DelayWrapper(quant_delay_steps)

@property
def input_view_impl(self):
Expand Down Expand Up @@ -140,6 +141,7 @@ def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]:
out = self.create_quant_tensor(out)
else:
quant_value, *quant_args = self.tensor_quant(x)
quant_args = tuple(quant_args)
quant_value = self.delay_wrapper(x, quant_value)
if self.skip_create_quant_tensor:
out = quant_value
Expand Down
4 changes: 3 additions & 1 deletion src/brevitas/proxy/runtime_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ def __init__(self, quant_layer, quant_injector):
self.cache_quant_io_metadata_only = True
self.cache_class = None
self.skip_create_quant_tensor = False
self.delay_wrapper = DelayWrapper(quant_injector.quant_delay_steps)
quant_delay_steps = quant_injector.quant_delay_steps if 'quant_delay_steps' in quant_injector else None
self.delay_wrapper = DelayWrapper(quant_delay_steps)

@property
def input_view_impl(self):
Expand Down Expand Up @@ -191,6 +192,7 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]:
# If y is an empty QuantTensor, we need to check if this is a passthrough proxy,
# otherwise return a simple Tensor
quant_value, *quant_args = out
quant_args = tuple(quant_args)
quant_value = self.delay_wrapper(y, quant_value)
if self.skip_create_quant_tensor:
out = quant_value
Expand Down
20 changes: 0 additions & 20 deletions tests/brevitas/core/binary_quant_fixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,8 @@
__all__ = [
'binary_quant',
'clamped_binary_quant',
'delayed_binary_quant',
'delayed_clamped_binary_quant',
'binary_quant_impl_all',
'binary_quant_all', # noqa
'delayed_binary_quant_all', # noqa
]


Expand Down Expand Up @@ -43,21 +40,4 @@ def clamped_binary_quant(scaling_impl_all):
return ClampedBinaryQuant(scaling_impl=scaling_impl_all)


@pytest_cases.fixture()
def delayed_binary_quant(scaling_impl_all, quant_delay_steps):
"""
Delayed BinaryQuant with all variants of scaling
"""
return BinaryQuant(scaling_impl=scaling_impl_all, quant_delay_steps=quant_delay_steps)


@pytest_cases.fixture()
def delayed_clamped_binary_quant(scaling_impl_all, quant_delay_steps):
"""
ClampedBinaryQuant with all variants of scaling
"""
return ClampedBinaryQuant(scaling_impl=scaling_impl_all, quant_delay_steps=quant_delay_steps)


fixture_union('binary_quant_all', ['binary_quant', 'clamped_binary_quant'])
fixture_union('delayed_binary_quant_all', ['delayed_binary_quant', 'delayed_clamped_binary_quant'])
10 changes: 0 additions & 10 deletions tests/brevitas/core/shared_quant_fixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from brevitas.core.scaling import ParameterScaling

__all__ = [
'quant_delay_steps',
'const_scaling_impl',
'parameter_scaling_impl',
'standalone_scaling_init',
Expand All @@ -18,15 +17,6 @@
]


@pytest_cases.fixture()
@pytest_cases.parametrize('steps', [1, 10])
def quant_delay_steps(steps):
"""
Non-zero steps to delay quantization
"""
return steps


@pytest_cases.fixture()
def const_scaling_impl(standalone_scaling_init):
"""
Expand Down
13 changes: 1 addition & 12 deletions tests/brevitas/core/ternary_quant_fixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from brevitas.core.quant import TernaryQuant

__all__ = ['threshold_init', 'ternary_quant', 'delayed_ternary_quant']
__all__ = ['threshold_init', 'ternary_quant']


@pytest_cases.fixture()
Expand All @@ -22,14 +22,3 @@ def ternary_quant(scaling_impl_all, threshold_init):
Ternary quant with all variants of scaling
"""
return TernaryQuant(scaling_impl=scaling_impl_all, threshold=threshold_init)


@pytest_cases.fixture()
def delayed_ternary_quant(scaling_impl_all, quant_delay_steps, threshold_init):
"""
Delayed TernaryQuant with all variants of scaling
"""
return TernaryQuant(
scaling_impl=scaling_impl_all,
quant_delay_steps=quant_delay_steps,
threshold=threshold_init)
11 changes: 0 additions & 11 deletions tests/brevitas/core/test_binary_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,17 +57,6 @@ def test_output_value(self, binary_quant_all, inp):
output, scale, _, _ = binary_quant_all(inp)
assert is_binary_output_value_correct(scale, output)

def test_delayed_output_value(self, delayed_binary_quant_all, quant_delay_steps, randn_inp):
"""
Test delayed quantization by a certain number of steps. Because delayed quantization is
stateful, we can't use Hypothesis to generate the input, so we resort to a basic fixture.
"""
for i in range(quant_delay_steps):
output, _, _, _ = delayed_binary_quant_all(randn_inp)
assert (output == randn_inp).all()
output, scale, _, _ = delayed_binary_quant_all(randn_inp)
assert is_binary_output_value_correct(scale, output)

@given(inp=float_tensor_random_shape_st())
def test_output_bit_width(self, binary_quant_all, inp):
_, _, _, bit_width = binary_quant_all(inp)
Expand Down
11 changes: 0 additions & 11 deletions tests/brevitas/core/test_ternary_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,17 +67,6 @@ def test_output_value(self, ternary_quant, inp):
output, scale, _, _ = ternary_quant(inp)
assert is_ternary_output_value_correct(scale, output)

def test_delayed_output_value(self, delayed_ternary_quant, quant_delay_steps, randn_inp):
"""
Test delayed quantization by a certain number of steps. Because delayed quantization is
stateful, we can't use Hypothesis to generate the input, so we resort to a basic fixture.
"""
for i in range(quant_delay_steps):
output, _, _, _ = delayed_ternary_quant(randn_inp)
assert (output == randn_inp).all()
output, scale, _, _ = delayed_ternary_quant(randn_inp)
assert is_ternary_output_value_correct(scale, output)

@given(inp=float_tensor_random_shape_st())
def test_output_bit_width(self, ternary_quant, inp):
_, _, _, bit_width = ternary_quant(inp)
Expand Down

0 comments on commit e4d25a0

Please sign in to comment.