From 13ec5ddacbe738decda5b7779df00d7f63de0ede Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 23 Jul 2024 13:52:34 +0100 Subject: [PATCH] Use Brevitas FX --- .../brevitas_examples/test_quantize_model.py | 99 ++++++++++--------- 1 file changed, 52 insertions(+), 47 deletions(-) diff --git a/tests/brevitas_examples/test_quantize_model.py b/tests/brevitas_examples/test_quantize_model.py index 6df94ee23..6a7184131 100644 --- a/tests/brevitas_examples/test_quantize_model.py +++ b/tests/brevitas_examples/test_quantize_model.py @@ -4,6 +4,7 @@ import torch import torch.nn as nn +import brevitas from brevitas.core.function_wrapper.shape import OverOutputChannelView from brevitas.core.function_wrapper.shape import OverTensorView from brevitas.core.stats.stats_op import MSE @@ -54,14 +55,24 @@ def simple_model(): """ assert IMAGE_DIM % 2 == 0, "`IMAGE_DIM` should be a multiple of 2" - return nn.Sequential( - nn.Conv2d(10, 16, kernel_size=3, padding=1), - nn.ReLU(), - nn.Conv2d(16, 32, kernel_size=3, padding=1), - nn.ReLU(), - nn.MaxPool2d(2), # downsample from IMAGE_DIM to half that - nn.Flatten(), - nn.Linear(32 * int(IMAGE_DIM / 2) ** 2, 1000)) + + class Model(torch.nn.Module): + + def __init__(self): + super().__init__() + self.layers = nn.Sequential( + nn.Conv2d(10, 16, kernel_size=3, padding=1), + nn.ReLU(), + nn.Conv2d(16, 32, kernel_size=3, padding=1), + nn.ReLU(), + nn.MaxPool2d(2), # downsample from IMAGE_DIM to half that + nn.Flatten(), + nn.Linear(32 * int(IMAGE_DIM / 2) ** 2, 1000)) + + def forward(self, x): + return self.layers(x) + + return Model() ############## @@ -84,7 +95,7 @@ def test_fx_model(simple_model, weight_bit_width, bias_bit_width, act_bit_width) - That setting `None` for the `bias_bit_width` returns a dequantized bias. - That the bit widths are as desired. """ - fx_model = torch.fx.symbolic_trace(simple_model) + fx_model = brevitas.fx.symbolic_trace(simple_model) quant_model = quantize_model( model=fx_model, backend='fx', @@ -98,17 +109,15 @@ def test_fx_model(simple_model, weight_bit_width, bias_bit_width, act_bit_width) quant_format='int', layerwise_first_last_bit_width=5, ) - # Assert it is a GraphModule - assert isinstance(quant_model, torch.fx.graph_module.GraphModule) # Assert we can feed data of the correct size through the model quant_model(torch.rand(1, 10, IMAGE_DIM, IMAGE_DIM)) # Get first/last layer for testing its quantization. - first_conv_layer = quant_model.get_submodule('0') - first_relu_layer = quant_model.get_submodule('1') - last_layer = quant_model.get_submodule('6') - last_layer_output = quant_model.get_submodule('_6_output_quant') + first_conv_layer = quant_model.layers.get_submodule('0') + first_relu_layer = quant_model.layers.get_submodule('1') + last_layer = quant_model.layers.get_submodule('6') + last_layer_output = quant_model.get_submodule('layers_6_output_quant') # Assert the module types are as desired assert isinstance(first_conv_layer, QuantConv2d) @@ -129,7 +138,7 @@ def test_fx_model(simple_model, weight_bit_width, bias_bit_width, act_bit_width) assert first_relu_layer.act_quant.is_quant_enabled # the output of the "fused" ConvReLU is quantized # Assert types are as expected - assert isinstance(quant_model.get_submodule('3'), QuantReLU) + assert isinstance(quant_model.layers.get_submodule('3'), QuantReLU) # Assert quantization bit widths are as desired # Biases @@ -164,7 +173,7 @@ def test_fx_sym_quant(simple_model): act_bit_width = 8 bias_bit_width = 32 - fx_model = torch.fx.symbolic_trace(simple_model) + fx_model = brevitas.fx.symbolic_trace(simple_model) quant_model = quantize_model( model=fx_model, backend='fx', @@ -183,10 +192,10 @@ def test_fx_sym_quant(simple_model): quant_model(torch.rand(1, 10, IMAGE_DIM, IMAGE_DIM)) # Get first/last layer for testing its quantization. - first_conv_layer = quant_model.get_submodule('0') - first_relu_layer = quant_model.get_submodule('1') - last_layer = quant_model.get_submodule('6') - last_layer_output = quant_model.get_submodule('_6_output_quant') + first_conv_layer = quant_model.layers.get_submodule('0') + first_relu_layer = quant_model.layers.get_submodule('1') + last_layer = quant_model.layers.get_submodule('6') + last_layer_output = quant_model.get_submodule('layers_6_output_quant') # Check quantizaton is toggled as expected assert first_conv_layer.bias_quant.is_quant_enabled @@ -235,7 +244,7 @@ def test_fx_affine_quantization(simple_model): act_bit_width = 8 bias_bit_width = 32 - fx_model = torch.fx.symbolic_trace(simple_model) + fx_model = brevitas.fx.symbolic_trace(simple_model) quant_model = quantize_model( model=fx_model, backend='fx', @@ -254,9 +263,9 @@ def test_fx_affine_quantization(simple_model): quant_model(torch.rand(1, 10, IMAGE_DIM, IMAGE_DIM)) # Get first/last layer for testing its quantization. - first_conv_layer = quant_model.get_submodule('0') - last_layer = quant_model.get_submodule('6') - last_layer_output = quant_model.get_submodule('_6_output_quant') + first_conv_layer = quant_model.layers.get_submodule('0') + last_layer = quant_model.layers.get_submodule('6') + last_layer_output = quant_model.get_submodule('layers_6_output_quant') # Assert the tensors are unsigned as expected for asymmetric quantization, with zero-points not at 0. # Weights @@ -298,7 +307,7 @@ def test_fx_param_method_stats(simple_model, weight_bit_width, bias_bit_width, a - That setting `None` for the `bias_bit_width` returns a dequantized bias. - That the bit widths are as desired. """ - fx_model = torch.fx.symbolic_trace(simple_model) + fx_model = brevitas.fx.symbolic_trace(simple_model) quant_model = quantize_model( model=fx_model, backend='fx', @@ -314,17 +323,15 @@ def test_fx_param_method_stats(simple_model, weight_bit_width, bias_bit_width, a weight_param_method='stats', act_param_method='stats', ) - # Assert it is a GraphModule - assert isinstance(quant_model, torch.fx.graph_module.GraphModule) # Assert we can feed data of the correct size through the model quant_model(torch.rand(1, 10, IMAGE_DIM, IMAGE_DIM)) # Get first/last layer for testing its quantization. - first_conv_layer = quant_model.get_submodule('0') - first_relu_layer = quant_model.get_submodule('1') - last_layer = quant_model.get_submodule('6') - last_layer_output = quant_model.get_submodule('_6_output_quant') + first_conv_layer = quant_model.layers.get_submodule('0') + first_relu_layer = quant_model.layers.get_submodule('1') + last_layer = quant_model.layers.get_submodule('6') + last_layer_output = quant_model.get_submodule('layers_6_output_quant') # Assert the module types are as desired assert isinstance(first_conv_layer, QuantConv2d) @@ -345,7 +352,7 @@ def test_fx_param_method_stats(simple_model, weight_bit_width, bias_bit_width, a assert first_relu_layer.act_quant.is_quant_enabled # the output of the "fused" ConvReLU is quantized # Assert types are as expected - assert isinstance(quant_model.get_submodule('3'), QuantReLU) + assert isinstance(quant_model.layers.get_submodule('3'), QuantReLU) # Assert quantization bit widths are as desired # Biases @@ -379,7 +386,7 @@ def test_fx_per_chan_weight_quantization(simple_model): act_bit_width = 8 bias_bit_width = 32 - fx_model = torch.fx.symbolic_trace(simple_model) + fx_model = brevitas.fx.symbolic_trace(simple_model) quant_model = quantize_model( model=fx_model, backend='fx', @@ -392,16 +399,14 @@ def test_fx_per_chan_weight_quantization(simple_model): scale_factor_type='float_scale', quant_format='int', ) - # Assert it is a GraphModule - assert isinstance(quant_model, torch.fx.graph_module.GraphModule) # Assert we can feed data of the correct size through the model quant_model(torch.rand(1, 10, IMAGE_DIM, IMAGE_DIM)) # Get first/last layer for testing its quantization. - first_conv_layer = quant_model.get_submodule('0') - last_layer = quant_model.get_submodule('6') - last_layer_output = quant_model.get_submodule('_6_output_quant') + first_conv_layer = quant_model.layers.get_submodule('0') + last_layer = quant_model.layers.get_submodule('6') + last_layer_output = quant_model.get_submodule('layers_6_output_quant') # Assert per-channel quantization of weights # 16 is the nb of output channels of first layer of `simple_model` @@ -426,7 +431,7 @@ def test_invalid_input(simple_model): """ We test various invalid inputs, e.g. invalid strings and zero/negative bit widths. """ - fx_model = torch.fx.symbolic_trace(simple_model) + fx_model = brevitas.fx.symbolic_trace(simple_model) with pytest.raises(KeyError): quantize_model( model=fx_model, @@ -543,7 +548,7 @@ def test_layerwise_percentile_for_calibration(simple_model, act_quant_percentile quant_model.eval() # Get first/last layer for testing its quantization. - first_conv_layer = quant_model.get_submodule('0') + first_conv_layer = quant_model.layers.get_submodule('0') # We check the calibration. We do so by ensuring that the quantization range is within a few quantization # bin tolerance of the `act_quant_percentile`. This should be the case, given that we are doing @@ -605,8 +610,8 @@ def test_layerwise_param_method_mse(simple_model, quant_granularity): quant_model.eval() # Get first/last layer for testing its quantization. - first_conv_layer = quant_model.get_submodule('0') - last_layer = quant_model.get_submodule('6') + first_conv_layer = quant_model.layers.get_submodule('0') + last_layer = quant_model.layers.get_submodule('6') # Check that the quant param method module is MSE as it should be # Weights @@ -680,14 +685,14 @@ def test_layerwise_10_in_channels_quantize_model( quant_format='int', layerwise_first_last_bit_width=layerwise_first_last_bit_width, ) - assert isinstance(quant_model, nn.Sequential) + assert isinstance(quant_model.layers, nn.Sequential) # Make sure we can feed data through the model _ = quant_model(torch.rand(1, 10, IMAGE_DIM, IMAGE_DIM)) # Get first layer for testing its quantization. # We also test we can feed data through the first layer in isolation - first_layer = quant_model.get_submodule('0') + first_layer = quant_model.layers.get_submodule('0') first_layer_output = first_layer(torch.rand(1, 10, IMAGE_DIM, IMAGE_DIM)) # Assert the module types are as desired @@ -786,8 +791,8 @@ def test_po2_layerwise_quantization(simple_model): quant_model(torch.rand(1, 10, IMAGE_DIM, IMAGE_DIM)) # Get first/last layer for testing its quantization. - first_conv_layer = quant_model.get_submodule('0') - last_layer = quant_model.get_submodule('6') + first_conv_layer = quant_model.layers.get_submodule('0') + last_layer = quant_model.layers.get_submodule('6') # Assert scales are powers of 2 as expected assert torch.isclose(torch.log2(first_conv_layer.input_quant.scale()) % 1, torch.Tensor([0.0]))