Skip to content

Commit

Permalink
Use Brevitas FX
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Jul 23, 2024
1 parent 1f9e2f2 commit 13ec5dd
Showing 1 changed file with 52 additions and 47 deletions.
99 changes: 52 additions & 47 deletions tests/brevitas_examples/test_quantize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
import torch.nn as nn

import brevitas
from brevitas.core.function_wrapper.shape import OverOutputChannelView
from brevitas.core.function_wrapper.shape import OverTensorView
from brevitas.core.stats.stats_op import MSE
Expand Down Expand Up @@ -54,14 +55,24 @@ def simple_model():
"""
assert IMAGE_DIM % 2 == 0, "`IMAGE_DIM` should be a multiple of 2"
return nn.Sequential(
nn.Conv2d(10, 16, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(16, 32, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2), # downsample from IMAGE_DIM to half that
nn.Flatten(),
nn.Linear(32 * int(IMAGE_DIM / 2) ** 2, 1000))

class Model(torch.nn.Module):

def __init__(self):
super().__init__()
self.layers = nn.Sequential(
nn.Conv2d(10, 16, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(16, 32, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2), # downsample from IMAGE_DIM to half that
nn.Flatten(),
nn.Linear(32 * int(IMAGE_DIM / 2) ** 2, 1000))

def forward(self, x):
return self.layers(x)

return Model()


##############
Expand All @@ -84,7 +95,7 @@ def test_fx_model(simple_model, weight_bit_width, bias_bit_width, act_bit_width)
- That setting `None` for the `bias_bit_width` returns a dequantized bias.
- That the bit widths are as desired.
"""
fx_model = torch.fx.symbolic_trace(simple_model)
fx_model = brevitas.fx.symbolic_trace(simple_model)
quant_model = quantize_model(
model=fx_model,
backend='fx',
Expand All @@ -98,17 +109,15 @@ def test_fx_model(simple_model, weight_bit_width, bias_bit_width, act_bit_width)
quant_format='int',
layerwise_first_last_bit_width=5,
)
# Assert it is a GraphModule
assert isinstance(quant_model, torch.fx.graph_module.GraphModule)

# Assert we can feed data of the correct size through the model
quant_model(torch.rand(1, 10, IMAGE_DIM, IMAGE_DIM))

# Get first/last layer for testing its quantization.
first_conv_layer = quant_model.get_submodule('0')
first_relu_layer = quant_model.get_submodule('1')
last_layer = quant_model.get_submodule('6')
last_layer_output = quant_model.get_submodule('_6_output_quant')
first_conv_layer = quant_model.layers.get_submodule('0')
first_relu_layer = quant_model.layers.get_submodule('1')
last_layer = quant_model.layers.get_submodule('6')
last_layer_output = quant_model.get_submodule('layers_6_output_quant')

# Assert the module types are as desired
assert isinstance(first_conv_layer, QuantConv2d)
Expand All @@ -129,7 +138,7 @@ def test_fx_model(simple_model, weight_bit_width, bias_bit_width, act_bit_width)
assert first_relu_layer.act_quant.is_quant_enabled # the output of the "fused" ConvReLU is quantized

# Assert types are as expected
assert isinstance(quant_model.get_submodule('3'), QuantReLU)
assert isinstance(quant_model.layers.get_submodule('3'), QuantReLU)

# Assert quantization bit widths are as desired
# Biases
Expand Down Expand Up @@ -164,7 +173,7 @@ def test_fx_sym_quant(simple_model):
act_bit_width = 8
bias_bit_width = 32

fx_model = torch.fx.symbolic_trace(simple_model)
fx_model = brevitas.fx.symbolic_trace(simple_model)
quant_model = quantize_model(
model=fx_model,
backend='fx',
Expand All @@ -183,10 +192,10 @@ def test_fx_sym_quant(simple_model):
quant_model(torch.rand(1, 10, IMAGE_DIM, IMAGE_DIM))

# Get first/last layer for testing its quantization.
first_conv_layer = quant_model.get_submodule('0')
first_relu_layer = quant_model.get_submodule('1')
last_layer = quant_model.get_submodule('6')
last_layer_output = quant_model.get_submodule('_6_output_quant')
first_conv_layer = quant_model.layers.get_submodule('0')
first_relu_layer = quant_model.layers.get_submodule('1')
last_layer = quant_model.layers.get_submodule('6')
last_layer_output = quant_model.get_submodule('layers_6_output_quant')

# Check quantizaton is toggled as expected
assert first_conv_layer.bias_quant.is_quant_enabled
Expand Down Expand Up @@ -235,7 +244,7 @@ def test_fx_affine_quantization(simple_model):
act_bit_width = 8
bias_bit_width = 32

fx_model = torch.fx.symbolic_trace(simple_model)
fx_model = brevitas.fx.symbolic_trace(simple_model)
quant_model = quantize_model(
model=fx_model,
backend='fx',
Expand All @@ -254,9 +263,9 @@ def test_fx_affine_quantization(simple_model):
quant_model(torch.rand(1, 10, IMAGE_DIM, IMAGE_DIM))

# Get first/last layer for testing its quantization.
first_conv_layer = quant_model.get_submodule('0')
last_layer = quant_model.get_submodule('6')
last_layer_output = quant_model.get_submodule('_6_output_quant')
first_conv_layer = quant_model.layers.get_submodule('0')
last_layer = quant_model.layers.get_submodule('6')
last_layer_output = quant_model.get_submodule('layers_6_output_quant')

# Assert the tensors are unsigned as expected for asymmetric quantization, with zero-points not at 0.
# Weights
Expand Down Expand Up @@ -298,7 +307,7 @@ def test_fx_param_method_stats(simple_model, weight_bit_width, bias_bit_width, a
- That setting `None` for the `bias_bit_width` returns a dequantized bias.
- That the bit widths are as desired.
"""
fx_model = torch.fx.symbolic_trace(simple_model)
fx_model = brevitas.fx.symbolic_trace(simple_model)
quant_model = quantize_model(
model=fx_model,
backend='fx',
Expand All @@ -314,17 +323,15 @@ def test_fx_param_method_stats(simple_model, weight_bit_width, bias_bit_width, a
weight_param_method='stats',
act_param_method='stats',
)
# Assert it is a GraphModule
assert isinstance(quant_model, torch.fx.graph_module.GraphModule)

# Assert we can feed data of the correct size through the model
quant_model(torch.rand(1, 10, IMAGE_DIM, IMAGE_DIM))

# Get first/last layer for testing its quantization.
first_conv_layer = quant_model.get_submodule('0')
first_relu_layer = quant_model.get_submodule('1')
last_layer = quant_model.get_submodule('6')
last_layer_output = quant_model.get_submodule('_6_output_quant')
first_conv_layer = quant_model.layers.get_submodule('0')
first_relu_layer = quant_model.layers.get_submodule('1')
last_layer = quant_model.layers.get_submodule('6')
last_layer_output = quant_model.get_submodule('layers_6_output_quant')

# Assert the module types are as desired
assert isinstance(first_conv_layer, QuantConv2d)
Expand All @@ -345,7 +352,7 @@ def test_fx_param_method_stats(simple_model, weight_bit_width, bias_bit_width, a
assert first_relu_layer.act_quant.is_quant_enabled # the output of the "fused" ConvReLU is quantized

# Assert types are as expected
assert isinstance(quant_model.get_submodule('3'), QuantReLU)
assert isinstance(quant_model.layers.get_submodule('3'), QuantReLU)

# Assert quantization bit widths are as desired
# Biases
Expand Down Expand Up @@ -379,7 +386,7 @@ def test_fx_per_chan_weight_quantization(simple_model):
act_bit_width = 8
bias_bit_width = 32

fx_model = torch.fx.symbolic_trace(simple_model)
fx_model = brevitas.fx.symbolic_trace(simple_model)
quant_model = quantize_model(
model=fx_model,
backend='fx',
Expand All @@ -392,16 +399,14 @@ def test_fx_per_chan_weight_quantization(simple_model):
scale_factor_type='float_scale',
quant_format='int',
)
# Assert it is a GraphModule
assert isinstance(quant_model, torch.fx.graph_module.GraphModule)

# Assert we can feed data of the correct size through the model
quant_model(torch.rand(1, 10, IMAGE_DIM, IMAGE_DIM))

# Get first/last layer for testing its quantization.
first_conv_layer = quant_model.get_submodule('0')
last_layer = quant_model.get_submodule('6')
last_layer_output = quant_model.get_submodule('_6_output_quant')
first_conv_layer = quant_model.layers.get_submodule('0')
last_layer = quant_model.layers.get_submodule('6')
last_layer_output = quant_model.get_submodule('layers_6_output_quant')

# Assert per-channel quantization of weights
# 16 is the nb of output channels of first layer of `simple_model`
Expand All @@ -426,7 +431,7 @@ def test_invalid_input(simple_model):
"""
We test various invalid inputs, e.g. invalid strings and zero/negative bit widths.
"""
fx_model = torch.fx.symbolic_trace(simple_model)
fx_model = brevitas.fx.symbolic_trace(simple_model)
with pytest.raises(KeyError):
quantize_model(
model=fx_model,
Expand Down Expand Up @@ -543,7 +548,7 @@ def test_layerwise_percentile_for_calibration(simple_model, act_quant_percentile
quant_model.eval()

# Get first/last layer for testing its quantization.
first_conv_layer = quant_model.get_submodule('0')
first_conv_layer = quant_model.layers.get_submodule('0')

# We check the calibration. We do so by ensuring that the quantization range is within a few quantization
# bin tolerance of the `act_quant_percentile`. This should be the case, given that we are doing
Expand Down Expand Up @@ -605,8 +610,8 @@ def test_layerwise_param_method_mse(simple_model, quant_granularity):
quant_model.eval()

# Get first/last layer for testing its quantization.
first_conv_layer = quant_model.get_submodule('0')
last_layer = quant_model.get_submodule('6')
first_conv_layer = quant_model.layers.get_submodule('0')
last_layer = quant_model.layers.get_submodule('6')

# Check that the quant param method module is MSE as it should be
# Weights
Expand Down Expand Up @@ -680,14 +685,14 @@ def test_layerwise_10_in_channels_quantize_model(
quant_format='int',
layerwise_first_last_bit_width=layerwise_first_last_bit_width,
)
assert isinstance(quant_model, nn.Sequential)
assert isinstance(quant_model.layers, nn.Sequential)

# Make sure we can feed data through the model
_ = quant_model(torch.rand(1, 10, IMAGE_DIM, IMAGE_DIM))

# Get first layer for testing its quantization.
# We also test we can feed data through the first layer in isolation
first_layer = quant_model.get_submodule('0')
first_layer = quant_model.layers.get_submodule('0')
first_layer_output = first_layer(torch.rand(1, 10, IMAGE_DIM, IMAGE_DIM))

# Assert the module types are as desired
Expand Down Expand Up @@ -786,8 +791,8 @@ def test_po2_layerwise_quantization(simple_model):
quant_model(torch.rand(1, 10, IMAGE_DIM, IMAGE_DIM))

# Get first/last layer for testing its quantization.
first_conv_layer = quant_model.get_submodule('0')
last_layer = quant_model.get_submodule('6')
first_conv_layer = quant_model.layers.get_submodule('0')
last_layer = quant_model.layers.get_submodule('6')

# Assert scales are powers of 2 as expected
assert torch.isclose(torch.log2(first_conv_layer.input_quant.scale()) % 1, torch.Tensor([0.0]))
Expand Down

0 comments on commit 13ec5dd

Please sign in to comment.