Skip to content

Commit

Permalink
remove batchnorm support
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Dec 21, 2023
1 parent 704772e commit a429084
Showing 1 changed file with 1 addition and 8 deletions.
9 changes: 1 addition & 8 deletions src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ def _no_equalize():
axis = _get_input_axis(module)
act_sink_axes[name] = _get_act_axis(module)
# If module is not supported, do not perform graph equalization
if not isinstance(module, _supported_layers):
if not isinstance(module, _supported_layers) or module in _batch_norm:
return _no_equalize()
# For MultiheadAttention, we support only self-attetion
if isinstance(module, nn.MultiheadAttention) and module.in_proj_weight is not None:
Expand Down Expand Up @@ -553,13 +553,6 @@ def _no_equalize():
# one (i.e., no equalization)
partial_scaling[indexes.start:indexes.end] = scaling_factors[indexes.offset:indexes.offset +
channel_range]
if isinstance(module, _batch_norm):
# We re-compute the bias as function of running_mean and running_var to adjust the
# additive factor for equalization.
additive_factor = module.running_mean.data * module.weight.data / torch.sqrt(
module.running_var.data + module.eps)
_update_weights(
module, module.bias.clone() + additive_factor * (partial_scaling - 1), attr='bias')
_update_weights(
module,
module.weight.clone() * torch.reshape(partial_scaling, sink_broadcast_size),
Expand Down

0 comments on commit a429084

Please sign in to comment.