Skip to content

Commit

Permalink
Merge branch 'keras-team:master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
SamanehSaadat authored May 20, 2024
2 parents 323abb9 + 294304b commit 8de2584
Show file tree
Hide file tree
Showing 45 changed files with 3,999 additions and 82 deletions.
13 changes: 13 additions & 0 deletions keras_nlp/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,12 @@
GPTNeoXPreprocessor,
)
from keras_nlp.src.models.gpt_neo_x.gpt_neo_x_tokenizer import GPTNeoXTokenizer
from keras_nlp.src.models.llama3.llama3_backbone import Llama3Backbone
from keras_nlp.src.models.llama3.llama3_causal_lm_preprocessor import (
Llama3CausalLMPreprocessor,
)
from keras_nlp.src.models.llama3.llama3_preprocessor import Llama3Preprocessor
from keras_nlp.src.models.llama3.llama3_tokenizer import Llama3Tokenizer
from keras_nlp.src.models.llama.llama_backbone import LlamaBackbone
from keras_nlp.src.models.llama.llama_causal_lm import LlamaCausalLM
from keras_nlp.src.models.llama.llama_causal_lm_preprocessor import (
Expand All @@ -152,6 +158,13 @@
)
from keras_nlp.src.models.opt.opt_preprocessor import OPTPreprocessor
from keras_nlp.src.models.opt.opt_tokenizer import OPTTokenizer
from keras_nlp.src.models.phi3.phi3_backbone import Phi3Backbone
from keras_nlp.src.models.phi3.phi3_causal_lm import Phi3CausalLM
from keras_nlp.src.models.phi3.phi3_causal_lm_preprocessor import (
Phi3CausalLMPreprocessor,
)
from keras_nlp.src.models.phi3.phi3_preprocessor import Phi3Preprocessor
from keras_nlp.src.models.phi3.phi3_tokenizer import Phi3Tokenizer
from keras_nlp.src.models.preprocessor import Preprocessor
from keras_nlp.src.models.roberta.roberta_backbone import RobertaBackbone
from keras_nlp.src.models.roberta.roberta_classifier import RobertaClassifier
Expand Down
7 changes: 6 additions & 1 deletion keras_nlp/src/layers/modeling/cached_multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ class CachedMultiHeadAttention(keras.layers.MultiHeadAttention):
`cache` (usually the index of the current token being processed
when running generation). If `cache_update_index=None` while `cache`
is set, the cache will not be updated.
training: a boolean indicating whether the layer should behave in
training mode or in inference mode.
Returns:
An `(attention_output, cache)` tuple. `attention_output` is the result
Expand All @@ -83,6 +85,7 @@ def call(
attention_mask=None,
cache=None,
cache_update_index=None,
training=None,
):
if (
hasattr(self, "_build_from_signature")
Expand Down Expand Up @@ -133,7 +136,9 @@ def call(
attention_scores = self._masked_softmax(
attention_scores, attention_mask
)
attention_scores = self._dropout_layer(attention_scores)
attention_scores = self._dropout_layer(
attention_scores, training=training
)

attention_output = ops.einsum(
self._combine_equation, attention_scores, value
Expand Down
27 changes: 27 additions & 0 deletions keras_nlp/src/layers/modeling/cached_multi_head_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,30 @@ def call(outputs, cache):

self.assertAllClose(output, no_loop_outputs)
self.assertAllClose(output_cache, no_loop_cache)

def test_training_propagation(self):
batch_size = 2
seq_len = 5
num_heads = 2
key_dim = 4
hidden_dim = num_heads * key_dim

input_shape = (batch_size, seq_len, hidden_dim)
x = random.uniform(shape=input_shape)

layer = CachedMultiHeadAttention(
num_heads=num_heads,
key_dim=key_dim,
dropout=0.99999, # Zeros out the outputs after the dropout layer
)
outputs = layer(x, x, training=True)

# Custom computation with dropout rate sets to about 1.0
value = layer._value_dense(x)
attention_scores = ops.zeros((batch_size, num_heads, seq_len, seq_len))
attention_output = ops.einsum(
layer._combine_equation, attention_scores, value
)
attention_output = layer._output_dense(attention_output)

self.assertAllClose(outputs, attention_output, atol=1e-5)
6 changes: 4 additions & 2 deletions keras_nlp/src/layers/modeling/f_net_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,14 @@ def build(self, inputs_shape):
)
self.built = True

def call(self, inputs):
def call(self, inputs, training=None):
"""Forward pass of the FNetEncoder.
Args:
inputs: a Tensor. The input data to TransformerEncoder, should be
of shape [batch_size, sequence_length, feature_dim].
training: a boolean indicating whether the layer should behave in
training mode or in inference mode.
Returns:
A Tensor of the same shape as the `inputs`.
Expand All @@ -160,7 +162,7 @@ def add_and_norm(input1, input2, norm_layer):
def feed_forward(input):
x = self._intermediate_dense(input)
x = self._output_dense(x)
return self._output_dropout(x)
return self._output_dropout(x, training=training)

mixing_output = fourier_transform(inputs)

Expand Down
32 changes: 32 additions & 0 deletions keras_nlp/src/layers/modeling/f_net_encoder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from keras_nlp.src.backend import ops
from keras_nlp.src.backend import random
from keras_nlp.src.layers.modeling.f_net_encoder import FNetEncoder
from keras_nlp.src.tests.test_case import TestCase
Expand Down Expand Up @@ -42,3 +43,34 @@ def test_value_error_when_invalid_kernel_initializer(self):
dropout=0.5,
kernel_initializer="Invalid",
)

def test_training_propagation(self):
x = random.uniform(shape=(2, 4, 6))
layer = FNetEncoder(
intermediate_dim=4,
dropout=0.99999, # Zeros out the outputs after the dropout layer
)
outputs = layer(x, training=True)

# Custom computation with dropout rate sets to about 1.0
def fourier_transform(input):
# Apply FFT on the input and take the real part.
input_dtype = input.dtype
# FFT transforms do not support float16.
input = ops.cast(input, "float32")
real_in, imaginary_in = (input, ops.zeros_like(input))
real_out, _ = ops.fft2((real_in, imaginary_in))
return ops.cast(real_out, input_dtype)

def add_and_norm(input1, input2, norm_layer):
return norm_layer(input1 + input2)

mixing_output = fourier_transform(x)
mixing_output = add_and_norm(x, mixing_output, layer._mixing_layer_norm)
x = add_and_norm(
mixing_output,
ops.zeros_like(mixing_output),
layer._output_layer_norm,
)

self.assertAllClose(outputs, x, atol=1e-5)
12 changes: 9 additions & 3 deletions keras_nlp/src/layers/modeling/transformer_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ def call(
cross_attention_cache=None,
cross_attention_cache_update_index=None,
use_causal_mask=True,
training=None,
):
"""Forward pass of the TransformerDecoder.
Expand Down Expand Up @@ -315,6 +316,9 @@ def call(
`None` (reuse a previously computed `cross_attention_cache`).
use_causal_mask: bool, defaults to `True`. If true, a causal mask
(masking out future input) is applied `on the decoder sequence.
training: a boolean indicating whether the layer should behave in
training mode or in inference mode.
Returns:
One of three things, depending on call arguments:
- `outputs`, if `self_attention_cache` is `None.
Expand Down Expand Up @@ -385,12 +389,13 @@ def call(
attention_mask=self_attention_mask,
cache=self_attention_cache,
cache_update_index=self_attention_cache_update_index,
training=training,
)
if self_attention_cache is None:
x = attention_output
else:
x, self_attention_cache = attention_output
x = self._self_attention_dropout(x)
x = self._self_attention_dropout(x, training=training)
x = x + residual
if not self.normalize_first:
x = self._self_attention_layer_norm(x)
Expand All @@ -412,12 +417,13 @@ def call(
attention_mask=cross_attention_mask,
cache=cross_attention_cache,
cache_update_index=cross_attention_cache_update_index,
training=training,
)
if cross_attention_cache is None:
x = attention_output
else:
x, cross_attention_cache = attention_output
x = self._cross_attention_dropout(x)
x = self._cross_attention_dropout(x, training=training)
x = x + residual
if not self.normalize_first:
x = self._cross_attention_layer_norm(x)
Expand All @@ -428,7 +434,7 @@ def call(
x = self._feedforward_layer_norm(x)
x = self._feedforward_intermediate_dense(x)
x = self._feedforward_output_dense(x)
x = self._feedforward_dropout(x)
x = self._feedforward_dropout(x, training=training)
x = x + residual
if not self.normalize_first:
x = self._feedforward_layer_norm(x)
Expand Down
17 changes: 17 additions & 0 deletions keras_nlp/src/layers/modeling/transformer_decoder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,23 @@ def test_value_error_when_invalid_kernel_inititalizer(self):
kernel_initializer="Invalid",
)

def test_training_propagation(self):
decoder = TransformerDecoder(
intermediate_dim=4,
num_heads=2,
dropout=0.99999, # Zeros out the outputs after the dropout layer
)
decoder_sequence = random.uniform(shape=[1, 4, 6])
encoder_sequence = random.uniform(shape=[1, 4, 6])
outputs = decoder(decoder_sequence, encoder_sequence, training=True)

# Custom computation with dropout rates set to about 1.0
x = decoder_sequence
x = decoder._self_attention_layer_norm(x)
x = decoder._feedforward_layer_norm(x)

self.assertAllClose(outputs, x, atol=1e-5)

def test_mask_propagation(self):
decoder = TransformerDecoder(
intermediate_dim=4,
Expand Down
11 changes: 8 additions & 3 deletions keras_nlp/src/layers/modeling/transformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,9 @@ def build(self, inputs_shape):
)
self.built = True

def call(self, inputs, padding_mask=None, attention_mask=None):
def call(
self, inputs, padding_mask=None, attention_mask=None, training=None
):
"""Forward pass of the TransformerEncoder.
Args:
Expand All @@ -194,6 +196,8 @@ def call(self, inputs, padding_mask=None, attention_mask=None):
attention_mask: a boolean Tensor. Customized mask used to mask out
certain tokens. `attention_mask` should have shape
[batch_size, sequence_length, sequence_length].
training: a boolean indicating whether the layer should behave in
training mode or in inference mode.
Returns:
A Tensor of the same shape as the `inputs`.
Expand All @@ -213,8 +217,9 @@ def call(self, inputs, padding_mask=None, attention_mask=None):
query=x,
value=x,
attention_mask=self_attention_mask,
training=training,
)
x = self._self_attention_dropout(x)
x = self._self_attention_dropout(x, training=training)
x = x + residual
if not self.normalize_first:
x = self._self_attention_layer_norm(x)
Expand All @@ -225,7 +230,7 @@ def call(self, inputs, padding_mask=None, attention_mask=None):
x = self._feedforward_layer_norm(x)
x = self._feedforward_intermediate_dense(x)
x = self._feedforward_output_dense(x)
x = self._feedforward_dropout(x)
x = self._feedforward_dropout(x, training=training)
x = x + residual
if not self.normalize_first:
x = self._feedforward_layer_norm(x)
Expand Down
16 changes: 16 additions & 0 deletions keras_nlp/src/layers/modeling/transformer_encoder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,22 @@ def test_value_error_when_invalid_kernel_inititalizer(self):
kernel_initializer="Invalid",
)

def test_training_propagation(self):
encoder = TransformerEncoder(
intermediate_dim=4,
num_heads=2,
dropout=0.99999, # Zeros out the outputs after the dropout layer
)
inputs = random.uniform(shape=[1, 4, 6])
outputs = encoder(inputs, training=True)

# Custom computation with dropout rates set to about 1.0
x = inputs
x = encoder._self_attention_layer_norm(x)
x = encoder._feedforward_layer_norm(x)

self.assertAllClose(outputs, x, atol=1e-5)

def test_mask_propagation(self):
encoder = TransformerEncoder(
intermediate_dim=4,
Expand Down
7 changes: 7 additions & 0 deletions keras_nlp/src/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,13 @@
)
from keras_nlp.src.models.opt.opt_preprocessor import OPTPreprocessor
from keras_nlp.src.models.opt.opt_tokenizer import OPTTokenizer
from keras_nlp.src.models.phi3.phi3_backbone import Phi3Backbone
from keras_nlp.src.models.phi3.phi3_causal_lm import Phi3CausalLM
from keras_nlp.src.models.phi3.phi3_causal_lm_preprocessor import (
Phi3CausalLMPreprocessor,
)
from keras_nlp.src.models.phi3.phi3_preprocessor import Phi3Preprocessor
from keras_nlp.src.models.phi3.phi3_tokenizer import Phi3Tokenizer
from keras_nlp.src.models.preprocessor import Preprocessor
from keras_nlp.src.models.roberta.roberta_backbone import RobertaBackbone
from keras_nlp.src.models.roberta.roberta_classifier import RobertaClassifier
Expand Down
12 changes: 3 additions & 9 deletions keras_nlp/src/models/llama/llama_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from keras_nlp.src.models.llama.llama_causal_lm_preprocessor import (
LlamaCausalLMPreprocessor,
)
from keras_nlp.src.utils.python_utils import classproperty
from keras_nlp.src.utils.tensor_utils import any_equal


Expand All @@ -46,6 +45,9 @@ class LlamaCausalLM(CausalLM):
should be preprocessed before calling the model.
"""

backbone_cls = LlamaBackbone
preprocessor_cls = LlamaCausalLMPreprocessor

def __init__(self, backbone, preprocessor=None, **kwargs):
# === Layers ===
self.backbone = backbone
Expand All @@ -61,14 +63,6 @@ def __init__(self, backbone, preprocessor=None, **kwargs):
**kwargs,
)

@classproperty
def backbone_cls(cls):
return LlamaBackbone

@classproperty
def preprocessor_cls(cls):
return LlamaCausalLMPreprocessor

def call_with_cache(
self,
token_ids,
Expand Down
7 changes: 0 additions & 7 deletions keras_nlp/src/models/llama/llama_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy

from keras_nlp.src.api_export import keras_nlp_export
from keras_nlp.src.models.llama.llama_presets import backbone_presets
from keras_nlp.src.tokenizers.sentence_piece_tokenizer import (
SentencePieceTokenizer,
)
from keras_nlp.src.utils.python_utils import classproperty


@keras_nlp_export("keras_nlp.models.LlamaTokenizer")
Expand Down Expand Up @@ -85,7 +82,3 @@ def set_proto(self, proto):
self.start_token_id = None
self.end_token_id = None
self.pad_token_id = None

@classproperty
def presets(cls):
return copy.deepcopy(backbone_presets)
20 changes: 20 additions & 0 deletions keras_nlp/src/models/llama3/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from keras_nlp.src.models.llama3.llama3_backbone import Llama3Backbone
from keras_nlp.src.models.llama3.llama3_presets import backbone_presets
from keras_nlp.src.models.llama3.llama3_tokenizer import Llama3Tokenizer
from keras_nlp.src.utils.preset_utils import register_presets

register_presets(backbone_presets, (Llama3Backbone, Llama3Tokenizer))
Loading

0 comments on commit 8de2584

Please sign in to comment.