Skip to content

Commit

Permalink
Fix the rotary embedding computation in LLaMA (keras-team#1544)
Browse files Browse the repository at this point in the history
* Fix rotary embedding computation in LLaMA

Also run the reverse embedding stage in compute_dtype instead of full-precision. This is how HF does it, so helps get the numerics closer

* Don't cast start_index; save rope keys

* Remove underscore from num_key_value_heads
  • Loading branch information
tirthasheshpatel authored Apr 3, 2024
1 parent f246a4e commit 9ac3335
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 12 deletions.
22 changes: 15 additions & 7 deletions keras_nlp/models/llama/llama_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,17 +131,29 @@ def call(
cache_update_index=None,
training=None,
):
start_index = (
cache_update_index if cache_update_index is not None else 0
)

query = self._query_dense(hidden_states)

# Compute RoPE for queries
query = self.rotary_embedding_layer(query, start_index=start_index)

def _compute_key_value(x):
key, value = self._key_dense(x), self._value_dense(x)
# Compute RoPE for keys
key = self.rotary_embedding_layer(key, start_index=start_index)
return key, value

if cache is not None:
key_cache = cache[:, 0, ...]
value_cache = cache[:, 1, ...]
if cache_update_index is None:
key = key_cache
value = value_cache
else:
key_update = self._key_dense(hidden_states)
value_update = self._value_dense(hidden_states)
key_update, value_update = _compute_key_value(hidden_states)
start = [0, cache_update_index, 0, 0]
key = ops.slice_update(key_cache, start, key_update)
value = ops.slice_update(value_cache, start, value_update)
Expand All @@ -153,11 +165,7 @@ def call(
f"`None`. Received: cache={cache}, "
f"cache_update_index={cache_update_index}"
)
key = self._key_dense(hidden_states)
value = self._value_dense(hidden_states)

query = self.rotary_embedding_layer(query)
key = self.rotary_embedding_layer(key)
key, value = _compute_key_value(hidden_states)

# [batch_shape, seq_len, num_key_value_heads, head_dim]
# -> [batch_shape, seq_len, num_heads, head_dim]
Expand Down
1 change: 1 addition & 0 deletions keras_nlp/models/llama/llama_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def __init__(
tie_weights=False,
embeddings_initializer=_llama_kernel_initializer(stddev=0.01),
dtype=dtype,
reverse_dtype=dtype,
name="token_embedding",
)
self.transformer_layers = []
Expand Down
5 changes: 0 additions & 5 deletions keras_nlp/models/mistral/mistral_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,6 @@ def call(
start_index = (
cache_update_index if cache_update_index is not None else 0
)
# If `cache_update_index` is a tensor, RotaryEmbedding expects it
# to have dtype `self.compute_dtype`.
start_index = ops.cast(
start_index, self.rotary_embedding_layer.compute_dtype
)

query = self._query_dense(hidden_states)

Expand Down
1 change: 1 addition & 0 deletions keras_nlp/models/mistral/mistral_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def __init__(
tie_weights=False,
embeddings_initializer=_mistral_kernel_initializer(stddev=0.01),
dtype=dtype,
reverse_dtype=dtype,
name="token_embedding",
)
self.transformer_layers = []
Expand Down

0 comments on commit 9ac3335

Please sign in to comment.