Skip to content

Commit

Permalink
Use ops.rsqrt, improve normalization layers and enable ops fusion i…
Browse files Browse the repository at this point in the history
…n tflite (#892)

* Add `rsqrt` to numpy backend

* Improve normalization

* Fix order bug

* Update LayerNormalization

* Improve unit test coverage

* Use np native
  • Loading branch information
james77777778 authored Sep 16, 2023
1 parent 74a4e7f commit c663efd
Show file tree
Hide file tree
Showing 10 changed files with 143 additions and 60 deletions.
4 changes: 4 additions & 0 deletions keras_core/backend/numpy/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,3 +298,7 @@ def istft(
else:
end = expected_output_len
return x[..., start:end]


def rsqrt(x):
return 1.0 / np.sqrt(x)
31 changes: 19 additions & 12 deletions keras_core/layers/normalization/batch_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,21 +201,21 @@ def call(self, inputs, training=None, mask=None):
mean, variance = ops.moments(
inputs, axes=self._reduction_axes, keepdims=True
)
outputs = (inputs - mean) / ops.sqrt(variance + self.epsilon)
mean = ops.squeeze(mean, self._reduction_axes)
variance = ops.squeeze(variance, self._reduction_axes)
moving_mean = ops.cast(self.moving_mean, inputs.dtype)
moving_variance = ops.cast(self.moving_variance, inputs.dtype)
self.moving_mean.assign(
ops.cast(
moving_mean * self.momentum + mean * (1.0 - self.momentum),
moving_mean * self.momentum
+ ops.squeeze(mean, self._reduction_axes)
* (1.0 - self.momentum),
inputs.dtype,
)
)
self.moving_variance.assign(
ops.cast(
moving_variance * self.momentum
+ variance * (1.0 - self.momentum),
+ ops.squeeze(variance, self._reduction_axes)
* (1.0 - self.momentum),
inputs.dtype,
)
)
Expand All @@ -224,17 +224,24 @@ def call(self, inputs, training=None, mask=None):
moving_variance = ops.cast(self.moving_variance, inputs.dtype)
moving_mean = ops.reshape(moving_mean, broadcast_shape)
moving_variance = ops.reshape(moving_variance, broadcast_shape)
outputs = (inputs - moving_mean) / ops.sqrt(
moving_variance + self.epsilon
)
mean = moving_mean
variance = moving_variance

inv = ops.rsqrt(variance + self.epsilon)
if self.scale:
gamma = ops.reshape(self.gamma, broadcast_shape)
gamma = ops.cast(gamma, outputs.dtype)
outputs = outputs * gamma
gamma = ops.cast(gamma, inputs.dtype)
inv = inv * gamma

res = -mean * inv
if self.center:
beta = ops.reshape(self.beta, broadcast_shape)
beta = ops.cast(beta, outputs.dtype)
outputs = outputs + beta
beta = ops.cast(beta, inputs.dtype)
res = res + beta

# Note: Folding BatchNormalization depends on the precise order of ops
# that are generated by the expression below
outputs = inputs * inv + res
return ops.cast(outputs, input_dtype)

def get_config(self):
Expand Down
33 changes: 11 additions & 22 deletions keras_core/layers/normalization/group_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,37 +171,26 @@ def _apply_normalization(self, reshaped_inputs, input_shape):
axis = -2 if self.axis == -1 else self.axis - 1
group_reduction_axes.pop(axis)

broadcast_shape = self._create_broadcast_shape(input_shape)
mean, variance = ops.moments(
reshaped_inputs, axes=group_reduction_axes, keepdims=True
)
gamma, beta = self._get_reshaped_weights(input_shape)

# Compute the batch normalization.
inv = 1 / ops.sqrt(variance + self.epsilon)

if gamma is not None:
inv = ops.multiply(inv, gamma)

if beta is not None:
x = beta - ops.multiply(mean, inv)
else:
x = -ops.multiply(mean, inv)

normalized_inputs = reshaped_inputs * ops.cast(
inv, reshaped_inputs.dtype
) + ops.cast(x, reshaped_inputs.dtype)
normalized_inputs = ops.cast(normalized_inputs, reshaped_inputs.dtype)
return normalized_inputs

def _get_reshaped_weights(self, input_shape):
broadcast_shape = self._create_broadcast_shape(input_shape)
gamma = None
beta = None
inv = ops.rsqrt(variance + self.epsilon)
if self.scale:
gamma = ops.reshape(self.gamma, broadcast_shape)
gamma = ops.cast(gamma, reshaped_inputs.dtype)
inv = inv * gamma

res = -mean * inv
if self.center:
beta = ops.reshape(self.beta, broadcast_shape)
return gamma, beta
beta = ops.cast(beta, reshaped_inputs.dtype)
res = res + beta

normalized_inputs = reshaped_inputs * inv + res
return normalized_inputs

def _create_broadcast_shape(self, input_shape):
broadcast_shape = [1] * len(input_shape)
Expand Down
45 changes: 45 additions & 0 deletions keras_core/layers/normalization/group_normalization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,51 @@ def test_groupnorm(self):
supports_masking=True,
)

def test_undefined_dim_error(self):
inputs = layers.Input(shape=(2, 2, 2, None))
layer = layers.GroupNormalization()
with self.assertRaisesRegex(
ValueError,
(
"input tensor should have a defined dimension but the layer "
"received an input with shape"
),
):
_ = layer(inputs)

def test_groups_bigger_than_dim_error(self):
inputs = np.ones(shape=(2, 2, 2, 4))
layer = layers.GroupNormalization(groups=5)
with self.assertRaisesRegex(
ValueError,
"cannot be more than the number of channels",
):
_ = layer(inputs)

def test_groups_not_a_multiple_of_dim_error(self):
inputs = np.ones(shape=(2, 2, 2, 4))
layer = layers.GroupNormalization(groups=3)
with self.assertRaisesRegex(
ValueError,
"must be a multiple of the number of channels",
):
_ = layer(inputs)

def test_groups_instance_norm(self):
# GroupNormalization with groups=-1 will become InstanceNormalization
instance_norm_layer_1 = layers.GroupNormalization(
groups=-1, axis=-1, scale=False, center=False
)
instance_norm_layer_2 = layers.GroupNormalization(
groups=4, axis=-1, scale=False, center=False
)
inputs = np.array([[[-1.0, 1.0, 0, 2.0], [1.0, 3.0, -4, -2.0]]])

outputs_1 = instance_norm_layer_1(inputs)
outputs_2 = instance_norm_layer_2(inputs)

self.assertAllClose(outputs_1, outputs_2)

def test_correctness_instance_norm(self):
instance_norm_layer = layers.GroupNormalization(
groups=4, axis=-1, scale=False, center=False
Expand Down
38 changes: 17 additions & 21 deletions keras_core/layers/normalization/layer_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,33 +206,29 @@ def _broadcast(v):
if self.rms_scaling:
# Calculate outputs with only variance and gamma if rms scaling
# is enabled
# Calculate the variance along last axis (layer activations).
# Calculate the variance along self.axis (layer activations).
variance = ops.var(inputs, axis=self.axis, keepdims=True)
inv = 1 / ops.sqrt(variance + self.epsilon)
outputs = inputs * ops.cast(inv, inputs.dtype) * self.gamma
inv = ops.rsqrt(variance + self.epsilon)

outputs = inputs * inv * ops.cast(self.gamma, inputs.dtype)
else:
# Calculate the mean & variance along last axis (layer activations).
# Calculate the mean & variance along self.axis (layer activations).
mean, variance = ops.moments(inputs, axes=self.axis, keepdims=True)
inv = 1 / ops.sqrt(variance + self.epsilon)
scale, offset = _broadcast(self.gamma), _broadcast(self.beta)
if scale is not None:
scale = ops.cast(scale, inputs.dtype)
inv = inv * scale
x = -mean * inv
if offset is not None:
offset = ops.cast(offset, inputs.dtype)
x = offset + x

outputs = inputs * ops.cast(inv, inputs.dtype) + ops.cast(
x, inputs.dtype
)
gamma, beta = _broadcast(self.gamma), _broadcast(self.beta)

inv = ops.rsqrt(variance + self.epsilon)
if gamma is not None:
gamma = ops.cast(gamma, inputs.dtype)
inv = inv * gamma

outputs = ops.cast(outputs, input_dtype)
res = -mean * inv
if beta is not None:
beta = ops.cast(beta, inputs.dtype)
res = res + beta

# If some components of the shape got lost due to adjustments, fix that.
outputs = ops.reshape(outputs, ops.shape(inputs))
outputs = inputs * inv + res

return outputs
return ops.cast(outputs, input_dtype)

def compute_output_shape(self, input_shape):
return input_shape
Expand Down
10 changes: 10 additions & 0 deletions keras_core/layers/normalization/layer_normalization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,16 @@ def test_ln_basics(self):
supports_masking=True,
)

def test_invalid_axis(self):
with self.assertRaisesRegex(
TypeError,
(
"Expected an int or a list/tuple of ints for the argument "
"'axis'"
),
):
layers.LayerNormalization(axis={"axis": -1})

def test_correctness(self):
layer = layers.LayerNormalization(dtype="float32")
layer.build(input_shape=(2, 2, 2))
Expand Down
26 changes: 26 additions & 0 deletions keras_core/layers/normalization/spectral_normalization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,32 @@ def test_basic_spectralnorm(self):
expected_num_losses=0,
supports_masking=False,
)
self.run_layer_test(
layers.SpectralNormalization,
init_kwargs={"layer": layers.Embedding(10, 4)},
input_data=np.random.randint(10, size=(10,)),
expected_output_shape=(10, 4),
expected_num_trainable_weights=1,
expected_num_non_trainable_weights=1,
expected_num_seed_generators=0,
expected_num_losses=0,
supports_masking=False,
run_training_check=False,
)

def test_invalid_power_iterations(self):
with self.assertRaisesRegex(
ValueError, "`power_iterations` should be greater than zero."
):
layers.SpectralNormalization(layers.Dense(2), power_iterations=0)

def test_invalid_layer(self):
layer = layers.SpectralNormalization(layers.ReLU())
inputs = np.ones(shape=(4, 2))
with self.assertRaisesRegex(
ValueError, "object has no attribute 'kernel' nor 'embeddings'"
):
layer(inputs)

def test_apply_layer(self):
images = np.ones((1, 2, 2, 1))
Expand Down
2 changes: 1 addition & 1 deletion keras_core/layers/normalization/unit_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def call(self, inputs):
x = ops.cast(inputs, self.compute_dtype)

square_sum = ops.sum(ops.square(x), axis=self.axis, keepdims=True)
x_inv_norm = 1 / ops.sqrt(ops.maximum(square_sum, 1e-12))
x_inv_norm = ops.rsqrt(ops.maximum(square_sum, 1e-12))
return ops.multiply(x, x_inv_norm)

def compute_output_shape(self, input_shape):
Expand Down
10 changes: 10 additions & 0 deletions keras_core/layers/normalization/unit_normalization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,16 @@ def test_un_basics(self):
supports_masking=True,
)

def test_invalid_axis(self):
with self.assertRaisesRegex(
TypeError,
(
"Invalid value for `axis` argument: expected an int or a "
"list/tuple of ints."
),
):
layers.UnitNormalization(axis={"axis": -1})

def test_correctness(self):
layer = layers.UnitNormalization(axis=-1)
inputs = np.random.normal(size=(2, 3))
Expand Down
4 changes: 0 additions & 4 deletions keras_core/ops/math_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,10 +831,6 @@ def test_istft(
ref = ref[..., truncated_len:-truncated_len]
self.assertAllClose(output, ref, atol=1e-5, rtol=1e-5)

@pytest.mark.skipif(
backend.backend() == "numpy",
reason="Numpy does not support rsqrt.",
)
def test_rsqrt(self):
x = np.array([[1, 4, 9], [16, 25, 36]], dtype="float32")
self.assertAllClose(kmath.rsqrt(x), 1 / np.sqrt(x))
Expand Down

0 comments on commit c663efd

Please sign in to comment.