Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Oct 1, 2024
1 parent 8e168d2 commit c6d60dc
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/brevitas/core/scaling/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ def __init__(
@brevitas.jit.script_method
def forward(
self, ignored: torch.Tensor, threshold: Optional[torch.Tensor] = None) -> torch.Tensor:
stats = self.parameter_list_stats()
if threshold is None:
threshold = torch.ones(1).type_as(stats)
stats = self.parameter_list_stats()
return self.stats_scaling_impl(stats, threshold)


Expand Down
4 changes: 2 additions & 2 deletions tests/brevitas/core/test_binary_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_binary_quant(self, binary_quant_impl_all, inp, scale_init):
scaling_impl = mock.Mock(return_value=scale_init)
binary_quant = binary_quant_impl_all(scaling_impl)
output, scale, zp, bit_width = binary_quant(inp)
scaling_impl.assert_called_once_with(inp, torch.tensor(1.).type_as(inp))
scaling_impl.assert_called_once_with(inp)
assert is_binary_output_value_correct(scale, output)
assert is_binary_output_sign_correct(inp, output)
assert (scale == scale_init).all()
Expand Down Expand Up @@ -81,4 +81,4 @@ def test_output_zero_point(self, binary_quant_all, inp):
@given(inp=float_tensor_random_shape_st())
def test_output_scale(self, binary_quant_all, scaling_impl_all, inp):
_, scale, _, _ = binary_quant_all(inp)
assert_allclose(scale, scaling_impl_all(inp, torch.tensor(1.).type_as(inp)))
assert_allclose(scale, scaling_impl_all(inp))
4 changes: 2 additions & 2 deletions tests/brevitas/core/test_float_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def test_float_to_quant_float(inp, minifloat_format):
@jit_disabled_for_mock()
def test_scaling_impls_called_once(inp, minifloat_format):
bit_width, exponent_bit_width, mantissa_bit_width, signed, exponent_bias = minifloat_format
scaling_impl = mock.Mock(side_effect=lambda x: 1.)
scaling_impl = mock.Mock(side_effect=lambda x, y: 1.)
float_scaling_impl = mock.Mock(side_effect=lambda x, y, z: 1.)
if exponent_bit_width == 0 or mantissa_bit_width == 0:
with pytest.raises(RuntimeError):
Expand Down Expand Up @@ -160,7 +160,7 @@ def test_inner_scale(inp, minifloat_format, scale):
bit_width, exponent_bit_width, mantissa_bit_width, signed, exponent_bias = minifloat_format
# set scaling_impl to scale and float_scaling_impl to 1 to use the same scale as we are here
float_scaling_impl = mock.Mock(side_effect=lambda x, y, z: 1.)
scaling_impl = mock.Mock(side_effect=lambda x: scale)
scaling_impl = mock.Mock(side_effect=lambda x, y: scale)
if exponent_bit_width == 0 or mantissa_bit_width == 0:
with pytest.raises(RuntimeError):
float_quant = FloatQuant(
Expand Down

0 comments on commit c6d60dc

Please sign in to comment.