Skip to content

Commit 8676a6d

Browse files
Feat (tests): adding tests for channel splitting
1 parent c117bd7 commit 8676a6d

File tree

6 files changed

+190
-6
lines changed

6 files changed

+190
-6
lines changed

src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py

+3
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,9 @@ def validate_config(config_namespace):
339339
config_namespace.gpfa2q)
340340
if multiple_gpxqs > 1:
341341
is_valid = False
342+
elif multiple_gpxqs == 0:
343+
# no gpxq algorithm, set act order to None
344+
config_namespace.gpxq_act_order = None
342345

343346
if config_namespace.act_equalization == 'layerwise' and config_namespace.target_backend == 'fx':
344347
is_valid = False

src/brevitas_examples/imagenet_classification/ptq/ptq_common.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -197,8 +197,10 @@ def layerwise_bit_width_fn_weight(module):
197197
act_bit_width_dict = {}
198198
if quant_format == 'int' and backend == 'layerwise':
199199
weight_bit_width_dict['weight_bit_width'] = layerwise_bit_width_fn_weight
200-
act_bit_width_dict['act_bit_width'] = layerwise_bit_width_fn_act
201-
200+
if act_bit_width is not None:
201+
act_bit_width_dict['act_bit_width'] = layerwise_bit_width_fn_act
202+
else:
203+
act_bit_width_dict['act_bit_width'] = None
202204
elif quant_format == 'int' and backend != 'layerwise':
203205
weight_bit_width_dict['weight_bit_width'] = weight_bit_width
204206
act_bit_width_dict['act_bit_width'] = act_bit_width
@@ -291,7 +293,7 @@ def kwargs_prefix(prefix, weight_kwargs):
291293
act_bit_width_dict['mantissa_bit_width'] = act_mantissa_bit_width
292294

293295
# Retrieve base input, weight, and bias quantizers
294-
bias_quant = BIAS_BIT_WIDTH_MAP[bias_bit_width]
296+
bias_quant = BIAS_BIT_WIDTH_MAP[bias_bit_width] if act_bit_width is not None else None
295297
weight_quant = WEIGHT_QUANT_MAP[weight_quant_format][weight_scale_type][weight_param_method][
296298
weight_quant_granularity][weight_quant_type]
297299
weight_quant = weight_quant.let(**weight_bit_width_dict)

src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py

+3
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,9 @@ def main():
256256
else:
257257
act_quant_calib_config = args.act_quant_calibration_type
258258

259+
if args.act_bit_width == 0:
260+
args.act_bit_width = None
261+
259262
config = (
260263
f"{args.model_name}_"
261264
f"{args.target_backend}_"

tests/brevitas/graph/equalization_fixtures.py

+46-1
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,49 @@ def forward(self, x):
301301
return ResidualSrcsAndSinkModel
302302

303303

304+
@pytest_cases.fixture
305+
def convgroupconv_model():
306+
307+
class ConvGroupConvModel(nn.Module):
308+
309+
def __init__(self) -> None:
310+
super().__init__()
311+
self.conv = nn.Conv2d(3, 16, kernel_size=3)
312+
self.conv_0 = nn.Conv2d(16, 32, kernel_size=1, groups=2)
313+
self.conv_1 = nn.Conv2d(32, 64, kernel_size=1, groups=4)
314+
self.relu = nn.ReLU()
315+
316+
def forward(self, x):
317+
x = self.conv(x)
318+
x = self.relu(x)
319+
x = self.conv_0(x)
320+
x = self.relu(x)
321+
x = self.conv_1(x)
322+
return x
323+
324+
return ConvGroupConvModel
325+
326+
327+
@pytest_cases.fixture
328+
def convtranspose_model():
329+
330+
class ConvTransposeModel(nn.Module):
331+
332+
def __init__(self) -> None:
333+
super().__init__()
334+
self.relu = nn.ReLU()
335+
self.conv_0 = nn.ConvTranspose2d(in_channels=3, out_channels=8, kernel_size=3)
336+
self.conv_1 = nn.ConvTranspose2d(in_channels=8, out_channels=32, kernel_size=3)
337+
338+
def forward(self, x):
339+
x = self.conv_0(x)
340+
x = self.relu(x)
341+
x = self.conv_1(x)
342+
return x
343+
344+
return ConvTransposeModel
345+
346+
304347
list_of_fixtures = [
305348
'residual_model',
306349
'srcsinkconflict_model',
@@ -309,7 +352,9 @@ def forward(self, x):
309352
'convdepthconv_model',
310353
'linearmha_model',
311354
'mhalinear_model',
312-
'layernormmha_model']
355+
'layernormmha_model',
356+
'convgroupconv_model',
357+
'convtranspose_model']
313358

314359
toy_model = fixture_union('toy_model', list_of_fixtures, ids=list_of_fixtures)
315360

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
2+
# SPDX-License-Identifier: BSD-3-Clause
3+
4+
import torch
5+
6+
from brevitas.fx import symbolic_trace
7+
from brevitas.graph.channel_splitting import _clean_regions
8+
from brevitas.graph.channel_splitting import _split
9+
from brevitas.graph.equalize import _extract_regions
10+
from brevitas.graph.fixed_point import MergeBatchNorm
11+
12+
from .equalization_fixtures import *
13+
14+
no_split_models = (
15+
'mul_model',
16+
'bnconv_model',
17+
'convdepthconv_model',
18+
'linearmha_model',
19+
'layernormmha_model',
20+
'convgroupconv_model',
21+
'vit_b_32',
22+
'shufflenet_v2_x0_5',
23+
'googlenet',
24+
'inception_v3')
25+
26+
SPLIT_RATIO = 0.1
27+
28+
29+
@pytest.mark.parametrize('split_input', [False, True])
30+
def test_toymodels(toy_model, split_input, request):
31+
test_id = request.node.callspec.id
32+
33+
torch.manual_seed(SEED)
34+
35+
model_class = toy_model
36+
model = model_class()
37+
if 'mha' in test_id:
38+
inp = torch.randn(IN_SIZE_LINEAR)
39+
else:
40+
inp = torch.randn(IN_SIZE_CONV)
41+
42+
model.eval()
43+
expected_out = model(inp)
44+
45+
model = symbolic_trace(model)
46+
# merge BN before applying channel splitting
47+
model = MergeBatchNorm().apply(model)
48+
49+
# save model's state dict to check if channel splitting was done or not
50+
old_state_dict = model.state_dict()
51+
52+
regions = _extract_regions(model)
53+
regions = _clean_regions(regions)
54+
if model_class in no_split_models:
55+
assert len(regions) == 0
56+
else:
57+
model = _split(model, regions, split_ratio=SPLIT_RATIO, split_input=split_input)
58+
59+
out = model(inp)
60+
assert torch.allclose(expected_out, out, atol=ATOL)
61+
62+
modified_sources = {source for region in regions for source in region.srcs_names}
63+
# avoiding checking the same module multiple times
64+
modified_sinks = {
65+
sink for region in regions for sink in region.sinks_names} - modified_sources
66+
for module in modified_sources:
67+
if 'mha' in module:
68+
module += '.out_proj'
69+
weight_name = module + '.weight'
70+
assert not torch.equal(old_state_dict[weight_name], model.state_dict()[weight_name])
71+
bias_name = module + '.bias'
72+
# not all modules have bias and they only differ when splitting output channels
73+
if bias_name in old_state_dict.keys() and not split_input:
74+
assert not torch.equal(old_state_dict[bias_name], model.state_dict()[bias_name])
75+
for module in modified_sinks:
76+
weight_name = module + '.weight'
77+
assert not torch.equal(old_state_dict[weight_name], model.state_dict()[weight_name])
78+
79+
80+
@pytest.mark.parametrize('split_input', [False, True])
81+
def test_torchvision_models(model_coverage: tuple, split_input: bool, request):
82+
model_class = request.node.callspec.id.split('-')[0]
83+
84+
model, coverage = model_coverage
85+
86+
torch.manual_seed(SEED)
87+
inp = torch.randn(IN_SIZE_CONV)
88+
89+
model.eval()
90+
expected_out = model(inp)
91+
92+
model = symbolic_trace(model)
93+
# merge BN before applying channel splitting
94+
model = MergeBatchNorm().apply(model)
95+
96+
old_state_dict = model.state_dict()
97+
98+
regions = _extract_regions(model)
99+
regions = _clean_regions(regions)
100+
if model_class in no_split_models:
101+
assert len(regions) == 0
102+
else:
103+
model = _split(model, regions, split_ratio=SPLIT_RATIO, split_input=split_input)
104+
105+
out = model(inp)
106+
assert torch.allclose(expected_out, out, atol=ATOL)
107+
108+
modified_sources = {source for region in regions for source in region.srcs_names}
109+
# avoiding checking the same module multiple times
110+
modified_sinks = {
111+
sink for region in regions for sink in region.sinks_names} - modified_sources
112+
for module in modified_sources:
113+
weight_name = module + '.weight'
114+
assert not torch.equal(old_state_dict[weight_name], model.state_dict()[weight_name])
115+
bias_name = module + '.bias'
116+
# not all modules have bias and they only differ when splitting output channels
117+
if bias_name in old_state_dict.keys() and not split_input:
118+
assert not torch.equal(old_state_dict[bias_name], model.state_dict()[bias_name])
119+
for module in modified_sinks:
120+
weight_name = module + '.weight'
121+
assert not torch.equal(old_state_dict[weight_name], model.state_dict()[weight_name])

tests/brevitas/graph/test_equalization.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,11 @@ def test_models(toy_model, merge_bias, request):
137137
assert torch.allclose(expected_out, out, atol=ATOL)
138138
# Check that at least one region performs "true" equalization
139139
# If all shapes are scalar, no equalization has been performed
140-
assert all([shape != () for shape in shape_scale_regions])
140+
if 'convgroupconv' in test_id:
141+
with pytest.raises(AssertionError):
142+
assert all([shape != () for shape in shape_scale_regions])
143+
else:
144+
assert all([shape != () for shape in shape_scale_regions])
141145

142146

143147
@pytest_cases.parametrize("layerwise", [True, False])
@@ -167,7 +171,13 @@ def test_act_equalization_models(toy_model, layerwise, request):
167171
assert torch.allclose(expected_out, out, atol=ATOL)
168172

169173
# This region is made up of a residual branch, so no regions are found for act equalization
170-
if 'srcsinkconflict_mode' not in test_id:
174+
if 'convgroupconv' in test_id:
175+
with pytest.raises(AssertionError):
176+
assert len(regions) > 0
177+
# Check that at least one region performs "true" equalization
178+
# If all shapes are scalar, no equalization has been performed
179+
assert all([shape != () for shape in shape_scale_regions])
180+
else:
171181
assert len(regions) > 0
172182
# Check that at least one region performs "true" equalization
173183
# If all shapes are scalar, no equalization has been performed

0 commit comments

Comments
 (0)