diff --git a/tests/brevitas_examples/test_quantize_model.py b/tests/brevitas_examples/test_quantize_model.py index 4ec5471be..6df94ee23 100644 --- a/tests/brevitas_examples/test_quantize_model.py +++ b/tests/brevitas_examples/test_quantize_model.py @@ -4,16 +4,16 @@ import torch import torch.nn as nn +from brevitas.core.function_wrapper.shape import OverOutputChannelView +from brevitas.core.function_wrapper.shape import OverTensorView +from brevitas.core.stats.stats_op import MSE +from brevitas.graph.calibrate import calibration_mode from brevitas.nn import QuantConv2d from brevitas.nn import QuantLinear from brevitas.nn import QuantReLU from brevitas.quant_tensor import QuantTensor from brevitas_examples.imagenet_classification.ptq.ptq_common import quantize_model -# TODO: -# - Possibility to use statistics or MSE for scale factor computations for weights and activations. -# - Percentiles used for the activations' statistics computation during calibration. - # CONSTANTS IMAGE_DIM = 16 @@ -282,6 +282,90 @@ def test_fx_affine_quantization(simple_model): assert last_layer_output.act_quant.bit_width().item() == act_bit_width +@pytest.mark.parametrize("weight_bit_width", [2, 8, 16]) +@pytest.mark.parametrize("act_bit_width", [2, 5, 8]) +@pytest.mark.parametrize("bias_bit_width", [16, 32, 0]) +def test_fx_param_method_stats(simple_model, weight_bit_width, bias_bit_width, act_bit_width): + """ + We test fx quantization, with the weight and activation quantization `stats` parameter methods. + `stats` is the default setting, but we also test it explicitly in case it ever changes from the + default. + + We test: + - The FX-graph, quantized model is a GraphModule. + - We can feed data through the model. + - That the weight, bias and input/output quantization is toggled as expected. + - That setting `None` for the `bias_bit_width` returns a dequantized bias. + - That the bit widths are as desired. + """ + fx_model = torch.fx.symbolic_trace(simple_model) + quant_model = quantize_model( + model=fx_model, + backend='fx', + weight_bit_width=weight_bit_width, + act_bit_width=act_bit_width, + bias_bit_width=bias_bit_width if bias_bit_width > 0 else None, + weight_quant_granularity='per_tensor', + act_quant_percentile=99.9, + act_quant_type='sym', + scale_factor_type='float_scale', + quant_format='int', + layerwise_first_last_bit_width=5, + weight_param_method='stats', + act_param_method='stats', + ) + # Assert it is a GraphModule + assert isinstance(quant_model, torch.fx.graph_module.GraphModule) + + # Assert we can feed data of the correct size through the model + quant_model(torch.rand(1, 10, IMAGE_DIM, IMAGE_DIM)) + + # Get first/last layer for testing its quantization. + first_conv_layer = quant_model.get_submodule('0') + first_relu_layer = quant_model.get_submodule('1') + last_layer = quant_model.get_submodule('6') + last_layer_output = quant_model.get_submodule('_6_output_quant') + + # Assert the module types are as desired + assert isinstance(first_conv_layer, QuantConv2d) + assert isinstance(last_layer, QuantLinear) + + # Check quantizaton is toggled as expected + if bias_bit_width == 0: + # If bias_bit_width is set as `None` (local variable value in scope of this function is 0), + # the bias should be dequantized. + assert not first_conv_layer.bias_quant.is_quant_enabled + else: + assert first_conv_layer.bias_quant.is_quant_enabled + assert first_conv_layer.weight_quant.is_quant_enabled + assert not first_conv_layer.input_quant.is_quant_enabled # unlike with the layerwise backend, the input quantization is disabled. + assert not first_conv_layer.output_quant.is_quant_enabled + + assert not first_relu_layer.input_quant.is_quant_enabled + assert first_relu_layer.act_quant.is_quant_enabled # the output of the "fused" ConvReLU is quantized + + # Assert types are as expected + assert isinstance(quant_model.get_submodule('3'), QuantReLU) + + # Assert quantization bit widths are as desired + # Biases + if bias_bit_width > 0: + assert first_conv_layer.bias_quant.tensor_quant.msb_clamp_bit_width_impl.bit_width._buffers[ + 'value'].item() == bias_bit_width + assert last_layer.bias_quant.tensor_quant.msb_clamp_bit_width_impl.bit_width._buffers[ + 'value'].item() == bias_bit_width + else: + # If bias_bit_width is `None`, the quantized bias should return a fully floating point parameter. + assert not isinstance(first_conv_layer.quant_bias(), QuantTensor) + + # Weights + assert first_conv_layer.weight_quant.bit_width().item() == weight_bit_width + assert last_layer.weight_quant.bit_width().item() == weight_bit_width + # Activations + assert first_relu_layer.act_quant.bit_width().item() == act_bit_width + assert last_layer_output.act_quant.bit_width().item() == act_bit_width + + def test_fx_per_chan_weight_quantization(simple_model): """ We test per-channel weight quantization. @@ -306,8 +390,10 @@ def test_fx_per_chan_weight_quantization(simple_model): act_quant_percentile=99.9, act_quant_type='sym', scale_factor_type='float_scale', - quant_format='float', + quant_format='int', ) + # Assert it is a GraphModule + assert isinstance(quant_model, torch.fx.graph_module.GraphModule) # Assert we can feed data of the correct size through the model quant_model(torch.rand(1, 10, IMAGE_DIM, IMAGE_DIM)) @@ -414,6 +500,152 @@ def test_invalid_input(simple_model): ########################## # LAYERWISE MODE TESTING # ########################## +@pytest.mark.parametrize("act_quant_percentile", [1, 50, 99.9]) +def test_layerwise_percentile_for_calibration(simple_model, act_quant_percentile): + """ + We test different values for the percentile used for the activations' statistics computation + during calibration. + We test if the percentile correcrly produces the desired qparams for a `QuantIdentity` + when fed a tensor with linearly scaled values between 0 and 1. + + We test: + - We can feed data through the model. + - The desired qparams manifest under controlled conditions as a result of calibration + percentiles. + """ + weight_bit_width = 8 + act_bit_width = 8 + bias_bit_width = 32 + + quant_model = quantize_model( + model=simple_model, + backend='layerwise', + weight_bit_width=weight_bit_width, + act_bit_width=act_bit_width, + bias_bit_width=bias_bit_width, + weight_quant_granularity='per_tensor', + act_quant_percentile=act_quant_percentile, + act_quant_type='asym', + scale_factor_type='float_scale', + quant_format='int', + ) + + # Assert we can feed data of the correct size through the model + # We are also performing calibration + quant_model.train() + # We create an input with values linearly scaled between 0 and 1. + input = torch.arange(0, 1, step=1 / (10 * IMAGE_DIM ** 2)) + input = input.view(1, 10, IMAGE_DIM, IMAGE_DIM).float() + with torch.no_grad(): + with calibration_mode(quant_model): + for _ in range(1000): + quant_model(input) + quant_model.eval() + + # Get first/last layer for testing its quantization. + first_conv_layer = quant_model.get_submodule('0') + + # We check the calibration. We do so by ensuring that the quantization range is within a few quantization + # bin tolerance of the `act_quant_percentile`. This should be the case, given that we are doing + # affine quantization with a strictly positive tensor (and so zero_point should be 0), and because + # we are doing the calibration with a tensor with all values linearly increasinging from 0 to 1. + assert torch.isclose(first_conv_layer.input_quant.zero_point(), torch.Tensor([0.])) + tolerance = 8 # quantization bins of tolerance, on the "plus side". + ideal_range = act_quant_percentile / 100 + scale = first_conv_layer.input_quant.scale() + + # The quantization range is always smaller than the data covered up to the percentile, because + # of how the percnetile->qrange calculation happens. + assert ideal_range > scale * 255 + # We make sure the quantization range is still reasonably close to covering the entire data, up + # to the provided percentile. + assert ideal_range < scale * (255 + tolerance) + + +@pytest.mark.parametrize("quant_granularity", ["per_tensor", "per_channel"]) +def test_layerwise_param_method_mse(simple_model, quant_granularity): + """ + We test layerwise quantization, with the weight and activation quantization `mse` parameter + methods. + + We test: + - We can feed data through the model. + - That the stat observer is explictly MSE. + - That the view on the quantization granularity is as desired. + - That during calibration, the qparams are derived by finding values that minimize the MSE + between the floating point and quantized tensor. + """ + weight_bit_width = 8 + act_bit_width = 8 + bias_bit_width = 32 + quant_model = quantize_model( + model=simple_model, + backend='layerwise', + weight_bit_width=weight_bit_width, + act_bit_width=act_bit_width, + bias_bit_width=bias_bit_width if bias_bit_width > 0 else None, + weight_quant_granularity=quant_granularity, + act_quant_type='asym', + act_quant_percentile=99.9, # Unused + scale_factor_type='float_scale', + quant_format='int', + weight_param_method='mse', + act_param_method='mse', + ) + + # Assert we can feed data of the correct size through the model + # We are also performing calibration + quant_model.train() + # We create an input with values linearly scaled between 0 and 1. + input = torch.arange(0, 1, step=1 / (10 * IMAGE_DIM ** 2)) + input = input.view(1, 10, IMAGE_DIM, IMAGE_DIM).float() + with torch.no_grad(): + with calibration_mode(quant_model): + quant_model(input) + quant_model.eval() + + # Get first/last layer for testing its quantization. + first_conv_layer = quant_model.get_submodule('0') + last_layer = quant_model.get_submodule('6') + + # Check that the quant param method module is MSE as it should be + # Weights + first_weight_param_mod = first_conv_layer.weight_quant.tensor_quant.scaling_impl.parameter_list_stats.stats.stats_impl + last_weight_param_mod = last_layer.weight_quant.tensor_quant.scaling_impl.parameter_list_stats.stats.stats_impl + assert isinstance(first_weight_param_mod, MSE) + assert isinstance(last_weight_param_mod, MSE) + + # Check observation is over tensor or channel as desired + def check_dim_of_observation(module: torch.nn.Module, quant_granularity: str): + if quant_granularity == 'per_tensor': + assert isinstance(module.input_view_shape_impl, OverTensorView) + elif quant_granularity == 'per_channel': + assert isinstance(module.input_view_shape_impl, OverOutputChannelView) + + # Weight + check_dim_of_observation(first_weight_param_mod, quant_granularity) + check_dim_of_observation(last_weight_param_mod, quant_granularity) + + # We test the calibrated qparams. We fed in a tensor with linearly scaled values between 0 and 1. + # We check that varying the qparams gives worse or equal MSE than the calibrated qparams. + # We assume a convex problem. + scale = first_conv_layer.input_quant.scale() + zero_point = first_conv_layer.input_quant.zero_point() + + def get_qmse( + scale: torch.Tensor, zero_point: torch.Tensor, input: torch.Tensor) -> torch.Tensor: + quant_tensor = scale * ( + torch.clamp(torch.round(input / scale + zero_point), 0, 255) - zero_point) + mse = torch.mean((quant_tensor - input) ** 2) + return mse + + orig_mse = get_qmse(scale, zero_point, input) + for scale_diff in [0.1 * scale, 0, -0.1 * scale]: + for zero_diff in [1, 0, -1]: + diff_mse = get_qmse(scale + scale_diff, zero_point + zero_diff, input) + assert torch.isclose(diff_mse, orig_mse) or (diff_mse > orig_mse) + + @pytest.mark.parametrize("weight_bit_width", [2, 5, 8, 16]) @pytest.mark.parametrize("act_bit_width", [2, 5, 8]) @pytest.mark.parametrize("bias_bit_width", [16, 32])