Skip to content

Commit

Permalink
Add tests for checking calibration percentile performance, and testin…
Browse files Browse the repository at this point in the history
…g MSE for qparam calibration.
  • Loading branch information
OscarSavolainen committed Apr 23, 2024
1 parent 545cccc commit 147f516
Showing 1 changed file with 237 additions and 5 deletions.
242 changes: 237 additions & 5 deletions tests/brevitas_examples/test_quantize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@
import torch
import torch.nn as nn

from brevitas.core.function_wrapper.shape import OverOutputChannelView
from brevitas.core.function_wrapper.shape import OverTensorView
from brevitas.core.stats.stats_op import MSE
from brevitas.graph.calibrate import calibration_mode
from brevitas.nn import QuantConv2d
from brevitas.nn import QuantLinear
from brevitas.nn import QuantReLU
from brevitas.quant_tensor import QuantTensor
from brevitas_examples.imagenet_classification.ptq.ptq_common import quantize_model

# TODO:
# - Possibility to use statistics or MSE for scale factor computations for weights and activations.
# - Percentiles used for the activations' statistics computation during calibration.

# CONSTANTS
IMAGE_DIM = 16

Expand Down Expand Up @@ -282,6 +282,90 @@ def test_fx_affine_quantization(simple_model):
assert last_layer_output.act_quant.bit_width().item() == act_bit_width


@pytest.mark.parametrize("weight_bit_width", [2, 8, 16])
@pytest.mark.parametrize("act_bit_width", [2, 5, 8])
@pytest.mark.parametrize("bias_bit_width", [16, 32, 0])
def test_fx_param_method_stats(simple_model, weight_bit_width, bias_bit_width, act_bit_width):
"""
We test fx quantization, with the weight and activation quantization `stats` parameter methods.
`stats` is the default setting, but we also test it explicitly in case it ever changes from the
default.
We test:
- The FX-graph, quantized model is a GraphModule.
- We can feed data through the model.
- That the weight, bias and input/output quantization is toggled as expected.
- That setting `None` for the `bias_bit_width` returns a dequantized bias.
- That the bit widths are as desired.
"""
fx_model = torch.fx.symbolic_trace(simple_model)
quant_model = quantize_model(
model=fx_model,
backend='fx',
weight_bit_width=weight_bit_width,
act_bit_width=act_bit_width,
bias_bit_width=bias_bit_width if bias_bit_width > 0 else None,
weight_quant_granularity='per_tensor',
act_quant_percentile=99.9,
act_quant_type='sym',
scale_factor_type='float_scale',
quant_format='int',
layerwise_first_last_bit_width=5,
weight_param_method='stats',
act_param_method='stats',
)
# Assert it is a GraphModule
assert isinstance(quant_model, torch.fx.graph_module.GraphModule)

# Assert we can feed data of the correct size through the model
quant_model(torch.rand(1, 10, IMAGE_DIM, IMAGE_DIM))

# Get first/last layer for testing its quantization.
first_conv_layer = quant_model.get_submodule('0')
first_relu_layer = quant_model.get_submodule('1')
last_layer = quant_model.get_submodule('6')
last_layer_output = quant_model.get_submodule('_6_output_quant')

# Assert the module types are as desired
assert isinstance(first_conv_layer, QuantConv2d)
assert isinstance(last_layer, QuantLinear)

# Check quantizaton is toggled as expected
if bias_bit_width == 0:
# If bias_bit_width is set as `None` (local variable value in scope of this function is 0),
# the bias should be dequantized.
assert not first_conv_layer.bias_quant.is_quant_enabled
else:
assert first_conv_layer.bias_quant.is_quant_enabled
assert first_conv_layer.weight_quant.is_quant_enabled
assert not first_conv_layer.input_quant.is_quant_enabled # unlike with the layerwise backend, the input quantization is disabled.
assert not first_conv_layer.output_quant.is_quant_enabled

assert not first_relu_layer.input_quant.is_quant_enabled
assert first_relu_layer.act_quant.is_quant_enabled # the output of the "fused" ConvReLU is quantized

# Assert types are as expected
assert isinstance(quant_model.get_submodule('3'), QuantReLU)

# Assert quantization bit widths are as desired
# Biases
if bias_bit_width > 0:
assert first_conv_layer.bias_quant.tensor_quant.msb_clamp_bit_width_impl.bit_width._buffers[
'value'].item() == bias_bit_width
assert last_layer.bias_quant.tensor_quant.msb_clamp_bit_width_impl.bit_width._buffers[
'value'].item() == bias_bit_width
else:
# If bias_bit_width is `None`, the quantized bias should return a fully floating point parameter.
assert not isinstance(first_conv_layer.quant_bias(), QuantTensor)

# Weights
assert first_conv_layer.weight_quant.bit_width().item() == weight_bit_width
assert last_layer.weight_quant.bit_width().item() == weight_bit_width
# Activations
assert first_relu_layer.act_quant.bit_width().item() == act_bit_width
assert last_layer_output.act_quant.bit_width().item() == act_bit_width


def test_fx_per_chan_weight_quantization(simple_model):
"""
We test per-channel weight quantization.
Expand All @@ -306,8 +390,10 @@ def test_fx_per_chan_weight_quantization(simple_model):
act_quant_percentile=99.9,
act_quant_type='sym',
scale_factor_type='float_scale',
quant_format='float',
quant_format='int',
)
# Assert it is a GraphModule
assert isinstance(quant_model, torch.fx.graph_module.GraphModule)

# Assert we can feed data of the correct size through the model
quant_model(torch.rand(1, 10, IMAGE_DIM, IMAGE_DIM))
Expand Down Expand Up @@ -414,6 +500,152 @@ def test_invalid_input(simple_model):
##########################
# LAYERWISE MODE TESTING #
##########################
@pytest.mark.parametrize("act_quant_percentile", [1, 50, 99.9])
def test_layerwise_percentile_for_calibration(simple_model, act_quant_percentile):
"""
We test different values for the percentile used for the activations' statistics computation
during calibration.
We test if the percentile correcrly produces the desired qparams for a `QuantIdentity`
when fed a tensor with linearly scaled values between 0 and 1.
We test:
- We can feed data through the model.
- The desired qparams manifest under controlled conditions as a result of calibration
percentiles.
"""
weight_bit_width = 8
act_bit_width = 8
bias_bit_width = 32

quant_model = quantize_model(
model=simple_model,
backend='layerwise',
weight_bit_width=weight_bit_width,
act_bit_width=act_bit_width,
bias_bit_width=bias_bit_width,
weight_quant_granularity='per_tensor',
act_quant_percentile=act_quant_percentile,
act_quant_type='asym',
scale_factor_type='float_scale',
quant_format='int',
)

# Assert we can feed data of the correct size through the model
# We are also performing calibration
quant_model.train()
# We create an input with values linearly scaled between 0 and 1.
input = torch.arange(0, 1, step=1 / (10 * IMAGE_DIM ** 2))
input = input.view(1, 10, IMAGE_DIM, IMAGE_DIM).float()
with torch.no_grad():
with calibration_mode(quant_model):
for _ in range(1000):
quant_model(input)
quant_model.eval()

# Get first/last layer for testing its quantization.
first_conv_layer = quant_model.get_submodule('0')

# We check the calibration. We do so by ensuring that the quantization range is within a few quantization
# bin tolerance of the `act_quant_percentile`. This should be the case, given that we are doing
# affine quantization with a strictly positive tensor (and so zero_point should be 0), and because
# we are doing the calibration with a tensor with all values linearly increasinging from 0 to 1.
assert torch.isclose(first_conv_layer.input_quant.zero_point(), torch.Tensor([0.]))
tolerance = 8 # quantization bins of tolerance, on the "plus side".
ideal_range = act_quant_percentile / 100
scale = first_conv_layer.input_quant.scale()

# The quantization range is always smaller than the data covered up to the percentile, because
# of how the percnetile->qrange calculation happens.
assert ideal_range > scale * 255
# We make sure the quantization range is still reasonably close to covering the entire data, up
# to the provided percentile.
assert ideal_range < scale * (255 + tolerance)


@pytest.mark.parametrize("quant_granularity", ["per_tensor", "per_channel"])
def test_layerwise_param_method_mse(simple_model, quant_granularity):
"""
We test layerwise quantization, with the weight and activation quantization `mse` parameter
methods.
We test:
- We can feed data through the model.
- That the stat observer is explictly MSE.
- That the view on the quantization granularity is as desired.
- That during calibration, the qparams are derived by finding values that minimize the MSE
between the floating point and quantized tensor.
"""
weight_bit_width = 8
act_bit_width = 8
bias_bit_width = 32
quant_model = quantize_model(
model=simple_model,
backend='layerwise',
weight_bit_width=weight_bit_width,
act_bit_width=act_bit_width,
bias_bit_width=bias_bit_width if bias_bit_width > 0 else None,
weight_quant_granularity=quant_granularity,
act_quant_type='asym',
act_quant_percentile=99.9, # Unused
scale_factor_type='float_scale',
quant_format='int',
weight_param_method='mse',
act_param_method='mse',
)

# Assert we can feed data of the correct size through the model
# We are also performing calibration
quant_model.train()
# We create an input with values linearly scaled between 0 and 1.
input = torch.arange(0, 1, step=1 / (10 * IMAGE_DIM ** 2))
input = input.view(1, 10, IMAGE_DIM, IMAGE_DIM).float()
with torch.no_grad():
with calibration_mode(quant_model):
quant_model(input)
quant_model.eval()

# Get first/last layer for testing its quantization.
first_conv_layer = quant_model.get_submodule('0')
last_layer = quant_model.get_submodule('6')

# Check that the quant param method module is MSE as it should be
# Weights
first_weight_param_mod = first_conv_layer.weight_quant.tensor_quant.scaling_impl.parameter_list_stats.stats.stats_impl
last_weight_param_mod = last_layer.weight_quant.tensor_quant.scaling_impl.parameter_list_stats.stats.stats_impl
assert isinstance(first_weight_param_mod, MSE)
assert isinstance(last_weight_param_mod, MSE)

# Check observation is over tensor or channel as desired
def check_dim_of_observation(module: torch.nn.Module, quant_granularity: str):
if quant_granularity == 'per_tensor':
assert isinstance(module.input_view_shape_impl, OverTensorView)
elif quant_granularity == 'per_channel':
assert isinstance(module.input_view_shape_impl, OverOutputChannelView)

# Weight
check_dim_of_observation(first_weight_param_mod, quant_granularity)
check_dim_of_observation(last_weight_param_mod, quant_granularity)

# We test the calibrated qparams. We fed in a tensor with linearly scaled values between 0 and 1.
# We check that varying the qparams gives worse or equal MSE than the calibrated qparams.
# We assume a convex problem.
scale = first_conv_layer.input_quant.scale()
zero_point = first_conv_layer.input_quant.zero_point()

def get_qmse(
scale: torch.Tensor, zero_point: torch.Tensor, input: torch.Tensor) -> torch.Tensor:
quant_tensor = scale * (
torch.clamp(torch.round(input / scale + zero_point), 0, 255) - zero_point)
mse = torch.mean((quant_tensor - input) ** 2)
return mse

orig_mse = get_qmse(scale, zero_point, input)
for scale_diff in [0.1 * scale, 0, -0.1 * scale]:
for zero_diff in [1, 0, -1]:
diff_mse = get_qmse(scale + scale_diff, zero_point + zero_diff, input)
assert torch.isclose(diff_mse, orig_mse) or (diff_mse > orig_mse)


@pytest.mark.parametrize("weight_bit_width", [2, 5, 8, 16])
@pytest.mark.parametrize("act_bit_width", [2, 5, 8])
@pytest.mark.parametrize("bias_bit_width", [16, 32])
Expand Down

0 comments on commit 147f516

Please sign in to comment.