diff --git a/src/brevitas/graph/quantize.py b/src/brevitas/graph/quantize.py index 4584a4eb8..9be9fb812 100644 --- a/src/brevitas/graph/quantize.py +++ b/src/brevitas/graph/quantize.py @@ -268,6 +268,7 @@ def preprocess_for_quantize( channel_splitting=False, channel_splitting_ratio=0.02, channel_splitting_grid_aware=False, + channel_splitting_split_iteratively=False, channel_splitting_criterion: str = 'maxabs', channel_splitting_weight_bit_width: int = 8): @@ -296,6 +297,7 @@ def preprocess_for_quantize( split_ratio=channel_splitting_ratio, grid_aware=channel_splitting_grid_aware, split_criterion=channel_splitting_criterion, + split_iteratively=channel_splitting_split_iteratively, weight_bit_width=channel_splitting_weight_bit_width).apply(model) model.train(training_state) return model diff --git a/src/brevitas/ptq_algorithms/channel_splitting.py b/src/brevitas/ptq_algorithms/channel_splitting.py index 7189e79be..a752b24b6 100644 --- a/src/brevitas/ptq_algorithms/channel_splitting.py +++ b/src/brevitas/ptq_algorithms/channel_splitting.py @@ -392,13 +392,19 @@ def split_channels_iteratively(module, split_ratio, grid_aware, bit_width): class LayerwiseChannelSplitting(GraphTransform): def __init__( - self, split_ratio=0.02, split_criterion='maxabs', grid_aware=False, weight_bit_width=8): + self, + split_ratio=0.02, + split_criterion='maxabs', + grid_aware=False, + split_iteratively=False, + weight_bit_width=8): super(LayerwiseChannelSplitting, self).__init__() self.grid_aware = grid_aware self.split_ratio = split_ratio self.split_criterion = split_criterion self.weight_bit_width = weight_bit_width + self.split_iteratively = split_iteratively def _is_supported_module(self, graph_model: GraphModule, node: Node) -> bool: if node.op == 'call_module': @@ -413,16 +419,26 @@ def apply(self, graph_model: GraphModule): for node in graph_model.graph.nodes: if self._is_supported_module(graph_model, node): module = get_module(graph_model, node.target) - # we only split input channels - # channels_to_split = _channels_to_split({}, {node.target: module}, split_ratio=self.split_ratio, split_input=True, split_criterion=self.split_criterion) - # split the channels in the module - # _split_channels(module, channels_to_split, grid_aware=self.grid_aware, split_input=True, bit_width=self.weight_bit_width) + if self.split_iteratively: + channels_to_split = split_channels_iteratively( + module, + split_ratio=self.split_ratio, + grid_aware=self.grid_aware, + bit_width=self.weight_bit_width) + else: + # we only split input channels + channels_to_split = _channels_to_split({}, {node.target: module}, + split_ratio=self.split_ratio, + split_input=True, + split_criterion=self.split_criterion) + # split the channels in the module + _split_channels( + module, + channels_to_split, + grid_aware=self.grid_aware, + split_input=True, + bit_width=self.weight_bit_width) # add node to split modules - channels_to_split = split_channels_iteratively( - module, - split_ratio=self.split_ratio, - grid_aware=self.grid_aware, - bit_width=self.weight_bit_width) split_modules[module] = torch.tensor(channels_to_split) for module, channels_to_split in split_modules.items(): diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py index 4e2262eb0..08c0636b7 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py @@ -242,6 +242,13 @@ 'split-input', default=False, help='Input Channels Splitting for channel splitting (default: disabled)') +add_bool_arg( + parser, + 'split-iteratively', + default=False, + help= + 'Input Channels are split iteratively, allows the same channel to be split multiple times (default: disabled)' +) def main(): @@ -313,6 +320,7 @@ def main(): f"Channel Splitting: {args.channel_splitting} - " f"Split Ratio: {args.split_ratio} - " f"Grid Aware: {args.grid_aware} - " + f"Split Iteratively: {args.split_iteratively} - " f"Merge BN: {not args.calibrate_bn}") # Get model-specific configurations about input shapes and normalization @@ -360,6 +368,7 @@ def main(): merge_bn=not args.calibrate_bn, channel_splitting=args.channel_splitting, channel_splitting_grid_aware=args.grid_aware, + channel_splitting_split_iteratively=args.split_iteratively, channel_splitting_ratio=args.split_ratio, channel_splitting_weight_bit_width=args.weight_bit_width) else: