diff --git a/keras/src/layers/normalization/batch_normalization.py b/keras/src/layers/normalization/batch_normalization.py index efc92cd160a..8b0160344ee 100644 --- a/keras/src/layers/normalization/batch_normalization.py +++ b/keras/src/layers/normalization/batch_normalization.py @@ -166,6 +166,12 @@ def __init__( self.gamma_constraint = constraints.get(gamma_constraint) self.supports_masking = True + self.gamma = None + self.beta = None + self.moving_mean = None + self.moving_variance = None + self._reduction_axes = None + def build(self, input_shape): shape = (input_shape[self.axis],) if self.scale: @@ -202,9 +208,11 @@ def build(self, input_shape): trainable=False, autocast=False, ) + self.input_spec = InputSpec( ndim=len(input_shape), axes={self.axis: input_shape[self.axis]} ) + reduction_axes = list(range(len(input_shape))) del reduction_axes[self.axis] self._reduction_axes = reduction_axes @@ -230,10 +238,12 @@ def call(self, inputs, training=None, mask=None): # out BN for mixed precision. inputs = ops.cast(inputs, "float32") + moving_mean = ops.cast(self.moving_mean, inputs.dtype) + moving_variance = ops.cast(self.moving_variance, inputs.dtype) + if training and self.trainable: mean, variance = self._moments(inputs, mask) - moving_mean = ops.cast(self.moving_mean, inputs.dtype) - moving_variance = ops.cast(self.moving_variance, inputs.dtype) + self.moving_mean.assign( moving_mean * self.momentum + mean * (1.0 - self.momentum) ) @@ -242,8 +252,6 @@ def call(self, inputs, training=None, mask=None): + variance * (1.0 - self.momentum) ) else: - moving_mean = ops.cast(self.moving_mean, inputs.dtype) - moving_variance = ops.cast(self.moving_variance, inputs.dtype) mean = moving_mean variance = moving_variance