Skip to content

Commit

Permalink
Fixes for the LLaMA backbone + add dropout (keras-team#1499)
Browse files Browse the repository at this point in the history
* Firxes for the LLaMA backbone + add dropout

* Address review comments

CachedLlamaAttention -> LlamaAttention and make parameter state public in the attention layer

* Remove self._hidden_dim and self._head_dim
  • Loading branch information
tirthasheshpatel committed Mar 13, 2024
1 parent 5136876 commit a59a26f
Show file tree
Hide file tree
Showing 4 changed files with 182 additions and 122 deletions.
120 changes: 63 additions & 57 deletions keras_nlp/models/llama/llama_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,34 +18,33 @@


class LlamaAttention(keras.layers.Layer):
"""Grouped query attention for Llama models"""
"""A cached grounded query attention layer with sliding window."""

def __init__(
self,
num_query_heads,
num_key_value_heads,
rope_max_wavelength=10000,
rope_scaling_factor=1.0,
kernel_initializer="glorot_uniform",
rope_max_wavelength=10000,
max_sequence_length=512,
dropout=0,
**kwargs,
):
super().__init__(**kwargs)
self.num_query_heads = num_query_heads
self.num_key_value_heads = num_key_value_heads
self.dropout = dropout

self.num_key_value_groups = num_query_heads // num_key_value_heads
self.rope_max_wavelength = rope_max_wavelength

self.kernel_initializer = keras.initializers.get(kernel_initializer)
self.max_sequence_length = max_sequence_length
self.kernel_initializer = keras.initializers.get(
clone_initializer(kernel_initializer)
)

self.rope_scaling_factor = rope_scaling_factor
self.rope_max_wavelength = rope_max_wavelength

def build(self, inputs_shape):
self.hidden_dim = inputs_shape[-1]
self.attn_head_size = self.hidden_dim // self.num_query_heads

# Einsum variables:
# b = batch size
# q = query length
Expand All @@ -54,27 +53,40 @@ def build(self, inputs_shape):
# u = num query heads
# v = num key/value heads
# h = head dim
hidden_dim = inputs_shape[-1]
head_dim = hidden_dim // self.num_query_heads
self._norm_factor = ops.sqrt(ops.cast(head_dim, self.compute_dtype))

self._query_dense = keras.layers.EinsumDense(
equation="bqm,muh->bquh",
output_shape=(None, self.num_query_heads, self.attn_head_size),
kernel_initializer=clone_initializer(self.kernel_initializer),
output_shape=(None, self.num_query_heads, head_dim),
kernel_initializer=self.kernel_initializer,
dtype=self.dtype_policy,
name="query",
)
self._query_dense.build(inputs_shape)

self._key_dense = keras.layers.EinsumDense(
equation="bkm,mvh->bkvh",
output_shape=(None, self.num_key_value_heads, self.attn_head_size),
kernel_initializer=clone_initializer(self.kernel_initializer),
output_shape=(
None,
self.num_key_value_heads,
head_dim,
),
kernel_initializer=self.kernel_initializer,
dtype=self.dtype_policy,
name="key",
)
self._key_dense.build(inputs_shape)

self._value_dense = keras.layers.EinsumDense(
equation="bkm,mvh->bkvh",
output_shape=(None, self.num_key_value_heads, self.attn_head_size),
kernel_initializer=clone_initializer(self.kernel_initializer),
output_shape=(
None,
self.num_key_value_heads,
head_dim,
),
kernel_initializer=self.kernel_initializer,
dtype=self.dtype_policy,
name="value",
)
Expand All @@ -86,21 +98,28 @@ def build(self, inputs_shape):
name="attention_softmax",
)

self._dropout_layer = keras.layers.Dropout(
rate=self.dropout,
dtype=self.dtype_policy,
)

self._output_dense = keras.layers.EinsumDense(
equation="bqm,mh->bqh",
output_shape=(None, self.hidden_dim),
kernel_initializer=clone_initializer(self.kernel_initializer),
equation="bquh,uhm->bqm",
output_shape=(None, hidden_dim),
kernel_initializer=self.kernel_initializer,
dtype=self.dtype_policy,
name="attention_output",
)
self._output_dense.build(inputs_shape)
self._output_dense.build((None, None, self.num_query_heads, head_dim))

self._rotary_embedding_layer = RotaryEmbedding(
self.rotary_embedding_layer = RotaryEmbedding(
max_wavelength=self.rope_max_wavelength,
scaling_factor=self.rope_scaling_factor,
dtype=self.dtype_policy,
)
self._rotary_embedding_layer.build(inputs_shape)

self._dot_product_equation = "bquh,bkuh->buqk"
self._combine_equation = "buqk,bkuh->bquh"

self.built = True

Expand All @@ -110,6 +129,7 @@ def call(
attention_mask=None,
cache=None,
cache_update_index=None,
training=None,
):
query = self._query_dense(hidden_states)

Expand All @@ -136,75 +156,61 @@ def call(
key = self._key_dense(hidden_states)
value = self._value_dense(hidden_states)

query = self._rotary_embedding_layer(query)
key = self._rotary_embedding_layer(key)
query = self.rotary_embedding_layer(query)
key = self.rotary_embedding_layer(key)

key = ops.tile(key, [1, 1, self.num_key_value_groups, 1])
value = ops.tile(value, [1, 1, self.num_key_value_groups, 1])
# [batch_shape, seq_len, num_key_value_heads, head_dim]
# -> [batch_shape, seq_len, num_heads, head_dim]
key = ops.repeat(key, repeats=self.num_key_value_groups, axis=2)
value = ops.repeat(value, repeats=self.num_key_value_groups, axis=2)

attention_output, attention_scores = self._compute_attention(
attention_output = self._compute_attention(
query, key, value, attention_mask
)

attention_output_shape = ops.shape(attention_output)

attention_output = ops.reshape(
attention_output,
[
attention_output_shape[0],
attention_output_shape[1],
self.hidden_dim,
],
attention_output = self._dropout_layer(
attention_output, training=training
)

attention_output = self._output_dense(attention_output)

if cache is not None:
return (attention_output, cache)
return attention_output, cache
return attention_output

def _masked_softmax(self, attention_scores, attention_mask=None):
if attention_mask is not None:
mask_expansion_axis = -3
for _ in range(
len(attention_scores.shape) - len(attention_mask.shape)
):
attention_mask = ops.expand_dims(
attention_mask, axis=mask_expansion_axis
)
return self._softmax(attention_scores, attention_mask)
return self._softmax(
attention_scores, attention_mask[:, None, :, :]
)
return self._softmax(attention_scores)

def _compute_attention(self, query, key, value, attention_mask=None):
attention_scores = ops.einsum("aecd,abcd->acbe", key, query)

norm_factor = ops.sqrt(
ops.convert_to_tensor(self.attn_head_size, self.compute_dtype)
)
attention_scores = ops.einsum(self._dot_product_equation, query, key)

attention_scores /= norm_factor
attention_scores = attention_scores / self._norm_factor
attention_scores = self._masked_softmax(
attention_scores, attention_mask
)
attention_scores = ops.cast(attention_scores, self.compute_dtype)
attention_output = ops.einsum(
"acbe,aecd->abcd", attention_scores, value
self._combine_equation, attention_scores, value
)

return attention_output, attention_scores
return attention_output

def get_config(self):
config = super().get_config()
config.update(
{
"num_query_heads": self.num_query_heads,
"hidden_dim": self.hidden_dim,
"num_key_value_heads": self.num_key_value_heads,
"rope_max_wavelength": self.rope_max_wavelength,
"rope_scaling_factor": self.rope_scaling_factor,
"kernel_initializer": keras.initializers.serialize(
self.kernel_initializer
),
"rope_max_wavelength": self.rope_max_wavelength,
"rope_scaling_factor": self.rope_scaling_factor,
"num_key_value_heads": self.num_key_value_heads,
"max_sequence_length": self.max_sequence_length,
"dropout": self.dropout,
}
)
return config
Loading

0 comments on commit a59a26f

Please sign in to comment.