diff --git a/src/brevitas/quant/solver/parameter.py b/src/brevitas/quant/solver/parameter.py index 67c3a56c1..cc36cc72d 100644 --- a/src/brevitas/quant/solver/parameter.py +++ b/src/brevitas/quant/solver/parameter.py @@ -128,23 +128,19 @@ def reshaped_scaling_shape(module): return module.weight.shape @value - def expanded_scaling_shape(module, input_channel_dim, group_size=None): + def expanded_scaling_shape(module, group_dim, group_size=None): assert group_size is not None, "Per Group scaling requires group size" size = list(module.weight.shape) - size[input_channel_dim] = (size[input_channel_dim] + group_size - 1) // group_size - size.insert(input_channel_dim + 1, group_size) + size[group_dim] = (size[group_dim] + group_size - 1) // group_size + size.insert(group_dim + 1, group_size) return size @value - def input_channel_dim(module): - return 1 if not hasattr(module, 'transposed') or not module.transposed else 0 - - @value - def padding(module, input_channel_dim, group_size): + def padding(module, group_dim, group_size): padding = [0, 0] * len(module.weight.shape) size = list(module.weight.shape) - if size[input_channel_dim] % group_size != 0: - padding[2 * input_channel_dim] = group_size - size[input_channel_dim] % group_size + if size[group_dim] % group_size != 0: + padding[2 * group_dim] = group_size - size[group_dim] % group_size padding = list(reversed(padding)) return padding