Skip to content

Commit

Permalink
Status update Channel splitting
Browse files Browse the repository at this point in the history
  • Loading branch information
fabianandresgrob committed Jan 2, 2024
1 parent 5d7d0ac commit 8180eb0
Show file tree
Hide file tree
Showing 7 changed files with 340 additions and 74 deletions.
32 changes: 17 additions & 15 deletions src/brevitas/graph/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from brevitas.graph.standardize import TorchFunctionalToModule
from brevitas.nn import quant_layer
import brevitas.nn as qnn
from brevitas.ptq_algorithms.channel_splitting import ChannelSplitting
from brevitas.ptq_algorithms.channel_splitting import LayerwiseChannelSplitting
from brevitas.quant import Int8ActPerTensorFloat
from brevitas.quant import Int8ActPerTensorFloatMinMaxInit
from brevitas.quant import Int8WeightPerTensorFloat
Expand Down Expand Up @@ -268,33 +268,35 @@ def preprocess_for_quantize(
channel_splitting=False,
channel_splitting_ratio=0.02,
channel_splitting_grid_aware=False,
channel_splitting_criterion: str = 'maxabs'):
channel_splitting_criterion: str = 'maxabs',
channel_splitting_weight_bit_width: int = 8):

training_state = model.training
model.eval()

if trace_model:
model = symbolic_trace(model)
model = TorchFunctionalToModule().apply(model)
model = DuplicateSharedStatelessModule().apply(model)
# model = TorchFunctionalToModule().apply(model)
# model = DuplicateSharedStatelessModule().apply(model)
if relu6_to_relu:
model = ModuleToModuleByClass(nn.ReLU6, nn.ReLU).apply(model)
model = MeanMethodToAdaptiveAvgPool2d().apply(model)
model = CollapseConsecutiveConcats().apply(model)
model = MoveSplitBatchNormBeforeCat().apply(model)
# model = MeanMethodToAdaptiveAvgPool2d().apply(model)
# model = CollapseConsecutiveConcats().apply(model)
# model = MoveSplitBatchNormBeforeCat().apply(model)
if merge_bn:
model = MergeBatchNorm().apply(model)
model = RemoveStochasticModules().apply(model)
model = EqualizeGraph(
iterations=equalize_iters,
merge_bias=equalize_merge_bias,
bias_shrinkage=equalize_bias_shrinkage,
scale_computation_type=equalize_scale_computation).apply(model)
# model = RemoveStochasticModules().apply(model)
# model = EqualizeGraph(
# iterations=equalize_iters,
# merge_bias=equalize_merge_bias,
# bias_shrinkage=equalize_bias_shrinkage,
# scale_computation_type=equalize_scale_computation).apply(model)
if channel_splitting:
model = ChannelSplitting(
model = LayerwiseChannelSplitting(
split_ratio=channel_splitting_ratio,
grid_aware=channel_splitting_grid_aware,
split_criterion=channel_splitting_criterion).apply(model)
split_criterion=channel_splitting_criterion,
weight_bit_width=channel_splitting_weight_bit_width).apply(model)
model.train(training_state)
return model

Expand Down
1 change: 1 addition & 0 deletions src/brevitas/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,5 @@
from .quant_upsample import QuantUpsample
from .quant_upsample import QuantUpsamplingBilinear2d
from .quant_upsample import QuantUpsamplingNearest2d
from .split_layer import ChannelSplitModule
from .target import flexml
35 changes: 35 additions & 0 deletions src/brevitas/nn/split_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from inspect import signature

import torch

INPUT_NAMES = ['input', 'inp', 'query', 'x']


class ChannelSplitModule(torch.nn.Module):

def __init__(self, layer, channels_to_duplicate) -> None:
super().__init__()

self.layer = layer
self.channels_to_duplicate = channels_to_duplicate

def forward(self, *args, **kwargs):
# Convert args + kwargs + defaults into kwargs
bound_arguments = signature(self.layer.forward).bind(*args, **kwargs)
bound_arguments.apply_defaults()
kwargs = bound_arguments.arguments

possible_input_kwargs = INPUT_NAMES
input_kwarg = [x for x in kwargs.keys() if x in possible_input_kwargs][0]
x = kwargs[input_kwarg]
out = x

self.channels_to_duplicate = self.channels_to_duplicate.to(out.device)

channels = torch.index_select(out, dim=1, index=self.channels_to_duplicate)
out = torch.cat([out, channels], dim=1)

kwargs[input_kwarg] = out

out = self.layer(*kwargs.values())
return out
Loading

0 comments on commit 8180eb0

Please sign in to comment.