Skip to content

Commit

Permalink
add flag for iterative channel splitting
Browse files Browse the repository at this point in the history
  • Loading branch information
fabianandresgrob committed Jan 2, 2024
1 parent 8180eb0 commit bc6e12f
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 10 deletions.
2 changes: 2 additions & 0 deletions src/brevitas/graph/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ def preprocess_for_quantize(
channel_splitting=False,
channel_splitting_ratio=0.02,
channel_splitting_grid_aware=False,
channel_splitting_split_iteratively=False,
channel_splitting_criterion: str = 'maxabs',
channel_splitting_weight_bit_width: int = 8):

Expand Down Expand Up @@ -296,6 +297,7 @@ def preprocess_for_quantize(
split_ratio=channel_splitting_ratio,
grid_aware=channel_splitting_grid_aware,
split_criterion=channel_splitting_criterion,
split_iteratively=channel_splitting_split_iteratively,
weight_bit_width=channel_splitting_weight_bit_width).apply(model)
model.train(training_state)
return model
Expand Down
36 changes: 26 additions & 10 deletions src/brevitas/ptq_algorithms/channel_splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,13 +392,19 @@ def split_channels_iteratively(module, split_ratio, grid_aware, bit_width):
class LayerwiseChannelSplitting(GraphTransform):

def __init__(
self, split_ratio=0.02, split_criterion='maxabs', grid_aware=False, weight_bit_width=8):
self,
split_ratio=0.02,
split_criterion='maxabs',
grid_aware=False,
split_iteratively=False,
weight_bit_width=8):
super(LayerwiseChannelSplitting, self).__init__()

self.grid_aware = grid_aware
self.split_ratio = split_ratio
self.split_criterion = split_criterion
self.weight_bit_width = weight_bit_width
self.split_iteratively = split_iteratively

def _is_supported_module(self, graph_model: GraphModule, node: Node) -> bool:
if node.op == 'call_module':
Expand All @@ -413,16 +419,26 @@ def apply(self, graph_model: GraphModule):
for node in graph_model.graph.nodes:
if self._is_supported_module(graph_model, node):
module = get_module(graph_model, node.target)
# we only split input channels
# channels_to_split = _channels_to_split({}, {node.target: module}, split_ratio=self.split_ratio, split_input=True, split_criterion=self.split_criterion)
# split the channels in the module
# _split_channels(module, channels_to_split, grid_aware=self.grid_aware, split_input=True, bit_width=self.weight_bit_width)
if self.split_iteratively:
channels_to_split = split_channels_iteratively(
module,
split_ratio=self.split_ratio,
grid_aware=self.grid_aware,
bit_width=self.weight_bit_width)
else:
# we only split input channels
channels_to_split = _channels_to_split({}, {node.target: module},
split_ratio=self.split_ratio,
split_input=True,
split_criterion=self.split_criterion)
# split the channels in the module
_split_channels(
module,
channels_to_split,
grid_aware=self.grid_aware,
split_input=True,
bit_width=self.weight_bit_width)
# add node to split modules
channels_to_split = split_channels_iteratively(
module,
split_ratio=self.split_ratio,
grid_aware=self.grid_aware,
bit_width=self.weight_bit_width)
split_modules[module] = torch.tensor(channels_to_split)

for module, channels_to_split in split_modules.items():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,13 @@
'split-input',
default=False,
help='Input Channels Splitting for channel splitting (default: disabled)')
add_bool_arg(
parser,
'split-iteratively',
default=False,
help=
'Input Channels are split iteratively, allows the same channel to be split multiple times (default: disabled)'
)


def main():
Expand Down Expand Up @@ -313,6 +320,7 @@ def main():
f"Channel Splitting: {args.channel_splitting} - "
f"Split Ratio: {args.split_ratio} - "
f"Grid Aware: {args.grid_aware} - "
f"Split Iteratively: {args.split_iteratively} - "
f"Merge BN: {not args.calibrate_bn}")

# Get model-specific configurations about input shapes and normalization
Expand Down Expand Up @@ -360,6 +368,7 @@ def main():
merge_bn=not args.calibrate_bn,
channel_splitting=args.channel_splitting,
channel_splitting_grid_aware=args.grid_aware,
channel_splitting_split_iteratively=args.split_iteratively,
channel_splitting_ratio=args.split_ratio,
channel_splitting_weight_bit_width=args.weight_bit_width)
else:
Expand Down

0 comments on commit bc6e12f

Please sign in to comment.