-
Notifications
You must be signed in to change notification settings - Fork 208
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Feat (Channel-Splitting): add support for graph channel splitting
- Loading branch information
1 parent
ab27d7f
commit 4bde0c5
Showing
5 changed files
with
332 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,285 @@ | ||
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
|
||
import math | ||
from typing import Dict, List, Set, Tuple, Union | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from brevitas.fx import GraphModule | ||
from brevitas.graph.base import GraphTransform | ||
from brevitas.graph.equalize import _channel_maxabs | ||
from brevitas.graph.equalize import _extract_regions | ||
from brevitas.graph.equalize import _get_input_axis | ||
from brevitas.graph.equalize import _get_output_axis | ||
from brevitas.graph.equalize import Region | ||
from brevitas.graph.equalize import transpose | ||
|
||
__all__ = ['GraphChannelSplitting'] | ||
|
||
_conv = ( | ||
nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d) | ||
|
||
_unsupported_layers = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.LayerNorm) | ||
|
||
|
||
def _channels_to_split( | ||
sources: Dict[str, nn.Module], | ||
sinks: Dict[str, nn.Module], | ||
split_criterion: str, | ||
split_ratio: float, | ||
split_input: bool) -> Dict[nn.Module, List[torch.Tensor]]: | ||
""" | ||
This method computes the channels that will be split based on `split_criterion`. | ||
""" | ||
modules = sinks if split_input else sources | ||
_get_axis = _get_input_axis if split_input else _get_output_axis | ||
# the modules are all of the same shape so we can just take the first one | ||
single_module = next(iter(modules)) | ||
num_channels = single_module.weight.shape[_get_axis(single_module)] | ||
splits_per_layer = int(math.ceil(split_ratio * num_channels)) | ||
|
||
all_channels = [] | ||
if split_criterion == 'maxabs': | ||
for module in modules: | ||
# get input/output axis of module | ||
axis = _get_axis(module) | ||
# transpose to have axis as first dimension | ||
weight_t = transpose(module.weight, axis) | ||
# flatten all but first dimension and get max per channel | ||
max_per_channel = _channel_maxabs(weight_t.reshape(weight_t.size(0), -1)) | ||
channels_sorted = torch.argsort(max_per_channel, descending=True) | ||
all_channels.append(channels_sorted[:splits_per_layer]) | ||
|
||
# return tensor with the unique indices to split | ||
channels_to_split = torch.cat(all_channels) | ||
return torch.unique(channels_to_split) | ||
|
||
|
||
# decorator is needed to modify the weights in-place using a view | ||
@torch.no_grad() | ||
def _split_channels( | ||
module: nn.Module, | ||
channels_to_split: torch.Tensor, | ||
split_input: bool = False, | ||
split_factor: float = 0.5) -> None: | ||
""" | ||
Given a module, this method splits the weight channels as proposed in https://arxiv.org/abs/1901.09504. | ||
`split_factor` determines how to split the channels, `channels_to_split` is a list of channel indices. | ||
If `split_input=True`, the input channels of the module are split, otherwise the output channels. | ||
""" | ||
weight = module.weight.data | ||
bias = module.bias.data if module.bias is not None else None | ||
num_added_channels = len(channels_to_split) | ||
|
||
_get_axis = _get_input_axis if split_input else _get_output_axis | ||
axis = _get_axis(module) | ||
# save shape of the module weights | ||
orig_shape = list(weight.shape) | ||
weight_t = transpose(weight, axis) | ||
# flatten to 2d | ||
weight_t = weight_t.reshape(weight_t.size(0), -1) | ||
for id in channels_to_split: | ||
# split and get channel to stack | ||
weight_t[id, :] *= split_factor | ||
split_channel = weight_t[id, :] | ||
# expand so we can stack | ||
split_channel = split_channel.expand(1, split_channel.size(0)) | ||
weight_t = torch.cat([weight_t, split_channel], dim=0) | ||
|
||
if bias is not None and not split_input: | ||
bias[id] *= split_factor | ||
split_channel = bias[id:id + 1] | ||
bias = torch.cat((bias, split_channel)) | ||
|
||
# reshape weight_t back to orig shape with the added channels | ||
del orig_shape[axis] | ||
weight_t = weight_t.reshape(weight_t.size(0), *orig_shape) | ||
weight_t = transpose(weight_t, axis) | ||
module.weight.data = weight_t | ||
if bias is not None: | ||
module.bias.data = bias | ||
|
||
if isinstance(module, _conv): | ||
if split_input: | ||
module.in_channels += num_added_channels | ||
else: | ||
module.out_channels += num_added_channels | ||
elif isinstance(module, nn.Linear): | ||
if split_input: | ||
module.in_features += num_added_channels | ||
else: | ||
module.out_features += num_added_channels | ||
|
||
|
||
def _split_channels_region( | ||
sources: Dict[str, nn.Module], | ||
sinks: Dict[str, nn.Module], | ||
channels_to_split: torch.tensor, | ||
split_input: bool) -> None: | ||
if not split_input: | ||
# splitting output channels | ||
for module in sources: | ||
_split_channels(module, channels_to_split, split_input=False) | ||
for module in sinks: | ||
# duplicating input_channels for all modules in the sink | ||
_split_channels(module, channels_to_split, split_factor=1, split_input=True) | ||
else: | ||
# input channels are split in half, output channels duplicated | ||
for module in sinks: | ||
_split_channels(module, channels_to_split, split_input=True) | ||
|
||
for module in sources: | ||
# duplicating output_channels for all modules in the source | ||
_split_channels(module, channels_to_split, split_factor=1, split_input=False) | ||
|
||
|
||
def _is_groupwise(module: nn.Module) -> bool: | ||
# only Conv layers can be groupwise | ||
return isinstance(module, _conv) and module.groups > 1 | ||
|
||
|
||
def _is_unsupported(module: nn.Module) -> bool: | ||
return isinstance(module, _unsupported_layers) | ||
|
||
|
||
def _is_mha(module: nn.Module) -> bool: | ||
return isinstance(module, nn.MultiheadAttention) | ||
|
||
|
||
def _is_supported(srcs: List[nn.Module], sinks: List[nn.Module]) -> bool: | ||
# groupwise convolutions are not supported so filter them out | ||
if any(map(_is_groupwise, srcs + sinks)): | ||
return False | ||
|
||
# filter out unsupported layers | ||
if any(map(_is_unsupported, sinks + srcs)): | ||
return False | ||
|
||
# mha can only be in the sources | ||
if any(map(_is_mha, sinks)): | ||
return False | ||
elif any(map(_is_mha, srcs)): | ||
# we need to access the weights of the out_proj layers in mha, therefore unwrap | ||
srcs = _unwrap_mha(srcs) | ||
|
||
# check if OCs of sources are all equal | ||
srcs_ocs = set(module.weight.shape[_get_output_axis(module)] for module in srcs) | ||
if len(srcs_ocs) > 1: | ||
return False | ||
|
||
# check if ICs of sinks are all equal | ||
sinks_ics = set(module.weight.shape[_get_input_axis(module)] for module in sinks) | ||
if len(sinks_ics) > 1: | ||
return False | ||
|
||
return srcs_ocs == sinks_ics | ||
|
||
|
||
def _unwrap_mha(sources: List[nn.Module]) -> List[nn.Module]: | ||
for i, source in enumerate(sources): | ||
if _is_mha(source): | ||
sources[i] = source.out_proj | ||
return sources | ||
|
||
|
||
def _split( | ||
model: GraphModule, | ||
regions: List[Region], | ||
split_ratio: float, | ||
split_input: bool, | ||
split_criterion: str = 'maxabs') -> GraphModule: | ||
for i, region in enumerate(regions): | ||
sources = [region.get_module_from_name(src) for src in region.srcs_names] | ||
sinks = [region.get_module_from_name(sink) for sink in region.sinks_names] | ||
|
||
# check for mha in sources and unwrap it for out_proj | ||
if any(map(_is_mha, sources)): | ||
sources = _unwrap_mha(sources) | ||
|
||
# get channels to split | ||
channels_to_split = _channels_to_split( | ||
sources=sources, | ||
sinks=sinks, | ||
split_criterion=split_criterion, | ||
split_ratio=split_ratio, | ||
split_input=split_input) | ||
# splitting/duplicating channels | ||
_split_channels_region( | ||
sources=sources, | ||
sinks=sinks, | ||
channels_to_split=channels_to_split, | ||
split_input=split_input) | ||
|
||
return model | ||
|
||
|
||
def _clean_regions(regions: List[Region]) -> List[Region]: | ||
""" | ||
Given a list of regions, this method removes all regions that are not compatible with channel splitting. | ||
""" | ||
# idea: map modules to their regions and check whether it appears in multiple regions | ||
regions_to_del = set() | ||
source_modules = dict() | ||
sink_modules = dict() | ||
for i, region in enumerate(regions): | ||
sources = [region.get_module_from_name(src) for src in region.srcs_names] | ||
sinks = [region.get_module_from_name(sink) for sink in region.sinks_names] | ||
|
||
# a module cannot be in the sources (or sinks) of multiple regions | ||
for src in sources: | ||
# if not yet in the dict, instantiate new list for keeping track | ||
if src not in source_modules: | ||
source_modules[src] = [i] | ||
else: | ||
# we know the module has been in sources before, so region needs to be deleted | ||
source_modules[src].append(i) | ||
regions_to_del.update({*source_modules[src]}) | ||
for sink in sinks: | ||
if sink not in sink_modules: | ||
sink_modules[sink] = [i] | ||
else: | ||
sink_modules[sink].append(i) | ||
regions_to_del.update({*sink_modules[sink]}) | ||
|
||
# check for other unsupported | ||
if not _is_supported(srcs=sources, sinks=sinks): | ||
# add region to be deleted | ||
regions_to_del.add(i) | ||
|
||
regions = [regions[i] for i, _ in enumerate(regions) if i not in regions_to_del] | ||
return regions | ||
|
||
|
||
class GraphChannelSplitting(GraphTransform): | ||
|
||
def __init__( | ||
self, | ||
split_ratio: float = 0.02, | ||
split_criterion: str = 'maxabs', | ||
split_input: bool = True): | ||
super(GraphChannelSplitting, self).__init__() | ||
|
||
self.split_ratio = split_ratio | ||
self.split_criterion = split_criterion | ||
self.split_input = split_input | ||
|
||
def apply( | ||
self, | ||
model: GraphModule, | ||
return_regions: bool = False | ||
) -> Union[Tuple[GraphModule, Set[Tuple[str]]], GraphModule]: | ||
regions = _extract_regions(model) | ||
regions = _clean_regions(regions) | ||
if len(regions) > 0: | ||
model = _split( | ||
model=model, | ||
regions=regions, | ||
split_ratio=self.split_ratio, | ||
split_criterion=self.split_criterion, | ||
split_input=self.split_input) | ||
if return_regions: | ||
return model, regions | ||
else: | ||
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.