Skip to content

Commit

Permalink
fix beta in fused BN
Browse files Browse the repository at this point in the history
  • Loading branch information
calad0i committed Jan 10, 2024
1 parent 5993704 commit 3a0f658
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions src/HGQ/layers/batchnorm_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,6 @@ def __init__(
gamma_constraint=None,
**kwargs,
):
if center and not getattr(self, "use_bias", False):
warn(f'`center` in fused BatchNorm can only be used if `use_bias` is True. Setting center to False.')
center = False

super().__init__(**kwargs)
self.axis = [axis] if not isinstance(axis, (list, tuple)) else list(axis)
Expand Down Expand Up @@ -56,6 +53,11 @@ def _post_build(self, input_shape):
self._reduction_axis = tuple([i for i in range(len(input_shape)) if i not in self.axis])
output_shape = self.compute_output_shape(input_shape)
shape = tuple([output_shape[i] for i in self.axis])

if self.center and not getattr(self, "use_bias", False):
warn(f'`center` in fused BatchNorm can only be used if `use_bias` is True. Setting center to False.', stacklevel=3)
self.center = False

if self.center:
self.bn_beta: tf.Variable = self.add_weight(
name="bn_beta",
Expand Down Expand Up @@ -103,7 +105,7 @@ def fused_bias(self):
if not self.center:
return self.bias
scale = self.bn_gamma * tf.math.rsqrt(self.moving_variance + self.epsilon)
return self.bias - self.moving_mean * scale
return self.bias - self.moving_mean * scale + self.bn_beta

def adapt_fused_bn_kernel_bw_bits(self, x: tf.Tensor):
"""Adapt the bitwidth of the kernel quantizer to the input tensor, such that each input is represented with approximately the same number of bits after fused batchnormalization."""
Expand Down

0 comments on commit 3a0f658

Please sign in to comment.