From a59a26fc7de85b9652b434cae380a73f748634f6 Mon Sep 17 00:00:00 2001 From: Tirth Patel Date: Wed, 13 Mar 2024 14:06:01 -0700 Subject: [PATCH] Fixes for the LLaMA backbone + add dropout (#1499) * Firxes for the LLaMA backbone + add dropout * Address review comments CachedLlamaAttention -> LlamaAttention and make parameter state public in the attention layer * Remove self._hidden_dim and self._head_dim --- keras_nlp/models/llama/llama_attention.py | 120 +++++++++--------- keras_nlp/models/llama/llama_backbone.py | 108 ++++++++++------ keras_nlp/models/llama/llama_backbone_test.py | 1 - keras_nlp/models/llama/llama_decoder.py | 75 +++++++---- 4 files changed, 182 insertions(+), 122 deletions(-) diff --git a/keras_nlp/models/llama/llama_attention.py b/keras_nlp/models/llama/llama_attention.py index 529e73b009..33ffcef209 100644 --- a/keras_nlp/models/llama/llama_attention.py +++ b/keras_nlp/models/llama/llama_attention.py @@ -18,34 +18,33 @@ class LlamaAttention(keras.layers.Layer): - """Grouped query attention for Llama models""" + """A cached grounded query attention layer with sliding window.""" def __init__( self, num_query_heads, num_key_value_heads, + rope_max_wavelength=10000, rope_scaling_factor=1.0, kernel_initializer="glorot_uniform", - rope_max_wavelength=10000, - max_sequence_length=512, + dropout=0, **kwargs, ): super().__init__(**kwargs) self.num_query_heads = num_query_heads self.num_key_value_heads = num_key_value_heads + self.dropout = dropout self.num_key_value_groups = num_query_heads // num_key_value_heads + self.rope_max_wavelength = rope_max_wavelength - self.kernel_initializer = keras.initializers.get(kernel_initializer) - self.max_sequence_length = max_sequence_length + self.kernel_initializer = keras.initializers.get( + clone_initializer(kernel_initializer) + ) self.rope_scaling_factor = rope_scaling_factor - self.rope_max_wavelength = rope_max_wavelength def build(self, inputs_shape): - self.hidden_dim = inputs_shape[-1] - self.attn_head_size = self.hidden_dim // self.num_query_heads - # Einsum variables: # b = batch size # q = query length @@ -54,18 +53,27 @@ def build(self, inputs_shape): # u = num query heads # v = num key/value heads # h = head dim + hidden_dim = inputs_shape[-1] + head_dim = hidden_dim // self.num_query_heads + self._norm_factor = ops.sqrt(ops.cast(head_dim, self.compute_dtype)) + self._query_dense = keras.layers.EinsumDense( equation="bqm,muh->bquh", - output_shape=(None, self.num_query_heads, self.attn_head_size), - kernel_initializer=clone_initializer(self.kernel_initializer), + output_shape=(None, self.num_query_heads, head_dim), + kernel_initializer=self.kernel_initializer, dtype=self.dtype_policy, name="query", ) self._query_dense.build(inputs_shape) + self._key_dense = keras.layers.EinsumDense( equation="bkm,mvh->bkvh", - output_shape=(None, self.num_key_value_heads, self.attn_head_size), - kernel_initializer=clone_initializer(self.kernel_initializer), + output_shape=( + None, + self.num_key_value_heads, + head_dim, + ), + kernel_initializer=self.kernel_initializer, dtype=self.dtype_policy, name="key", ) @@ -73,8 +81,12 @@ def build(self, inputs_shape): self._value_dense = keras.layers.EinsumDense( equation="bkm,mvh->bkvh", - output_shape=(None, self.num_key_value_heads, self.attn_head_size), - kernel_initializer=clone_initializer(self.kernel_initializer), + output_shape=( + None, + self.num_key_value_heads, + head_dim, + ), + kernel_initializer=self.kernel_initializer, dtype=self.dtype_policy, name="value", ) @@ -86,21 +98,28 @@ def build(self, inputs_shape): name="attention_softmax", ) + self._dropout_layer = keras.layers.Dropout( + rate=self.dropout, + dtype=self.dtype_policy, + ) + self._output_dense = keras.layers.EinsumDense( - equation="bqm,mh->bqh", - output_shape=(None, self.hidden_dim), - kernel_initializer=clone_initializer(self.kernel_initializer), + equation="bquh,uhm->bqm", + output_shape=(None, hidden_dim), + kernel_initializer=self.kernel_initializer, dtype=self.dtype_policy, name="attention_output", ) - self._output_dense.build(inputs_shape) + self._output_dense.build((None, None, self.num_query_heads, head_dim)) - self._rotary_embedding_layer = RotaryEmbedding( + self.rotary_embedding_layer = RotaryEmbedding( max_wavelength=self.rope_max_wavelength, scaling_factor=self.rope_scaling_factor, dtype=self.dtype_policy, ) - self._rotary_embedding_layer.build(inputs_shape) + + self._dot_product_equation = "bquh,bkuh->buqk" + self._combine_equation = "buqk,bkuh->bquh" self.built = True @@ -110,6 +129,7 @@ def call( attention_mask=None, cache=None, cache_update_index=None, + training=None, ): query = self._query_dense(hidden_states) @@ -136,75 +156,61 @@ def call( key = self._key_dense(hidden_states) value = self._value_dense(hidden_states) - query = self._rotary_embedding_layer(query) - key = self._rotary_embedding_layer(key) + query = self.rotary_embedding_layer(query) + key = self.rotary_embedding_layer(key) - key = ops.tile(key, [1, 1, self.num_key_value_groups, 1]) - value = ops.tile(value, [1, 1, self.num_key_value_groups, 1]) + # [batch_shape, seq_len, num_key_value_heads, head_dim] + # -> [batch_shape, seq_len, num_heads, head_dim] + key = ops.repeat(key, repeats=self.num_key_value_groups, axis=2) + value = ops.repeat(value, repeats=self.num_key_value_groups, axis=2) - attention_output, attention_scores = self._compute_attention( + attention_output = self._compute_attention( query, key, value, attention_mask ) - attention_output_shape = ops.shape(attention_output) - - attention_output = ops.reshape( - attention_output, - [ - attention_output_shape[0], - attention_output_shape[1], - self.hidden_dim, - ], + attention_output = self._dropout_layer( + attention_output, training=training ) attention_output = self._output_dense(attention_output) if cache is not None: - return (attention_output, cache) + return attention_output, cache return attention_output def _masked_softmax(self, attention_scores, attention_mask=None): if attention_mask is not None: - mask_expansion_axis = -3 - for _ in range( - len(attention_scores.shape) - len(attention_mask.shape) - ): - attention_mask = ops.expand_dims( - attention_mask, axis=mask_expansion_axis - ) - return self._softmax(attention_scores, attention_mask) + return self._softmax( + attention_scores, attention_mask[:, None, :, :] + ) + return self._softmax(attention_scores) def _compute_attention(self, query, key, value, attention_mask=None): - attention_scores = ops.einsum("aecd,abcd->acbe", key, query) - - norm_factor = ops.sqrt( - ops.convert_to_tensor(self.attn_head_size, self.compute_dtype) - ) + attention_scores = ops.einsum(self._dot_product_equation, query, key) - attention_scores /= norm_factor + attention_scores = attention_scores / self._norm_factor attention_scores = self._masked_softmax( attention_scores, attention_mask ) attention_scores = ops.cast(attention_scores, self.compute_dtype) attention_output = ops.einsum( - "acbe,aecd->abcd", attention_scores, value + self._combine_equation, attention_scores, value ) - return attention_output, attention_scores + return attention_output def get_config(self): config = super().get_config() config.update( { "num_query_heads": self.num_query_heads, - "hidden_dim": self.hidden_dim, + "num_key_value_heads": self.num_key_value_heads, + "rope_max_wavelength": self.rope_max_wavelength, + "rope_scaling_factor": self.rope_scaling_factor, "kernel_initializer": keras.initializers.serialize( self.kernel_initializer ), - "rope_max_wavelength": self.rope_max_wavelength, - "rope_scaling_factor": self.rope_scaling_factor, - "num_key_value_heads": self.num_key_value_heads, - "max_sequence_length": self.max_sequence_length, + "dropout": self.dropout, } ) return config diff --git a/keras_nlp/models/llama/llama_backbone.py b/keras_nlp/models/llama/llama_backbone.py index 733d9ef434..b5383d528a 100644 --- a/keras_nlp/models/llama/llama_backbone.py +++ b/keras_nlp/models/llama/llama_backbone.py @@ -11,14 +11,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +# import copy + from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras from keras_nlp.backend import ops from keras_nlp.layers.modeling.reversible_embedding import ReversibleEmbedding from keras_nlp.models.backbone import Backbone -from keras_nlp.models.llama.llama_decoder import LlamaDecoder + +# from keras_nlp.models.llama.llama_presets import backbone_presets +from keras_nlp.models.llama.llama_decoder import LlamaTransformerDecoder from keras_nlp.models.llama.llama_layernorm import LlamaLayerNorm +# from keras_nlp.utils.python_utils import classproperty + def _llama_kernel_initializer(stddev=0.02): return keras.initializers.RandomNormal(stddev=stddev) @@ -27,41 +34,64 @@ def _llama_kernel_initializer(stddev=0.02): @keras_nlp_export("keras_nlp.models.LlamaBackbone") class LlamaBackbone(Backbone): """ - LLaMA core network with hyperparameters. + The Llama Transformer core architecture with hyperparameters. This network implements a Transformer-based decoder network, - LLaMA, as described in ["LLaMA: Open Foundation and Fine-Tuned Language Models"](https://arxiv.org/abs/2302.13971). + Llama, as described in + ["Llama 7B"](https://arxiv.org/pdf/2310.06825.pdf). + It includes the embedding lookups and transformer layers. The default constructor gives a fully customizable, randomly initialized - LLaMA model with any number of layers, heads, and embedding - dimensions. This backbone also supports LLaMA2 checkpoints. + Llama model with any number of layers, heads, and embedding + dimensions. To load preset architectures and weights, use the `from_preset` + constructor. Args: - vocabulary_size: int. The size of the token vocabulary. - num_layers: int. The number of transformer layers. - num_query_heads: int. The number of attention heads for each transformer. - The hidden size must be divisible by the number of attention heads. - hidden_dim: int. The size of the transformer encoding and pooler layers. - intermediate_dim: int. The output dimension of the first Dense layer in - a two-layer feedforward network for each transformer. - num_key_value_heads: int. This is the number of key_value heads that - should be used to implement Grouped Query Attention. If num_key_value_heads=num_attention_heads, - the model will use Multi Head Attention (MHA), if num_key_value_heads=1 - the model will use Multi Query Attention (MQA) - rope_scaling_factor: float. The scaling factor for calculation of rotary - embedding - rope_max_wavelength: int. The maximum angular wavelength of the - sine/cosine curves, for rotary embeddings. - layer_norm_epsilon: float. a value added to the denominator for - numerical stability. - max_sequence_length: int. The maximum sequence length that this encoder - can consume. If `None`, `max_sequence_length` uses the value from - sequence length. This determines the variable shape for positional - embeddings. + vocabulary_size (int): The size of the token vocabulary. + num_layers (int): The number of transformer layers. + num_query_heads (int): The number of query attention heads for + each transformer. + hidden_dim (int): The size of the transformer encoding and pooling layers. + intermediate_dim (int): The output dimension of the first Dense layer in a + three-layer feedforward network for each transformer. + num_key_value_heads (int): The number of key and value attention heads for + each transformer. + rope_max_wavelength (int, optional): The maximum angular wavelength of the + sine/cosine curves, for rotary embeddings. Defaults to `10000`. + rope_scaling_factor (float, optional): The scaling factor for calculation + of roatary embedding. Defaults to `1.0`. + layer_norm_epsilon (float, optional): Epsilon for the layer normalization + layers in the transformer decoder. Defaults to `1e-6`. dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use for model computations and weights. Note that some computations, such as softmax and layer normalization, will always be done at float32 precision regardless of dtype. + + Examples: + + ```python + input_data = { + "token_ids": np.ones(shape=(1, 12), dtype="int32"), + "padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]), + } + + # Pretrained Llama decoder. + model = keras_nlp.models.LlamaBackbone.from_preset("llama7b_base_en") + model(input_data) + + # Randomly initialized Llama decoder with custom config. + model = keras_nlp.models.LlamaBackbone( + vocabulary_size=10, + hidden_dim=512, + num_layers=2, + num_query_heads=32, + num_key_value_heads=8, + intermediate_dim=1024, + layer_norm_epsilon=1e-6, + dtype="float32" + ) + model(input_data) + ``` """ def __init__( @@ -72,10 +102,10 @@ def __init__( hidden_dim, intermediate_dim, num_key_value_heads, - rope_scaling_factor=1.0, rope_max_wavelength=10000, - layer_norm_epsilon=1e-5, - max_sequence_length=4096, + rope_scaling_factor=1.0, + layer_norm_epsilon=1e-6, + dropout=0, dtype=None, **kwargs, ): @@ -83,31 +113,31 @@ def __init__( self.token_embedding = ReversibleEmbedding( input_dim=vocabulary_size, output_dim=hidden_dim, - embeddings_initializer=_llama_kernel_initializer(stddev=0.01), tie_weights=False, + embeddings_initializer=_llama_kernel_initializer(stddev=0.01), dtype=dtype, name="token_embedding", ) self.transformer_layers = [] for i in range(num_layers): - layer = LlamaDecoder( + layer = LlamaTransformerDecoder( intermediate_dim=intermediate_dim, num_query_heads=num_query_heads, num_key_value_heads=num_key_value_heads, - rope_scaling_factor=rope_scaling_factor, - max_sequence_length=max_sequence_length, rope_max_wavelength=rope_max_wavelength, + rope_scaling_factor=rope_scaling_factor, layer_norm_epsilon=layer_norm_epsilon, activation=ops.silu, kernel_initializer=_llama_kernel_initializer(stddev=0.02), + dropout=dropout, dtype=dtype, name=f"transformer_layer_{i}", ) self.transformer_layers.append(layer) self.layer_norm = LlamaLayerNorm( - dtype=dtype, epsilon=layer_norm_epsilon, - name="layer_norm", + dtype=dtype, + name="sequence_output_layernorm", ) # === Functional Model === @@ -140,8 +170,8 @@ def __init__( self.rope_max_wavelength = rope_max_wavelength self.num_key_value_heads = num_key_value_heads self.rope_scaling_factor = rope_scaling_factor - self.max_sequence_length = max_sequence_length self.layer_norm_epsilon = layer_norm_epsilon + self.dropout = dropout def get_config(self): config = super().get_config() @@ -155,8 +185,12 @@ def get_config(self): "rope_max_wavelength": self.rope_max_wavelength, "rope_scaling_factor": self.rope_scaling_factor, "num_key_value_heads": self.num_key_value_heads, - "max_sequence_length": self.max_sequence_length, "layer_norm_epsilon": self.layer_norm_epsilon, + "dropout": self.dropout, } ) return config + + # @classproperty + # def presets(cls): + # return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/llama/llama_backbone_test.py b/keras_nlp/models/llama/llama_backbone_test.py index efff972c6b..56d8c44bd3 100644 --- a/keras_nlp/models/llama/llama_backbone_test.py +++ b/keras_nlp/models/llama/llama_backbone_test.py @@ -28,7 +28,6 @@ def setUp(self): "num_key_value_heads": 2, "hidden_dim": 8, "intermediate_dim": 8, - "max_sequence_length": 10, } self.input_data = { "token_ids": ops.ones((2, 5), dtype="int32"), diff --git a/keras_nlp/models/llama/llama_decoder.py b/keras_nlp/models/llama/llama_decoder.py index 3b9d6906b8..1ef247c575 100644 --- a/keras_nlp/models/llama/llama_decoder.py +++ b/keras_nlp/models/llama/llama_decoder.py @@ -24,20 +24,20 @@ from keras_nlp.utils.keras_utils import clone_initializer -class LlamaDecoder(keras.layers.Layer): - """Llama decoder block.""" +class LlamaTransformerDecoder(keras.layers.Layer): + """A Transformer decoder layer for the Llama backbone.""" def __init__( self, intermediate_dim, num_query_heads, num_key_value_heads, + rope_max_wavelength=10000, rope_scaling_factor=1.0, - activation="relu", + activation="silu", layer_norm_epsilon=1e-5, kernel_initializer="glorot_uniform", - rope_max_wavelength=10000, - max_sequence_length=512, + dropout=0, **kwargs, ): super().__init__(**kwargs) @@ -48,37 +48,50 @@ def __init__( self.rope_max_wavelength = rope_max_wavelength self.rope_scaling_factor = rope_scaling_factor - self.max_sequence_length = max_sequence_length + self.dropout = dropout + self.activation = keras.activations.get(activation) self.layer_norm_epsilon = layer_norm_epsilon self.kernel_initializer = keras.initializers.get(kernel_initializer) + self.supports_masking = True + def build(self, decoder_sequence_shape): + self._decoder_sequence_shape = decoder_sequence_shape self.hidden_dim = decoder_sequence_shape[-1] - # Self attention layers. + # Self attention layer. self._self_attention_layer = LlamaAttention( num_query_heads=self.num_query_heads, num_key_value_heads=self.num_key_value_heads, rope_max_wavelength=self.rope_max_wavelength, - max_sequence_length=self.max_sequence_length, rope_scaling_factor=self.rope_scaling_factor, kernel_initializer=clone_initializer(self.kernel_initializer), + dropout=self.dropout, dtype=self.dtype_policy, + name="self_attention", ) self._self_attention_layer.build(decoder_sequence_shape) self._self_attention_layernorm = LlamaLayerNorm( epsilon=self.layer_norm_epsilon, dtype=self.dtype_policy, + name="self_attention_layernorm", ) self._self_attention_layernorm.build(decoder_sequence_shape) + self._self_attention_dropout = keras.layers.Dropout( + rate=self.dropout, + dtype=self.dtype_policy, + name="self_attention_dropout", + ) # Feedforward layers. self._feedforward_intermediate_dense = keras.layers.Dense( self.intermediate_dim, kernel_initializer=clone_initializer(self.kernel_initializer), + use_bias=False, dtype=self.dtype_policy, + name="feedforward_intermediate_dense", ) self._feedforward_intermediate_dense.build(decoder_sequence_shape) @@ -86,23 +99,30 @@ def build(self, decoder_sequence_shape): self.intermediate_dim, activation=self.activation, kernel_initializer=clone_initializer(self.kernel_initializer), + use_bias=False, dtype=self.dtype_policy, + name="feedforward_gate_dense", ) self._feedforward_gate_dense.build(decoder_sequence_shape) self._feedforward_output_dense = keras.layers.Dense( self.hidden_dim, kernel_initializer=clone_initializer(self.kernel_initializer), + use_bias=False, dtype=self.dtype_policy, + name="feedforward_output_dense", ) - intermediate_shape = list(decoder_sequence_shape) - intermediate_shape[-1] = self.intermediate_dim - self._feedforward_output_dense.build(tuple(intermediate_shape)) + self._feedforward_output_dense.build( + self._feedforward_gate_dense.compute_output_shape( + decoder_sequence_shape + ) + ) self._feedforward_layernorm = LlamaLayerNorm( epsilon=self.layer_norm_epsilon, dtype=self.dtype_policy, + name="feedforward_layernorm", ) self._feedforward_layernorm.build(decoder_sequence_shape) @@ -115,6 +135,7 @@ def call( decoder_attention_mask=None, self_attention_cache=None, self_attention_cache_update_index=None, + training=None, ): self_attention_mask = self._compute_self_attention_mask( decoder_sequence=decoder_sequence, @@ -125,10 +146,9 @@ def call( ) residual = decoder_sequence - x = self._self_attention_layernorm( - decoder_sequence, - ) + x = self._self_attention_layernorm(decoder_sequence) + # Self attention block. x = self._self_attention_layer( hidden_states=x, attention_mask=self_attention_mask, @@ -139,6 +159,8 @@ def call( if self_attention_cache is not None: x, self_attention_cache = x + x = self._self_attention_dropout(x, training=training) + x = x + residual residual = x @@ -152,7 +174,7 @@ def call( decoder_output = x + residual if self_attention_cache is not None: - return (decoder_output, self_attention_cache) + return decoder_output, self_attention_cache return decoder_output def _compute_self_attention_mask( @@ -160,8 +182,8 @@ def _compute_self_attention_mask( decoder_sequence, decoder_padding_mask, decoder_attention_mask, - self_attention_cache=None, - self_attention_cache_update_index=None, + self_attention_cache, + self_attention_cache_update_index, ): decoder_mask = merge_padding_and_attention_mask( decoder_sequence, decoder_padding_mask, decoder_attention_mask @@ -174,16 +196,16 @@ def _compute_self_attention_mask( if self_attention_cache is not None: input_length = ops.shape(self_attention_cache)[2] + cache_update_index = ( + 0 + if self_attention_cache_update_index is None + else self_attention_cache_update_index + ) + causal_mask = compute_causal_mask( - batch_size, - input_length, - output_length, - ( - 0 - if self_attention_cache_update_index is None - else self_attention_cache_update_index - ), + batch_size, input_length, output_length, cache_update_index ) + return ( ops.minimum(decoder_mask, causal_mask) if decoder_mask is not None @@ -198,17 +220,16 @@ def get_config(self): config.update( { "intermediate_dim": self.intermediate_dim, - "hidden_dim": self.hidden_dim, "num_query_heads": self.num_query_heads, "rope_max_wavelength": self.rope_max_wavelength, "rope_scaling_factor": self.rope_scaling_factor, "num_key_value_heads": self.num_key_value_heads, - "max_sequence_length": self.max_sequence_length, "activation": keras.activations.serialize(self.activation), "layer_norm_epsilon": self.layer_norm_epsilon, "kernel_initializer": keras.initializers.serialize( self.kernel_initializer ), + "dropout": self.dropout, } ) return config