Skip to content

Commit

Permalink
Fix BN folding for PyTorch ConvTranspose2d with groups>1 (sony#907)
Browse files Browse the repository at this point in the history
Fix BN folding for PyTorch ConvTranspose2d with groups>1
  • Loading branch information
elad-c authored Dec 31, 2023
1 parent 7db1ae7 commit e3fae76
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from model_compression_toolkit.core.common import BaseNode
from model_compression_toolkit.core.common.substitutions.batchnorm_folding import BatchNormalizationFolding, BatchNormalizationForwardFolding
from model_compression_toolkit.core.pytorch.constants import KERNEL, BIAS, GAMMA, BETA, MOVING_MEAN, MOVING_VARIANCE, \
EPSILON, USE_BIAS, GROUPS, IN_CHANNELS
EPSILON, USE_BIAS, GROUPS, IN_CHANNELS, OUT_CHANNELS


def batchnorm_folding_node_matchers() -> [BaseNode, BaseNode]:
Expand Down Expand Up @@ -66,7 +66,15 @@ def update_kernel_for_bn_folding_fn(conv_node: BaseNode,
_scale = weights_scale[None, :, None, None]
else:
_scale = weights_scale[:, None, None, None]
return kernel * _scale, KERNEL
if conv_node.type == ConvTranspose2d and conv_node.framework_attr[GROUPS] > 1:
# PyTorch ConvTranspose2d kernel with groups stacks groups on in_channels axis, so need to reshape the kernel
# so the groups are stacked on the out_channels axis to match the scale vector (then reshape back to original
# shape)
_in_channels = int(conv_node.framework_attr[IN_CHANNELS]/conv_node.framework_attr[GROUPS])
_out_channels = conv_node.framework_attr[OUT_CHANNELS]
return (kernel.reshape((_in_channels, _out_channels, -1, 1)) * _scale).reshape(kernel.shape), KERNEL
else:
return kernel * _scale, KERNEL


def update_weights_for_bn_forward_folding_fn(conv_node: BaseNode,
Expand All @@ -85,7 +93,7 @@ def update_weights_for_bn_forward_folding_fn(conv_node: BaseNode,
Returns:
The modified convolution node's weight/kernel/
"""
if conv_node.type == Conv2d and conv_node.framework_attr['groups'] > 1:
if conv_node.type == Conv2d and conv_node.framework_attr[GROUPS] > 1:
bias_update = (kernel * bias_factor[:, None, None, None]).flatten()
_scale = weights_scale[:, None, None, None]
elif conv_node.type == ConvTranspose2d:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,14 @@ class BNFoldingNetTest(BasePytorchTest):
"""
def __init__(self, unit_test, test_layer, functional, fold_applied=True, float_reconstruction_error=1e-6):
super().__init__(unit_test, float_reconstruction_error)
self.input_channels = test_layer.in_channels
self.test_layer = test_layer
self.fold_applied = fold_applied
self.functional = functional

def create_inputs_shape(self):
return [[self.val_batch_size, self.input_channels, 32, 32]]

def create_feature_network(self, input_shape):
return BNFoldingNet(self.test_layer, self.functional, self.fold_applied)

Expand Down
2 changes: 2 additions & 0 deletions tests/pytorch_tests/model_tests/test_feature_models_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ def test_bn_folding(self):
BNFoldingNetTest(self, nn.Conv2d(3, 3, kernel_size=(3, 1), groups=3),
functional, fold_applied=False).run_test() # DW-Conv test
BNFoldingNetTest(self, nn.ConvTranspose2d(3, 2, kernel_size=(1, 3)), functional, fold_applied=False).run_test()
BNFoldingNetTest(self, nn.ConvTranspose2d(6, 9, kernel_size=(5, 4), groups=3), False).run_test()
BNFoldingNetTest(self, nn.ConvTranspose2d(3, 3, kernel_size=(4, 2), groups=3), False).run_test()

def test_bn_forward_folding(self):
"""
Expand Down

0 comments on commit e3fae76

Please sign in to comment.