Skip to content

Commit

Permalink
Fix rope scaling factor (keras-team#1605)
Browse files Browse the repository at this point in the history
* Fix rope scaling factor

* Fix format

* Add tests

* Fix format
  • Loading branch information
abuelnasr0 committed May 6, 2024
1 parent 026c6ed commit 778ccd7
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 6 deletions.
13 changes: 7 additions & 6 deletions keras_nlp/src/layers/modeling/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ class RotaryEmbedding(keras.layers.Layer):
Args:
max_wavelength: int. The maximum angular wavelength of the sine/cosine
curves.
scaling_factor: float. The scaling factor used to scale frequency range.
scaling_factor: float. The scaling factor used to scale positions of
the tokens.
sequence_axis: int. Sequence axis in the input tensor.
feature_axis: int. Feature axis in the input tensor.
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
Expand Down Expand Up @@ -125,6 +126,7 @@ def _compute_cos_sin_embedding(self, inputs, start_index=0, positions=None):
else:
positions = ops.cast(positions, "float32")

positions = positions / ops.cast(self.scaling_factor, "float32")
freq = ops.einsum("i,j->ij", positions, inverse_freq)
embedding = ops.stack((freq, freq), axis=-2)
embedding = ops.reshape(
Expand All @@ -143,12 +145,11 @@ def _compute_cos_sin_embedding(self, inputs, start_index=0, positions=None):
return cos_emb, sin_emb

def _get_inverse_freq(self, rotary_dim):
freq_range = ops.arange(0, rotary_dim, 2, dtype="float32")
freq_range = freq_range / ops.cast(self.scaling_factor, "float32")
inverse_freq = 1.0 / (
self.max_wavelength
** (freq_range / ops.cast(rotary_dim, "float32"))
freq_range = ops.divide(
ops.arange(0, rotary_dim, 2, dtype="float32"),
ops.cast(rotary_dim, "float32"),
)
inverse_freq = 1.0 / (self.max_wavelength**freq_range)
return inverse_freq

def get_config(self):
Expand Down
73 changes: 73 additions & 0 deletions keras_nlp/src/layers/modeling/rotary_embedding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,76 @@ def test_positions_array(self):
got = layer(x, positions=positions)

np.testing.assert_allclose(expected, ops.convert_to_numpy(got))

def test_rope_scaling(self):
# Reference values computed from Huggingface llama implementation
# With `scaling_factor` = 2.0
# from transformers.models.llama.modeling_llama import (
# LlamaLinearScalingRotaryEmbedding,apply_rotary_pos_emb
# )
# import torch
# torch.set_printoptions(precision=9)
# rotary_emb = LlamaLinearScalingRotaryEmbedding(
# dim=4, max_position_embeddings=3, scaling_factor=2.0
# )
# query = torch.ones((1, 2, 3, 4)) # [bsz, num_heads, seq_len, head_dim]
# cos, sin = rotary_emb(
# query, torch.unsqueeze(torch.arange(3, dtype=torch.int32), 0)
# )
# query, _ = apply_rotary_pos_emb(query, query, cos, sin)
# print(query.transpose(1, 2))
expected = [
[
[
[1.000000000, 1.000000000, 1.000000000, 1.000000000],
[1.000000000, 1.000000000, 1.000000000, 1.000000000],
],
[
[0.398157001, 0.994987488, 1.357008100, 1.004987478],
[0.398157001, 0.994987488, 1.357008100, 1.004987478],
],
[
[-0.301168621, 0.989950180, 1.381773233, 1.009949803],
[-0.301168621, 0.989950180, 1.381773233, 1.009949803],
],
]
]

layer = RotaryEmbedding(scaling_factor=2.0)
self.assertAllClose(
layer(ops.ones((1, 3, 2, 4))),
ops.convert_to_tensor(expected),
)

def test_rope_scaling_with_kv_cache(self):
# Reference values computed from Huggingface llama implementation
# With `scaling_factor` = 5.0
# from transformers.models.llama.modeling_llama import (
# LlamaLinearScalingRotaryEmbedding,apply_rotary_pos_emb
# )
# import torch
# torch.set_printoptions(precision=9)
# rotary_emb = LlamaLinearScalingRotaryEmbedding(
# dim=4, max_position_embeddings=3, scaling_factor=5.0
# )

# query = torch.ones((1, 2, 1, 4)) # [bsz, num_heads, seq_len, head_dim]
# cos, sin = rotary_emb(
# query, torch.unsqueeze(torch.arange(12, 13, dtype=torch.int32), 0)
# )
# query, _ = apply_rotary_pos_emb(query, query, cos, sin)
# query.transpose(1, 2)
expected = [
[
[
[-1.412856817, 0.975714266, -0.061930716, 1.023709655],
[-1.412856817, 0.975714266, -0.061930716, 1.023709655],
]
]
]

layer = RotaryEmbedding(scaling_factor=5.0)
self.assertAllClose(
layer(ops.ones((1, 1, 2, 4)), start_index=12),
ops.convert_to_tensor(expected),
)

0 comments on commit 778ccd7

Please sign in to comment.