Skip to content

Commit

Permalink
Fix (tests): updating GPxQ testing
Browse files Browse the repository at this point in the history
  • Loading branch information
i-colbert committed Oct 8, 2024
1 parent 10a5c7e commit 00e7d1c
Showing 1 changed file with 9 additions and 59 deletions.
68 changes: 9 additions & 59 deletions tests/brevitas/graph/test_gpxq.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from functools import partial

import pytest
import torch
import torch.nn as nn
Expand All @@ -16,25 +14,13 @@


def apply_gpfq(
calib_loader: DataLoader,
model: nn.Module,
act_order: bool,
use_quant_activations: bool = True,
accumulator_bit_width: int = 32,
a2q_layer_filter_fnc=lambda x: True):
calib_loader: DataLoader, model: nn.Module, act_order: bool, use_quant_activations: bool):
model.eval()
dtype = next(model.parameters()).dtype
device = next(model.parameters()).device
with torch.no_grad():
# use A2GPFQ if accumulator is less than 32 is specified
with gpfq_mode(
model,
use_quant_activations=use_quant_activations,
act_order=act_order,
use_gpfa2q=accumulator_bit_width < 32,
accumulator_bit_width=accumulator_bit_width,
a2q_layer_filter_fnc=a2q_layer_filter_fnc,
) as gpfq:
with gpfq_mode(model, use_quant_activations=use_quant_activations,
act_order=act_order) as gpfq:
gpfq_model = gpfq.model
for _ in range(gpfq.num_layers):
for _, (images, _) in enumerate(calib_loader):
Expand Down Expand Up @@ -64,44 +50,20 @@ def apply_gptq(
gptq.update()


def custom_layer_filter_fnc(layer: nn.Module) -> bool:
if isinstance(layer, nn.Conv2d) and layer.in_channels == 3:
return False
elif isinstance(layer, nn.ConvTranspose2d) and layer.in_channels == 3:
return False
return True


apply_gpxq_func_map = {"gpfq": apply_gpfq, "gptq": apply_gptq}


@pytest.mark.parametrize("act_order", [True, False])
@pytest.mark.parametrize("use_quant_activations", [True, False])
@pytest.mark.parametrize("acc_bit_width", [32, 24, 16, 12])
@pytest.mark.parametrize("apply_gpxq_tuple", apply_gpxq_func_map.items())
def test_toymodels(
toy_quant_model, act_order, use_quant_activations, acc_bit_width, apply_gpxq_tuple,
request):
def test_toymodels(toy_quant_model, act_order, use_quant_activations, apply_gpxq_tuple, request):

test_id = request.node.callspec.id
input_quant = test_id.split('-')[1]
weight_quant = test_id.split('-')[2]

if ('MXFloat' in input_quant or 'MXInt' in weight_quant) and acc_bit_width < 32:
pytest.skip("MX quant does not support accumulator-aware quantization.")

torch.manual_seed(SEED)

name, apply_gpxq = apply_gpxq_tuple

if (name == 'gptq' and acc_bit_width < 32):
pytest.skip("GPTQ does not support accumulator-aware quantization.")

if name == 'gpfq':
filter_func = custom_layer_filter_fnc
apply_gpxq = partial(
apply_gpxq, accumulator_bit_width=acc_bit_width, a2q_layer_filter_fnc=filter_func)

model_class = toy_quant_model
model = model_class()
if 'mha' in test_id:
Expand All @@ -122,20 +84,8 @@ def test_toymodels(
act_order=act_order,
use_quant_activations=use_quant_activations)

elif (name == 'gpfq') and (acc_bit_width < 32) and (not use_quant_activations or
input_quant == 'None'):
# GPFA2Q requires that the quant activations are used. GPFA2Q.single_layer_update will
# raise a ValueError if GPFA2Q.quant_input is None (also see GPxQ.process_input). This will
# happen when `use_quant_activations=False` or when the input to a model is not quantized
with pytest.raises(ValueError):
apply_gpxq(
calib_loader=calib_loader,
model=model,
act_order=act_order,
use_quant_activations=use_quant_activations)
else:
apply_gpxq(
calib_loader=calib_loader,
model=model,
act_order=act_order,
use_quant_activations=use_quant_activations)
apply_gpxq(
calib_loader=calib_loader,
model=model,
act_order=act_order,
use_quant_activations=use_quant_activations)

0 comments on commit 00e7d1c

Please sign in to comment.