diff --git a/tests/brevitas/graph/test_gpxq.py b/tests/brevitas/graph/test_gpxq.py index 4293b8582..8f5235767 100644 --- a/tests/brevitas/graph/test_gpxq.py +++ b/tests/brevitas/graph/test_gpxq.py @@ -1,8 +1,6 @@ # Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -from functools import partial - import pytest import torch import torch.nn as nn @@ -16,25 +14,13 @@ def apply_gpfq( - calib_loader: DataLoader, - model: nn.Module, - act_order: bool, - use_quant_activations: bool = True, - accumulator_bit_width: int = 32, - a2q_layer_filter_fnc=lambda x: True): + calib_loader: DataLoader, model: nn.Module, act_order: bool, use_quant_activations: bool): model.eval() dtype = next(model.parameters()).dtype device = next(model.parameters()).device with torch.no_grad(): - # use A2GPFQ if accumulator is less than 32 is specified - with gpfq_mode( - model, - use_quant_activations=use_quant_activations, - act_order=act_order, - use_gpfa2q=accumulator_bit_width < 32, - accumulator_bit_width=accumulator_bit_width, - a2q_layer_filter_fnc=a2q_layer_filter_fnc, - ) as gpfq: + with gpfq_mode(model, use_quant_activations=use_quant_activations, + act_order=act_order) as gpfq: gpfq_model = gpfq.model for _ in range(gpfq.num_layers): for _, (images, _) in enumerate(calib_loader): @@ -64,44 +50,20 @@ def apply_gptq( gptq.update() -def custom_layer_filter_fnc(layer: nn.Module) -> bool: - if isinstance(layer, nn.Conv2d) and layer.in_channels == 3: - return False - elif isinstance(layer, nn.ConvTranspose2d) and layer.in_channels == 3: - return False - return True - - apply_gpxq_func_map = {"gpfq": apply_gpfq, "gptq": apply_gptq} @pytest.mark.parametrize("act_order", [True, False]) @pytest.mark.parametrize("use_quant_activations", [True, False]) -@pytest.mark.parametrize("acc_bit_width", [32, 24, 16, 12]) @pytest.mark.parametrize("apply_gpxq_tuple", apply_gpxq_func_map.items()) -def test_toymodels( - toy_quant_model, act_order, use_quant_activations, acc_bit_width, apply_gpxq_tuple, - request): +def test_toymodels(toy_quant_model, act_order, use_quant_activations, apply_gpxq_tuple, request): test_id = request.node.callspec.id - input_quant = test_id.split('-')[1] - weight_quant = test_id.split('-')[2] - - if ('MXFloat' in input_quant or 'MXInt' in weight_quant) and acc_bit_width < 32: - pytest.skip("MX quant does not support accumulator-aware quantization.") torch.manual_seed(SEED) name, apply_gpxq = apply_gpxq_tuple - if (name == 'gptq' and acc_bit_width < 32): - pytest.skip("GPTQ does not support accumulator-aware quantization.") - - if name == 'gpfq': - filter_func = custom_layer_filter_fnc - apply_gpxq = partial( - apply_gpxq, accumulator_bit_width=acc_bit_width, a2q_layer_filter_fnc=filter_func) - model_class = toy_quant_model model = model_class() if 'mha' in test_id: @@ -122,20 +84,8 @@ def test_toymodels( act_order=act_order, use_quant_activations=use_quant_activations) - elif (name == 'gpfq') and (acc_bit_width < 32) and (not use_quant_activations or - input_quant == 'None'): - # GPFA2Q requires that the quant activations are used. GPFA2Q.single_layer_update will - # raise a ValueError if GPFA2Q.quant_input is None (also see GPxQ.process_input). This will - # happen when `use_quant_activations=False` or when the input to a model is not quantized - with pytest.raises(ValueError): - apply_gpxq( - calib_loader=calib_loader, - model=model, - act_order=act_order, - use_quant_activations=use_quant_activations) - else: - apply_gpxq( - calib_loader=calib_loader, - model=model, - act_order=act_order, - use_quant_activations=use_quant_activations) + apply_gpxq( + calib_loader=calib_loader, + model=model, + act_order=act_order, + use_quant_activations=use_quant_activations)