diff --git a/src/brevitas/graph/channel_splitting.py b/src/brevitas/graph/channel_splitting.py index 6cc649c2b..8978366eb 100644 --- a/src/brevitas/graph/channel_splitting.py +++ b/src/brevitas/graph/channel_splitting.py @@ -57,6 +57,7 @@ def _channels_to_split( return torch.unique(channels_to_split) +# decorator is needed to modify the weights in-place using a view @torch.no_grad() def _split_channels( module: nn.Module, @@ -253,7 +254,11 @@ def _clean_regions(regions: List[Region]) -> List[Region]: class GraphChannelSplitting(GraphTransform): - def __init__(self, split_ratio=0.02, split_criterion='maxabs', split_input=True): + def __init__( + self, + split_ratio: float = 0.02, + split_criterion: str = 'maxabs', + split_input: bool = True): super(GraphChannelSplitting, self).__init__() self.split_ratio = split_ratio @@ -262,7 +267,7 @@ def __init__(self, split_ratio=0.02, split_criterion='maxabs', split_input=True) def apply( self, - model, + model: GraphModule, return_regions: bool = False ) -> Union[Tuple[GraphModule, Set[Tuple[str]]], GraphModule]: regions = _extract_regions(model) diff --git a/src/brevitas/graph/quantize.py b/src/brevitas/graph/quantize.py index bd7bc04e6..4ac324c8b 100644 --- a/src/brevitas/graph/quantize.py +++ b/src/brevitas/graph/quantize.py @@ -265,8 +265,8 @@ def preprocess_for_quantize( merge_bn=True, equalize_bias_shrinkage: str = 'vaiq', equalize_scale_computation: str = 'maxabs', - channel_splitting_ratio=0.0, - channel_splitting_split_input=True, + channel_splitting_ratio: float = 0.0, + channel_splitting_split_input: bool = True, channel_splitting_criterion: str = 'maxabs'): training_state = model.training diff --git a/src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py b/src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py index 604c6bea5..668eee22c 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py +++ b/src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py @@ -105,8 +105,7 @@ def unique(sequence): 'accumulator_bit_width': [16], # Accumulator bit width, only in combination with GPFA2Q 'act_quant_percentile': [99.999], # Activation Quantization Percentile 'uint_sym_act_for_unsigned_values': [True], # Whether to use unsigned act quant when possible - 'channel_splitting_ratio': [0.0 - ], # Channel Splitting ratio, 0.0 means no splitting is performed + 'channel_splitting_ratio': [0.0], # Channel Splitting ratio, 0.0 means no splitting 'split_input': [True], # Whether to split the input channels when applying channel splitting 'merge_bn': [True]} # Whether to merge BN layers