Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix weights mse #1047

Merged
merged 3 commits into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/brevitas/core/scaling/standalone.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def forward(self, ignored: Tensor, threshold: Optional[Tensor] = None) -> Tensor
# workaround to avoid find_ununsed_parameter=True in DDP
stats = stats + 0. * self.value
if self.local_loss_mode:
return self.stats_scaling_impl(stats)
return self.stats_scaling_impl(stats, threshold)
stats = self.restrict_inplace_preprocess(stats)
threshold = self.restrict_inplace_preprocess(threshold)
inplace_tensor_mul(self.value.detach(), stats)
Expand Down
63 changes: 63 additions & 0 deletions tests/brevitas_examples/test_quantize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,69 @@ def get_qmse(
assert torch.isclose(diff_mse, orig_mse) or (diff_mse > orig_mse)


@pytest.mark.parametrize("quant_granularity", ["per_tensor", "per_channel"])
@jit_disabled_for_local_loss()
def test_layerwise_stats_vs_mse(simple_model, quant_granularity):
"""
We test layerwise quantization, with the weight and activation quantization `mse` parameter
methods.

We test:
- Recostruction error of MSE should be smaller or equal to stats
"""
weight_bit_width = 8
act_bit_width = 8
bias_bit_width = 32
quant_model_mse = quantize_model(
model=deepcopy(simple_model),
backend='layerwise',
weight_bit_width=weight_bit_width,
act_bit_width=act_bit_width,
bias_bit_width=bias_bit_width if bias_bit_width > 0 else None,
weight_quant_granularity=quant_granularity,
act_quant_type='asym',
act_quant_percentile=99.9, # Unused
scale_factor_type='float_scale',
quant_format='int',
weight_param_method='mse',
act_param_method='mse')

quant_model_stats = quantize_model(
model=deepcopy(simple_model),
backend='layerwise',
weight_bit_width=weight_bit_width,
act_bit_width=act_bit_width,
bias_bit_width=bias_bit_width if bias_bit_width > 0 else None,
weight_quant_granularity=quant_granularity,
act_quant_type='asym',
act_quant_percentile=99.9, # Unused
scale_factor_type='float_scale',
quant_format='int',
weight_param_method='stats',
act_param_method='mse')

# We create an input with values linearly scaled between 0 and 1.
input = torch.arange(0, 1, step=1 / (10 * IMAGE_DIM ** 2))
input = input.view(1, 10, IMAGE_DIM, IMAGE_DIM).float()
with torch.no_grad():
with calibration_mode(quant_model_mse):
quant_model_mse(input)
quant_model_mse.eval()
with torch.no_grad():
with calibration_mode(quant_model_stats):
quant_model_stats(input)
quant_model_stats.eval()
weight = simple_model.layers.get_submodule('0').weight
first_conv_layer_mse = quant_model_mse.layers.get_submodule('0')
first_conv_layer_stats = quant_model_stats.layers.get_submodule('0')

l2_stats = ((weight - first_conv_layer_stats.quant_weight().value) ** 2).sum()
l2_mse = ((weight - first_conv_layer_mse.quant_weight().value) ** 2).sum()

# Recostruction error of MSE should be smaller or equal to stats
assert l2_mse - l2_stats <= torch.tensor(1e-5)


@pytest.mark.parametrize("weight_bit_width", [2, 5, 8, 16])
@pytest.mark.parametrize("act_bit_width", [2, 5, 8])
@pytest.mark.parametrize("bias_bit_width", [16, 32])
Expand Down
Loading