Skip to content

Commit

Permalink
Feat (tests): adding tests for channel splitting
Browse files Browse the repository at this point in the history
  • Loading branch information
fabianandresgrob committed Jan 26, 2024
1 parent 9f9f259 commit f671dd1
Show file tree
Hide file tree
Showing 6 changed files with 190 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,9 @@ def validate_config(config_namespace):
config_namespace.gpfa2q)
if multiple_gpxqs > 1:
is_valid = False
elif multiple_gpxqs == 0:
# no gpxq algorithm, set act order to None
config_namespace.gpxq_act_order = None

if config_namespace.act_equalization == 'layerwise' and config_namespace.target_backend == 'fx':
is_valid = False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,10 @@ def layerwise_bit_width_fn_weight(module):
act_bit_width_dict = {}
if quant_format == 'int' and backend == 'layerwise':
weight_bit_width_dict['weight_bit_width'] = layerwise_bit_width_fn_weight
act_bit_width_dict['act_bit_width'] = layerwise_bit_width_fn_act

if act_bit_width is not None:
act_bit_width_dict['act_bit_width'] = layerwise_bit_width_fn_act
else:
act_bit_width_dict['act_bit_width'] = None
elif quant_format == 'int' and backend != 'layerwise':
weight_bit_width_dict['weight_bit_width'] = weight_bit_width
act_bit_width_dict['act_bit_width'] = act_bit_width
Expand Down Expand Up @@ -291,7 +293,7 @@ def kwargs_prefix(prefix, weight_kwargs):
act_bit_width_dict['mantissa_bit_width'] = act_mantissa_bit_width

# Retrieve base input, weight, and bias quantizers
bias_quant = BIAS_BIT_WIDTH_MAP[bias_bit_width]
bias_quant = BIAS_BIT_WIDTH_MAP[bias_bit_width] if act_bit_width is not None else None
weight_quant = WEIGHT_QUANT_MAP[weight_quant_format][weight_scale_type][weight_param_method][
weight_quant_granularity][weight_quant_type]
weight_quant = weight_quant.let(**weight_bit_width_dict)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,9 @@ def main():
else:
act_quant_calib_config = args.act_quant_calibration_type

if args.act_bit_width == 0:
args.act_bit_width = None

config = (
f"{args.model_name}_"
f"{args.target_backend}_"
Expand Down
47 changes: 46 additions & 1 deletion tests/brevitas/graph/equalization_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,49 @@ def forward(self, x):
return ResidualSrcsAndSinkModel


@pytest_cases.fixture
def convgroupconv_model():

class ConvGroupConvModel(nn.Module):

def __init__(self) -> None:
super().__init__()
self.conv = nn.Conv2d(3, 16, kernel_size=3)
self.conv_0 = nn.Conv2d(16, 32, kernel_size=1, groups=2)
self.conv_1 = nn.Conv2d(32, 64, kernel_size=1, groups=4)
self.relu = nn.ReLU()

def forward(self, x):
x = self.conv(x)
x = self.relu(x)
x = self.conv_0(x)
x = self.relu(x)
x = self.conv_1(x)
return x

return ConvGroupConvModel


@pytest_cases.fixture
def convtranspose_model():

class ConvTransposeModel(nn.Module):

def __init__(self) -> None:
super().__init__()
self.relu = nn.ReLU()
self.conv_0 = nn.ConvTranspose2d(in_channels=3, out_channels=8, kernel_size=3)
self.conv_1 = nn.ConvTranspose2d(in_channels=8, out_channels=32, kernel_size=3)

def forward(self, x):
x = self.conv_0(x)
x = self.relu(x)
x = self.conv_1(x)
return x

return ConvTransposeModel


list_of_fixtures = [
'residual_model',
'srcsinkconflict_model',
Expand All @@ -309,7 +352,9 @@ def forward(self, x):
'convdepthconv_model',
'linearmha_model',
'mhalinear_model',
'layernormmha_model']
'layernormmha_model',
'convgroupconv_model',
'convtranspose_model']

toy_model = fixture_union('toy_model', list_of_fixtures, ids=list_of_fixtures)

Expand Down
121 changes: 121 additions & 0 deletions tests/brevitas/graph/test_channel_splitting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

import torch

from brevitas.fx import symbolic_trace
from brevitas.graph.channel_splitting import _clean_regions
from brevitas.graph.channel_splitting import _split
from brevitas.graph.equalize import _extract_regions
from brevitas.graph.fixed_point import MergeBatchNorm

from .equalization_fixtures import *

no_split_models = (
'mul_model',
'bnconv_model',
'convdepthconv_model',
'linearmha_model',
'layernormmha_model',
'convgroupconv_model',
'vit_b_32',
'shufflenet_v2_x0_5',
'googlenet',
'inception_v3')

SPLIT_RATIO = 0.1


@pytest.mark.parametrize('split_input', [False, True])
def test_toymodels(toy_model, split_input, request):
test_id = request.node.callspec.id

torch.manual_seed(SEED)

model_class = toy_model
model = model_class()
if 'mha' in test_id:
inp = torch.randn(IN_SIZE_LINEAR)
else:
inp = torch.randn(IN_SIZE_CONV)

model.eval()
expected_out = model(inp)

model = symbolic_trace(model)
# merge BN before applying channel splitting
model = MergeBatchNorm().apply(model)

# save model's state dict to check if channel splitting was done or not
old_state_dict = model.state_dict()

regions = _extract_regions(model)
regions = _clean_regions(regions)
if model_class in no_split_models:
assert len(regions) == 0
else:
model = _split(model, regions, split_ratio=SPLIT_RATIO, split_input=split_input)

out = model(inp)
assert torch.allclose(expected_out, out, atol=ATOL)

modified_sources = {source for region in regions for source in region.srcs_names}
# avoiding checking the same module multiple times
modified_sinks = {
sink for region in regions for sink in region.sinks_names} - modified_sources
for module in modified_sources:
if 'mha' in module:
module += '.out_proj'
weight_name = module + '.weight'
assert not torch.equal(old_state_dict[weight_name], model.state_dict()[weight_name])
bias_name = module + '.bias'
# not all modules have bias and they only differ when splitting output channels
if bias_name in old_state_dict.keys() and not split_input:
assert not torch.equal(old_state_dict[bias_name], model.state_dict()[bias_name])
for module in modified_sinks:
weight_name = module + '.weight'
assert not torch.equal(old_state_dict[weight_name], model.state_dict()[weight_name])


@pytest.mark.parametrize('split_input', [False, True])
def test_torchvision_models(model_coverage: tuple, split_input: bool, request):
model_class = request.node.callspec.id.split('-')[0]

model, coverage = model_coverage

torch.manual_seed(SEED)
inp = torch.randn(IN_SIZE_CONV)

model.eval()
expected_out = model(inp)

model = symbolic_trace(model)
# merge BN before applying channel splitting
model = MergeBatchNorm().apply(model)

old_state_dict = model.state_dict()

regions = _extract_regions(model)
regions = _clean_regions(regions)
if model_class in no_split_models:
assert len(regions) == 0
else:
model = _split(model, regions, split_ratio=SPLIT_RATIO, split_input=split_input)

out = model(inp)
assert torch.allclose(expected_out, out, atol=ATOL)

modified_sources = {source for region in regions for source in region.srcs_names}
# avoiding checking the same module multiple times
modified_sinks = {
sink for region in regions for sink in region.sinks_names} - modified_sources
for module in modified_sources:
weight_name = module + '.weight'
assert not torch.equal(old_state_dict[weight_name], model.state_dict()[weight_name])
bias_name = module + '.bias'
# not all modules have bias and they only differ when splitting output channels
if bias_name in old_state_dict.keys() and not split_input:
assert not torch.equal(old_state_dict[bias_name], model.state_dict()[bias_name])
for module in modified_sinks:
weight_name = module + '.weight'
assert not torch.equal(old_state_dict[weight_name], model.state_dict()[weight_name])
14 changes: 12 additions & 2 deletions tests/brevitas/graph/test_equalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,11 @@ def test_models(toy_model, merge_bias, request):
assert torch.allclose(expected_out, out, atol=ATOL)
# Check that at least one region performs "true" equalization
# If all shapes are scalar, no equalization has been performed
assert all([shape != () for shape in shape_scale_regions])
if 'convgroupconv' in test_id:
with pytest.raises(AssertionError):
assert all([shape != () for shape in shape_scale_regions])
else:
assert all([shape != () for shape in shape_scale_regions])


@pytest_cases.parametrize("layerwise", [True, False])
Expand Down Expand Up @@ -167,7 +171,13 @@ def test_act_equalization_models(toy_model, layerwise, request):
assert torch.allclose(expected_out, out, atol=ATOL)

# This region is made up of a residual branch, so no regions are found for act equalization
if 'srcsinkconflict_mode' not in test_id:
if 'convgroupconv' in test_id:
with pytest.raises(AssertionError):
assert len(regions) > 0
# Check that at least one region performs "true" equalization
# If all shapes are scalar, no equalization has been performed
assert all([shape != () for shape in shape_scale_regions])
else:
assert len(regions) > 0
# Check that at least one region performs "true" equalization
# If all shapes are scalar, no equalization has been performed
Expand Down

0 comments on commit f671dd1

Please sign in to comment.