From b37258ac85460c338d7e4e02c9dec6d5290703c1 Mon Sep 17 00:00:00 2001 From: Fabian Grob Date: Thu, 25 Jan 2024 01:56:24 -0800 Subject: [PATCH] Fix (Channel-Splitting): clean up --- src/brevitas/graph/channel_splitting.py | 62 ++++++++----------- src/brevitas/graph/quantize.py | 6 +- .../brevitas/graph/test_channel_splitting.py | 22 +++++-- 3 files changed, 48 insertions(+), 42 deletions(-) diff --git a/src/brevitas/graph/channel_splitting.py b/src/brevitas/graph/channel_splitting.py index 891eeac70..6cc649c2b 100644 --- a/src/brevitas/graph/channel_splitting.py +++ b/src/brevitas/graph/channel_splitting.py @@ -16,7 +16,7 @@ from brevitas.graph.equalize import Region from brevitas.graph.equalize import transpose -__all__ = ['RegionwiseChannelSplitting'] +__all__ = ['GraphChannelSplitting'] _conv = ( nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d) @@ -36,13 +36,13 @@ def _channels_to_split( modules = sinks if split_input else sources _get_axis = _get_input_axis if split_input else _get_output_axis # the modules are all of the same shape so we can just take the first one - single_module = next(iter(modules.values())) + single_module = next(iter(modules)) num_channels = single_module.weight.shape[_get_axis(single_module)] splits_per_layer = int(math.ceil(split_ratio * num_channels)) - module_to_channels = {} + all_channels = [] if split_criterion == 'maxabs': - for name, module in modules.items(): + for module in modules: # get input/output axis of module axis = _get_axis(module) # transpose to have axis as first dimension @@ -50,10 +50,10 @@ def _channels_to_split( # flatten all but first dimension and get max per channel max_per_channel = _channel_maxabs(weight_t.reshape(weight_t.size(0), -1)) channels_sorted = torch.argsort(max_per_channel, descending=True) - module_to_channels[name] = channels_sorted[:splits_per_layer] + all_channels.append(channels_sorted[:splits_per_layer]) - # return tensor with the indices to split - channels_to_split = torch.cat(list(module_to_channels.values())) + # return tensor with the unique indices to split + channels_to_split = torch.cat(all_channels) return torch.unique(channels_to_split) @@ -119,17 +119,17 @@ def _split_channels_region( split_input: bool) -> None: if not split_input: # splitting output channels - for name, module in sources.items(): + for module in sources: _split_channels(module, channels_to_split, split_input=False) - for name, module in sinks.items(): + for module in sinks: # duplicating input_channels for all modules in the sink _split_channels(module, channels_to_split, split_factor=1, split_input=True) else: # input channels are split in half, output channels duplicated - for name, module in sinks.items(): + for module in sinks: _split_channels(module, channels_to_split, split_input=True) - for name, module in sources.items(): + for module in sources: # duplicating output_channels for all modules in the source _split_channels(module, channels_to_split, split_factor=1, split_input=False) @@ -160,7 +160,7 @@ def _is_supported(srcs: List[nn.Module], sinks: List[nn.Module]) -> bool: if any(map(_is_mha, sinks)): return False elif any(map(_is_mha, srcs)): - # for all mha in sources, we need to unwrap them access the code directly + # we need to access the weights of the out_proj layers in mha, therefore unwrap srcs = _unwrap_mha(srcs) # check if OCs of sources are all equal @@ -176,19 +176,11 @@ def _is_supported(srcs: List[nn.Module], sinks: List[nn.Module]) -> bool: return srcs_ocs == sinks_ics -def _unwrap_mha( - sources: Union[List[nn.Module], Dict[str, nn.Module]] -) -> Union[List[nn.Module], Dict[str, nn.Module]]: - if isinstance(sources, List): - for i, source in enumerate(sources): - if _is_mha(source): - sources[i] = source.out_proj - return sources - elif isinstance(sources, Dict): - for i, source in sources.items(): - if _is_mha(source): - sources[i] = source.out_proj - return sources +def _unwrap_mha(sources: List[nn.Module]) -> List[nn.Module]: + for i, source in enumerate(sources): + if _is_mha(source): + sources[i] = source.out_proj + return sources def _split( @@ -198,11 +190,11 @@ def _split( split_input: bool, split_criterion: str = 'maxabs') -> GraphModule: for i, region in enumerate(regions): - sources = {src: region.get_module_from_name(src) for src in region.srcs_names} - sinks = {sink: region.get_module_from_name(sink) for sink in region.sinks_names} + sources = [region.get_module_from_name(src) for src in region.srcs_names] + sinks = [region.get_module_from_name(sink) for sink in region.sinks_names] # check for mha in sources and unwrap it for out_proj - if any(map(_is_mha, sources.values())): + if any(map(_is_mha, sources)): sources = _unwrap_mha(sources) # get channels to split @@ -231,11 +223,11 @@ def _clean_regions(regions: List[Region]) -> List[Region]: source_modules = dict() sink_modules = dict() for i, region in enumerate(regions): - sources = {src: region.get_module_from_name(src) for src in region.srcs_names} - sinks = {sink: region.get_module_from_name(sink) for sink in region.sinks_names} + sources = [region.get_module_from_name(src) for src in region.srcs_names] + sinks = [region.get_module_from_name(sink) for sink in region.sinks_names] # a module cannot be in the sources (or sinks) of multiple regions - for src in sources.keys(): + for src in sources: # if not yet in the dict, instantiate new list for keeping track if src not in source_modules: source_modules[src] = [i] @@ -243,7 +235,7 @@ def _clean_regions(regions: List[Region]) -> List[Region]: # we know the module has been in sources before, so region needs to be deleted source_modules[src].append(i) regions_to_del.update({*source_modules[src]}) - for sink in sinks.keys(): + for sink in sinks: if sink not in sink_modules: sink_modules[sink] = [i] else: @@ -251,7 +243,7 @@ def _clean_regions(regions: List[Region]) -> List[Region]: regions_to_del.update({*sink_modules[sink]}) # check for other unsupported - if not _is_supported(list(sources.values()), list(sinks.values())): + if not _is_supported(srcs=sources, sinks=sinks): # add region to be deleted regions_to_del.add(i) @@ -259,10 +251,10 @@ def _clean_regions(regions: List[Region]) -> List[Region]: return regions -class RegionwiseChannelSplitting(GraphTransform): +class GraphChannelSplitting(GraphTransform): def __init__(self, split_ratio=0.02, split_criterion='maxabs', split_input=True): - super(RegionwiseChannelSplitting, self).__init__() + super(GraphChannelSplitting, self).__init__() self.split_ratio = split_ratio self.split_criterion = split_criterion diff --git a/src/brevitas/graph/quantize.py b/src/brevitas/graph/quantize.py index 1538200a4..bd7bc04e6 100644 --- a/src/brevitas/graph/quantize.py +++ b/src/brevitas/graph/quantize.py @@ -8,7 +8,7 @@ from brevitas.core.scaling.standalone import ParameterScaling from brevitas.fx.brevitas_tracer import symbolic_trace from brevitas.graph.base import ModuleToModuleByClass -from brevitas.graph.channel_splitting import RegionwiseChannelSplitting +from brevitas.graph.channel_splitting import GraphChannelSplitting from brevitas.graph.equalize import EqualizeGraph from brevitas.graph.fixed_point import CollapseConsecutiveConcats from brevitas.graph.fixed_point import MergeBatchNorm @@ -289,8 +289,8 @@ def preprocess_for_quantize( merge_bias=equalize_merge_bias, bias_shrinkage=equalize_bias_shrinkage, scale_computation_type=equalize_scale_computation).apply(model) - if channel_splitting_ratio: - model = RegionwiseChannelSplitting( + if channel_splitting_ratio > 0: + model = GraphChannelSplitting( split_ratio=channel_splitting_ratio, split_criterion=channel_splitting_criterion, split_input=channel_splitting_split_input).apply(model) diff --git a/tests/brevitas/graph/test_channel_splitting.py b/tests/brevitas/graph/test_channel_splitting.py index 6043c7a9b..30a3dd8d3 100644 --- a/tests/brevitas/graph/test_channel_splitting.py +++ b/tests/brevitas/graph/test_channel_splitting.py @@ -11,6 +11,18 @@ from .equalization_fixtures import * +no_split_models = ( + 'mul_model', + 'bnconv_model', + 'convdepthconv_model', + 'linearmha_model', + 'layernormmha_model', + 'convgroupconv_model', + 'vit_b_32', + 'shufflenet_v2_x0_5', + 'googlenet', + 'inception_v3') + SPLIT_RATIO = 0.1 @@ -39,8 +51,8 @@ def test_toymodels(toy_model, split_input, request): regions = _extract_regions(model) regions = _clean_regions(regions) - if not len(regions) > 0: - pytest.skip(reason='No regions supported.') + if model_class in no_split_models: + assert len(regions) == 0 else: model = _split(model, regions, split_ratio=SPLIT_RATIO, split_input=split_input) @@ -67,6 +79,8 @@ def test_toymodels(toy_model, split_input, request): @pytest.mark.parametrize('split_input', [False, True]) def test_torchvision_models(model_coverage: tuple, split_input: bool, request): + model_class = request.node.callspec.id.split('-')[0] + model, coverage = model_coverage torch.manual_seed(SEED) @@ -83,8 +97,8 @@ def test_torchvision_models(model_coverage: tuple, split_input: bool, request): regions = _extract_regions(model) regions = _clean_regions(regions) - if not len(regions) > 0: - pytest.skip(reason='No regions supported.') + if model_class in no_split_models: + assert len(regions) == 0 else: model = _split(model, regions, split_ratio=SPLIT_RATIO, split_input=split_input)