From f671dd1e35b3ad969fc6561cfc107fbd6ddf2a64 Mon Sep 17 00:00:00 2001 From: Fabian Grob Date: Fri, 26 Jan 2024 02:57:24 -0800 Subject: [PATCH] Feat (tests): adding tests for channel splitting --- .../benchmark/ptq_benchmark_torchvision.py | 3 + .../imagenet_classification/ptq/ptq_common.py | 8 +- .../ptq/ptq_evaluate.py | 3 + tests/brevitas/graph/equalization_fixtures.py | 47 ++++++- .../brevitas/graph/test_channel_splitting.py | 121 ++++++++++++++++++ tests/brevitas/graph/test_equalization.py | 14 +- 6 files changed, 190 insertions(+), 6 deletions(-) create mode 100644 tests/brevitas/graph/test_channel_splitting.py diff --git a/src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py b/src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py index 3a1581c3c..668eee22c 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py +++ b/src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py @@ -339,6 +339,9 @@ def validate_config(config_namespace): config_namespace.gpfa2q) if multiple_gpxqs > 1: is_valid = False + elif multiple_gpxqs == 0: + # no gpxq algorithm, set act order to None + config_namespace.gpxq_act_order = None if config_namespace.act_equalization == 'layerwise' and config_namespace.target_backend == 'fx': is_valid = False diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py index 0b626e485..541f0085f 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py @@ -197,8 +197,10 @@ def layerwise_bit_width_fn_weight(module): act_bit_width_dict = {} if quant_format == 'int' and backend == 'layerwise': weight_bit_width_dict['weight_bit_width'] = layerwise_bit_width_fn_weight - act_bit_width_dict['act_bit_width'] = layerwise_bit_width_fn_act - + if act_bit_width is not None: + act_bit_width_dict['act_bit_width'] = layerwise_bit_width_fn_act + else: + act_bit_width_dict['act_bit_width'] = None elif quant_format == 'int' and backend != 'layerwise': weight_bit_width_dict['weight_bit_width'] = weight_bit_width act_bit_width_dict['act_bit_width'] = act_bit_width @@ -291,7 +293,7 @@ def kwargs_prefix(prefix, weight_kwargs): act_bit_width_dict['mantissa_bit_width'] = act_mantissa_bit_width # Retrieve base input, weight, and bias quantizers - bias_quant = BIAS_BIT_WIDTH_MAP[bias_bit_width] + bias_quant = BIAS_BIT_WIDTH_MAP[bias_bit_width] if act_bit_width is not None else None weight_quant = WEIGHT_QUANT_MAP[weight_quant_format][weight_scale_type][weight_param_method][ weight_quant_granularity][weight_quant_type] weight_quant = weight_quant.let(**weight_bit_width_dict) diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py index a32badd82..740a207ac 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py @@ -256,6 +256,9 @@ def main(): else: act_quant_calib_config = args.act_quant_calibration_type + if args.act_bit_width == 0: + args.act_bit_width = None + config = ( f"{args.model_name}_" f"{args.target_backend}_" diff --git a/tests/brevitas/graph/equalization_fixtures.py b/tests/brevitas/graph/equalization_fixtures.py index 263543a82..4750fc96d 100644 --- a/tests/brevitas/graph/equalization_fixtures.py +++ b/tests/brevitas/graph/equalization_fixtures.py @@ -301,6 +301,49 @@ def forward(self, x): return ResidualSrcsAndSinkModel +@pytest_cases.fixture +def convgroupconv_model(): + + class ConvGroupConvModel(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.conv = nn.Conv2d(3, 16, kernel_size=3) + self.conv_0 = nn.Conv2d(16, 32, kernel_size=1, groups=2) + self.conv_1 = nn.Conv2d(32, 64, kernel_size=1, groups=4) + self.relu = nn.ReLU() + + def forward(self, x): + x = self.conv(x) + x = self.relu(x) + x = self.conv_0(x) + x = self.relu(x) + x = self.conv_1(x) + return x + + return ConvGroupConvModel + + +@pytest_cases.fixture +def convtranspose_model(): + + class ConvTransposeModel(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.relu = nn.ReLU() + self.conv_0 = nn.ConvTranspose2d(in_channels=3, out_channels=8, kernel_size=3) + self.conv_1 = nn.ConvTranspose2d(in_channels=8, out_channels=32, kernel_size=3) + + def forward(self, x): + x = self.conv_0(x) + x = self.relu(x) + x = self.conv_1(x) + return x + + return ConvTransposeModel + + list_of_fixtures = [ 'residual_model', 'srcsinkconflict_model', @@ -309,7 +352,9 @@ def forward(self, x): 'convdepthconv_model', 'linearmha_model', 'mhalinear_model', - 'layernormmha_model'] + 'layernormmha_model', + 'convgroupconv_model', + 'convtranspose_model'] toy_model = fixture_union('toy_model', list_of_fixtures, ids=list_of_fixtures) diff --git a/tests/brevitas/graph/test_channel_splitting.py b/tests/brevitas/graph/test_channel_splitting.py new file mode 100644 index 000000000..30a3dd8d3 --- /dev/null +++ b/tests/brevitas/graph/test_channel_splitting.py @@ -0,0 +1,121 @@ +# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +import torch + +from brevitas.fx import symbolic_trace +from brevitas.graph.channel_splitting import _clean_regions +from brevitas.graph.channel_splitting import _split +from brevitas.graph.equalize import _extract_regions +from brevitas.graph.fixed_point import MergeBatchNorm + +from .equalization_fixtures import * + +no_split_models = ( + 'mul_model', + 'bnconv_model', + 'convdepthconv_model', + 'linearmha_model', + 'layernormmha_model', + 'convgroupconv_model', + 'vit_b_32', + 'shufflenet_v2_x0_5', + 'googlenet', + 'inception_v3') + +SPLIT_RATIO = 0.1 + + +@pytest.mark.parametrize('split_input', [False, True]) +def test_toymodels(toy_model, split_input, request): + test_id = request.node.callspec.id + + torch.manual_seed(SEED) + + model_class = toy_model + model = model_class() + if 'mha' in test_id: + inp = torch.randn(IN_SIZE_LINEAR) + else: + inp = torch.randn(IN_SIZE_CONV) + + model.eval() + expected_out = model(inp) + + model = symbolic_trace(model) + # merge BN before applying channel splitting + model = MergeBatchNorm().apply(model) + + # save model's state dict to check if channel splitting was done or not + old_state_dict = model.state_dict() + + regions = _extract_regions(model) + regions = _clean_regions(regions) + if model_class in no_split_models: + assert len(regions) == 0 + else: + model = _split(model, regions, split_ratio=SPLIT_RATIO, split_input=split_input) + + out = model(inp) + assert torch.allclose(expected_out, out, atol=ATOL) + + modified_sources = {source for region in regions for source in region.srcs_names} + # avoiding checking the same module multiple times + modified_sinks = { + sink for region in regions for sink in region.sinks_names} - modified_sources + for module in modified_sources: + if 'mha' in module: + module += '.out_proj' + weight_name = module + '.weight' + assert not torch.equal(old_state_dict[weight_name], model.state_dict()[weight_name]) + bias_name = module + '.bias' + # not all modules have bias and they only differ when splitting output channels + if bias_name in old_state_dict.keys() and not split_input: + assert not torch.equal(old_state_dict[bias_name], model.state_dict()[bias_name]) + for module in modified_sinks: + weight_name = module + '.weight' + assert not torch.equal(old_state_dict[weight_name], model.state_dict()[weight_name]) + + +@pytest.mark.parametrize('split_input', [False, True]) +def test_torchvision_models(model_coverage: tuple, split_input: bool, request): + model_class = request.node.callspec.id.split('-')[0] + + model, coverage = model_coverage + + torch.manual_seed(SEED) + inp = torch.randn(IN_SIZE_CONV) + + model.eval() + expected_out = model(inp) + + model = symbolic_trace(model) + # merge BN before applying channel splitting + model = MergeBatchNorm().apply(model) + + old_state_dict = model.state_dict() + + regions = _extract_regions(model) + regions = _clean_regions(regions) + if model_class in no_split_models: + assert len(regions) == 0 + else: + model = _split(model, regions, split_ratio=SPLIT_RATIO, split_input=split_input) + + out = model(inp) + assert torch.allclose(expected_out, out, atol=ATOL) + + modified_sources = {source for region in regions for source in region.srcs_names} + # avoiding checking the same module multiple times + modified_sinks = { + sink for region in regions for sink in region.sinks_names} - modified_sources + for module in modified_sources: + weight_name = module + '.weight' + assert not torch.equal(old_state_dict[weight_name], model.state_dict()[weight_name]) + bias_name = module + '.bias' + # not all modules have bias and they only differ when splitting output channels + if bias_name in old_state_dict.keys() and not split_input: + assert not torch.equal(old_state_dict[bias_name], model.state_dict()[bias_name]) + for module in modified_sinks: + weight_name = module + '.weight' + assert not torch.equal(old_state_dict[weight_name], model.state_dict()[weight_name]) diff --git a/tests/brevitas/graph/test_equalization.py b/tests/brevitas/graph/test_equalization.py index 4f713211c..89759b41a 100644 --- a/tests/brevitas/graph/test_equalization.py +++ b/tests/brevitas/graph/test_equalization.py @@ -137,7 +137,11 @@ def test_models(toy_model, merge_bias, request): assert torch.allclose(expected_out, out, atol=ATOL) # Check that at least one region performs "true" equalization # If all shapes are scalar, no equalization has been performed - assert all([shape != () for shape in shape_scale_regions]) + if 'convgroupconv' in test_id: + with pytest.raises(AssertionError): + assert all([shape != () for shape in shape_scale_regions]) + else: + assert all([shape != () for shape in shape_scale_regions]) @pytest_cases.parametrize("layerwise", [True, False]) @@ -167,7 +171,13 @@ def test_act_equalization_models(toy_model, layerwise, request): assert torch.allclose(expected_out, out, atol=ATOL) # This region is made up of a residual branch, so no regions are found for act equalization - if 'srcsinkconflict_mode' not in test_id: + if 'convgroupconv' in test_id: + with pytest.raises(AssertionError): + assert len(regions) > 0 + # Check that at least one region performs "true" equalization + # If all shapes are scalar, no equalization has been performed + assert all([shape != () for shape in shape_scale_regions]) + else: assert len(regions) > 0 # Check that at least one region performs "true" equalization # If all shapes are scalar, no equalization has been performed