Skip to content

Commit

Permalink
Fix (Channel-Splitting): clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
fabianandresgrob committed Jan 26, 2024
1 parent d2bcdd3 commit 331bebc
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 42 deletions.
62 changes: 27 additions & 35 deletions src/brevitas/graph/channel_splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from brevitas.graph.equalize import Region
from brevitas.graph.equalize import transpose

__all__ = ['RegionwiseChannelSplitting']
__all__ = ['GraphChannelSplitting']

_conv = (
nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)
Expand All @@ -36,24 +36,24 @@ def _channels_to_split(
modules = sinks if split_input else sources
_get_axis = _get_input_axis if split_input else _get_output_axis
# the modules are all of the same shape so we can just take the first one
single_module = next(iter(modules.values()))
single_module = next(iter(modules))
num_channels = single_module.weight.shape[_get_axis(single_module)]
splits_per_layer = int(math.ceil(split_ratio * num_channels))

module_to_channels = {}
all_channels = []
if split_criterion == 'maxabs':
for name, module in modules.items():
for module in modules:
# get input/output axis of module
axis = _get_axis(module)
# transpose to have axis as first dimension
weight_t = transpose(module.weight, axis)
# flatten all but first dimension and get max per channel
max_per_channel = _channel_maxabs(weight_t.reshape(weight_t.size(0), -1))
channels_sorted = torch.argsort(max_per_channel, descending=True)
module_to_channels[name] = channels_sorted[:splits_per_layer]
all_channels.append(channels_sorted[:splits_per_layer])

# return tensor with the indices to split
channels_to_split = torch.cat(list(module_to_channels.values()))
# return tensor with the unique indices to split
channels_to_split = torch.cat(all_channels)
return torch.unique(channels_to_split)


Expand Down Expand Up @@ -119,17 +119,17 @@ def _split_channels_region(
split_input: bool) -> None:
if not split_input:
# splitting output channels
for name, module in sources.items():
for module in sources:
_split_channels(module, channels_to_split, split_input=False)
for name, module in sinks.items():
for module in sinks:
# duplicating input_channels for all modules in the sink
_split_channels(module, channels_to_split, split_factor=1, split_input=True)
else:
# input channels are split in half, output channels duplicated
for name, module in sinks.items():
for module in sinks:
_split_channels(module, channels_to_split, split_input=True)

for name, module in sources.items():
for module in sources:
# duplicating output_channels for all modules in the source
_split_channels(module, channels_to_split, split_factor=1, split_input=False)

Expand Down Expand Up @@ -160,7 +160,7 @@ def _is_supported(srcs: List[nn.Module], sinks: List[nn.Module]) -> bool:
if any(map(_is_mha, sinks)):
return False
elif any(map(_is_mha, srcs)):
# for all mha in sources, we need to unwrap them access the code directly
# we need to access the weights of the out_proj layers in mha, therefore unwrap
srcs = _unwrap_mha(srcs)

# check if OCs of sources are all equal
Expand All @@ -176,19 +176,11 @@ def _is_supported(srcs: List[nn.Module], sinks: List[nn.Module]) -> bool:
return srcs_ocs == sinks_ics


def _unwrap_mha(
sources: Union[List[nn.Module], Dict[str, nn.Module]]
) -> Union[List[nn.Module], Dict[str, nn.Module]]:
if isinstance(sources, List):
for i, source in enumerate(sources):
if _is_mha(source):
sources[i] = source.out_proj
return sources
elif isinstance(sources, Dict):
for i, source in sources.items():
if _is_mha(source):
sources[i] = source.out_proj
return sources
def _unwrap_mha(sources: List[nn.Module]) -> List[nn.Module]:
for i, source in enumerate(sources):
if _is_mha(source):
sources[i] = source.out_proj
return sources


def _split(
Expand All @@ -198,11 +190,11 @@ def _split(
split_input: bool,
split_criterion: str = 'maxabs') -> GraphModule:
for i, region in enumerate(regions):
sources = {src: region.get_module_from_name(src) for src in region.srcs_names}
sinks = {sink: region.get_module_from_name(sink) for sink in region.sinks_names}
sources = [region.get_module_from_name(src) for src in region.srcs_names]
sinks = [region.get_module_from_name(sink) for sink in region.sinks_names]

# check for mha in sources and unwrap it for out_proj
if any(map(_is_mha, sources.values())):
if any(map(_is_mha, sources)):
sources = _unwrap_mha(sources)

# get channels to split
Expand Down Expand Up @@ -231,38 +223,38 @@ def _clean_regions(regions: List[Region]) -> List[Region]:
source_modules = dict()
sink_modules = dict()
for i, region in enumerate(regions):
sources = {src: region.get_module_from_name(src) for src in region.srcs_names}
sinks = {sink: region.get_module_from_name(sink) for sink in region.sinks_names}
sources = [region.get_module_from_name(src) for src in region.srcs_names]
sinks = [region.get_module_from_name(sink) for sink in region.sinks_names]

# a module cannot be in the sources (or sinks) of multiple regions
for src in sources.keys():
for src in sources:
# if not yet in the dict, instantiate new list for keeping track
if src not in source_modules:
source_modules[src] = [i]
else:
# we know the module has been in sources before, so region needs to be deleted
source_modules[src].append(i)
regions_to_del.update({*source_modules[src]})
for sink in sinks.keys():
for sink in sinks:
if sink not in sink_modules:
sink_modules[sink] = [i]
else:
sink_modules[sink].append(i)
regions_to_del.update({*sink_modules[sink]})

# check for other unsupported
if not _is_supported(list(sources.values()), list(sinks.values())):
if not _is_supported(srcs=sources, sinks=sinks):
# add region to be deleted
regions_to_del.add(i)

regions = [regions[i] for i, _ in enumerate(regions) if i not in regions_to_del]
return regions


class RegionwiseChannelSplitting(GraphTransform):
class GraphChannelSplitting(GraphTransform):

def __init__(self, split_ratio=0.02, split_criterion='maxabs', split_input=True):
super(RegionwiseChannelSplitting, self).__init__()
super(GraphChannelSplitting, self).__init__()

self.split_ratio = split_ratio
self.split_criterion = split_criterion
Expand Down
6 changes: 3 additions & 3 deletions src/brevitas/graph/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from brevitas.core.scaling.standalone import ParameterScaling
from brevitas.fx.brevitas_tracer import symbolic_trace
from brevitas.graph.base import ModuleToModuleByClass
from brevitas.graph.channel_splitting import RegionwiseChannelSplitting
from brevitas.graph.channel_splitting import GraphChannelSplitting
from brevitas.graph.equalize import EqualizeGraph
from brevitas.graph.fixed_point import CollapseConsecutiveConcats
from brevitas.graph.fixed_point import MergeBatchNorm
Expand Down Expand Up @@ -289,8 +289,8 @@ def preprocess_for_quantize(
merge_bias=equalize_merge_bias,
bias_shrinkage=equalize_bias_shrinkage,
scale_computation_type=equalize_scale_computation).apply(model)
if channel_splitting_ratio:
model = RegionwiseChannelSplitting(
if channel_splitting_ratio > 0:
model = GraphChannelSplitting(
split_ratio=channel_splitting_ratio,
split_criterion=channel_splitting_criterion,
split_input=channel_splitting_split_input).apply(model)
Expand Down
22 changes: 18 additions & 4 deletions tests/brevitas/graph/test_channel_splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,18 @@

from .equalization_fixtures import *

no_split_models = (
'mul_model',
'bnconv_model',
'convdepthconv_model',
'linearmha_model',
'layernormmha_model',
'convgroupconv_model',
'vit_b_32',
'shufflenet_v2_x0_5',
'googlenet',
'inception_v3')

SPLIT_RATIO = 0.1


Expand Down Expand Up @@ -39,8 +51,8 @@ def test_toymodels(toy_model, split_input, request):

regions = _extract_regions(model)
regions = _clean_regions(regions)
if not len(regions) > 0:
pytest.skip(reason='No regions supported.')
if model_class in no_split_models:
assert len(regions) == 0
else:
model = _split(model, regions, split_ratio=SPLIT_RATIO, split_input=split_input)

Expand All @@ -67,6 +79,8 @@ def test_toymodels(toy_model, split_input, request):

@pytest.mark.parametrize('split_input', [False, True])
def test_torchvision_models(model_coverage: tuple, split_input: bool, request):
model_class = request.node.callspec.id.split('-')[0]

model, coverage = model_coverage

torch.manual_seed(SEED)
Expand All @@ -83,8 +97,8 @@ def test_torchvision_models(model_coverage: tuple, split_input: bool, request):

regions = _extract_regions(model)
regions = _clean_regions(regions)
if not len(regions) > 0:
pytest.skip(reason='No regions supported.')
if model_class in no_split_models:
assert len(regions) == 0
else:
model = _split(model, regions, split_ratio=SPLIT_RATIO, split_input=split_input)

Expand Down

0 comments on commit 331bebc

Please sign in to comment.