From 778ccd72fe5d74e8eedc7d38dfb57561821b7851 Mon Sep 17 00:00:00 2001 From: Mohamed Abu El-Nasr <64566340+abuelnasr0@users.noreply.github.com> Date: Tue, 7 May 2024 00:33:49 +0300 Subject: [PATCH] Fix rope scaling factor (#1605) * Fix rope scaling factor * Fix format * Add tests * Fix format --- .../src/layers/modeling/rotary_embedding.py | 13 ++-- .../layers/modeling/rotary_embedding_test.py | 73 +++++++++++++++++++ 2 files changed, 80 insertions(+), 6 deletions(-) diff --git a/keras_nlp/src/layers/modeling/rotary_embedding.py b/keras_nlp/src/layers/modeling/rotary_embedding.py index 8e375ba5a..515f60667 100644 --- a/keras_nlp/src/layers/modeling/rotary_embedding.py +++ b/keras_nlp/src/layers/modeling/rotary_embedding.py @@ -35,7 +35,8 @@ class RotaryEmbedding(keras.layers.Layer): Args: max_wavelength: int. The maximum angular wavelength of the sine/cosine curves. - scaling_factor: float. The scaling factor used to scale frequency range. + scaling_factor: float. The scaling factor used to scale positions of + the tokens. sequence_axis: int. Sequence axis in the input tensor. feature_axis: int. Feature axis in the input tensor. **kwargs: other keyword arguments passed to `keras.layers.Layer`, @@ -125,6 +126,7 @@ def _compute_cos_sin_embedding(self, inputs, start_index=0, positions=None): else: positions = ops.cast(positions, "float32") + positions = positions / ops.cast(self.scaling_factor, "float32") freq = ops.einsum("i,j->ij", positions, inverse_freq) embedding = ops.stack((freq, freq), axis=-2) embedding = ops.reshape( @@ -143,12 +145,11 @@ def _compute_cos_sin_embedding(self, inputs, start_index=0, positions=None): return cos_emb, sin_emb def _get_inverse_freq(self, rotary_dim): - freq_range = ops.arange(0, rotary_dim, 2, dtype="float32") - freq_range = freq_range / ops.cast(self.scaling_factor, "float32") - inverse_freq = 1.0 / ( - self.max_wavelength - ** (freq_range / ops.cast(rotary_dim, "float32")) + freq_range = ops.divide( + ops.arange(0, rotary_dim, 2, dtype="float32"), + ops.cast(rotary_dim, "float32"), ) + inverse_freq = 1.0 / (self.max_wavelength**freq_range) return inverse_freq def get_config(self): diff --git a/keras_nlp/src/layers/modeling/rotary_embedding_test.py b/keras_nlp/src/layers/modeling/rotary_embedding_test.py index 3ce776323..6b10cb053 100644 --- a/keras_nlp/src/layers/modeling/rotary_embedding_test.py +++ b/keras_nlp/src/layers/modeling/rotary_embedding_test.py @@ -168,3 +168,76 @@ def test_positions_array(self): got = layer(x, positions=positions) np.testing.assert_allclose(expected, ops.convert_to_numpy(got)) + + def test_rope_scaling(self): + # Reference values computed from Huggingface llama implementation + # With `scaling_factor` = 2.0 + # from transformers.models.llama.modeling_llama import ( + # LlamaLinearScalingRotaryEmbedding,apply_rotary_pos_emb + # ) + # import torch + # torch.set_printoptions(precision=9) + # rotary_emb = LlamaLinearScalingRotaryEmbedding( + # dim=4, max_position_embeddings=3, scaling_factor=2.0 + # ) + # query = torch.ones((1, 2, 3, 4)) # [bsz, num_heads, seq_len, head_dim] + # cos, sin = rotary_emb( + # query, torch.unsqueeze(torch.arange(3, dtype=torch.int32), 0) + # ) + # query, _ = apply_rotary_pos_emb(query, query, cos, sin) + # print(query.transpose(1, 2)) + expected = [ + [ + [ + [1.000000000, 1.000000000, 1.000000000, 1.000000000], + [1.000000000, 1.000000000, 1.000000000, 1.000000000], + ], + [ + [0.398157001, 0.994987488, 1.357008100, 1.004987478], + [0.398157001, 0.994987488, 1.357008100, 1.004987478], + ], + [ + [-0.301168621, 0.989950180, 1.381773233, 1.009949803], + [-0.301168621, 0.989950180, 1.381773233, 1.009949803], + ], + ] + ] + + layer = RotaryEmbedding(scaling_factor=2.0) + self.assertAllClose( + layer(ops.ones((1, 3, 2, 4))), + ops.convert_to_tensor(expected), + ) + + def test_rope_scaling_with_kv_cache(self): + # Reference values computed from Huggingface llama implementation + # With `scaling_factor` = 5.0 + # from transformers.models.llama.modeling_llama import ( + # LlamaLinearScalingRotaryEmbedding,apply_rotary_pos_emb + # ) + # import torch + # torch.set_printoptions(precision=9) + # rotary_emb = LlamaLinearScalingRotaryEmbedding( + # dim=4, max_position_embeddings=3, scaling_factor=5.0 + # ) + + # query = torch.ones((1, 2, 1, 4)) # [bsz, num_heads, seq_len, head_dim] + # cos, sin = rotary_emb( + # query, torch.unsqueeze(torch.arange(12, 13, dtype=torch.int32), 0) + # ) + # query, _ = apply_rotary_pos_emb(query, query, cos, sin) + # query.transpose(1, 2) + expected = [ + [ + [ + [-1.412856817, 0.975714266, -0.061930716, 1.023709655], + [-1.412856817, 0.975714266, -0.061930716, 1.023709655], + ] + ] + ] + + layer = RotaryEmbedding(scaling_factor=5.0) + self.assertAllClose( + layer(ops.ones((1, 1, 2, 4)), start_index=12), + ops.convert_to_tensor(expected), + )