Skip to content

Commit

Permalink
Merge branch 'keras-team:master' into task-upload
Browse files Browse the repository at this point in the history
  • Loading branch information
SamanehSaadat committed Apr 10, 2024
2 parents f89d795 + ab649f5 commit 069303c
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 40 deletions.
34 changes: 19 additions & 15 deletions keras_nlp/layers/modeling/reversible_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np

from keras_nlp.api_export import keras_nlp_export
from keras_nlp.backend import keras
from keras_nlp.backend import ops
Expand Down Expand Up @@ -50,8 +48,7 @@ class ReversibleEmbedding(keras.layers.Embedding):
mask_zero: Boolean, whether or not the input value 0 is a special
"padding" value that should be masked out.
reverse_dtype: The dtype for the reverse projection computation.
For stability, it is usually best to use full precision even when
working with half or mixed precision training.
Defaults to the `compute_dtype` of the layer.
**kwargs: other keyword arguments passed to `keras.layers.Embedding`,
including `name`, `trainable`, `dtype` etc.
Expand Down Expand Up @@ -92,7 +89,7 @@ def __init__(
embeddings_regularizer=None,
embeddings_constraint=None,
mask_zero=False,
reverse_dtype="float32",
reverse_dtype=None,
**kwargs,
):
super().__init__(
Expand Down Expand Up @@ -124,8 +121,9 @@ def call(self, inputs, reverse=False):
kernel = ops.transpose(ops.convert_to_tensor(self.embeddings))
else:
kernel = self.reverse_embeddings
inputs = ops.cast(inputs, self.reverse_dtype)
kernel = ops.cast(kernel, self.reverse_dtype)
if self.reverse_dtype is not None:
inputs = ops.cast(inputs, self.reverse_dtype)
kernel = ops.cast(kernel, self.reverse_dtype)
return ops.matmul(inputs, kernel)

return super().call(inputs)
Expand All @@ -140,18 +138,24 @@ def get_config(self):
)
return config

def save_own_variables(self, store):
if not self.built:
return
super().save_own_variables(store)
# Before Keras 3.2, the reverse weight is saved in the super() call.
# After Keras 3.2, the reverse weight must be saved manually.
if len(store.keys()) < len(self.weights):
# Store the reverse embedding as the last weight.
store[str(len(store.keys()))] = self.reverse_embeddings

def load_own_variables(self, store):
if not self.built:
self.build()
self.embeddings.assign(store["0"])
super().load_own_variables(store)
if not self.tie_weights:
# Handle the case where saved weights are tied, but the layer
# weights untied. We can simply assign the embedding weights to both
# variables in this case.
if len(store.keys()) == 1:
self.reverse_embeddings.assign(np.transpose(store["0"]))
else:
self.reverse_embeddings.assign(store["1"])
# Last weight in the store is the reverse embedding weights.
key = str(len(store.keys()) - 1)
self.reverse_embeddings.assign(store[key])

def compute_output_spec(self, inputs, reverse=False):
output_shape = list(inputs.shape)
Expand Down
41 changes: 22 additions & 19 deletions keras_nlp/layers/modeling/reversible_embedding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,28 @@ def test_layer_behaviors_tied(self, tie_weights):
expected_num_trainable_weights=1 if tie_weights else 2,
)

@parameterized.named_parameters(
("tie_weights", True),
("untie_weights", False),
)
def test_saving(self, tie_weights):
input_data = random.randint(minval=0, maxval=100, shape=(4, 10))
model = keras.Sequential(
[
ReversibleEmbedding(
input_dim=100,
output_dim=32,
tie_weights=tie_weights,
)
]
)
path = os.path.join(self.get_temp_dir(), "model.keras")
model_output = model(input_data)
model.save(path, save_format="keras_v3")
restored_model = keras.models.load_model(path)
restored_output = restored_model(input_data)
self.assertAllClose(model_output, restored_output)

def test_correctness(self):
layer = ReversibleEmbedding(input_dim=3, output_dim=2)
layer.build()
Expand All @@ -57,25 +79,6 @@ def test_correctness(self):
out = layer(np.array(([[1.0, 1.0]])), reverse=True)
self.assertAllClose(out, np.array([[0.0, 4.0, 6.0]]))

def test_tied_checkpoint_untied_weights(self):
embedding = ReversibleEmbedding(100, 16, tie_weights=True)
inputs = keras.Input(shape=(10,), dtype="int32")
hidden_states = embedding(inputs)
outputs = embedding(hidden_states, reverse=True)
tied_model = keras.Model(inputs, outputs)
path = os.path.join(self.get_temp_dir(), "checkpoint.weights.h5")
tied_model.save_weights(path)

embedding = ReversibleEmbedding(100, 16, tie_weights=False)
inputs = keras.Input(shape=(10,), dtype="int32")
hidden_states = embedding(inputs)
outputs = embedding(hidden_states, reverse=True)
untied_model = keras.Model(inputs, outputs)
untied_model.load_weights(path)

input_data = ops.ones(shape=(4, 10), dtype="int32")
self.assertAllClose(untied_model(input_data), tied_model(input_data))

def test_reverse_dtype(self):
embedding = ReversibleEmbedding(100, 16, reverse_dtype="float32")
input_data = ops.ones(shape=(4, 10, 16))
Expand Down
1 change: 0 additions & 1 deletion keras_nlp/models/llama/llama_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ def __init__(
tie_weights=False,
embeddings_initializer=_llama_kernel_initializer(stddev=0.01),
dtype=dtype,
reverse_dtype=dtype,
name="token_embedding",
)
self.transformer_layers = []
Expand Down
1 change: 0 additions & 1 deletion keras_nlp/models/mistral/mistral_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ def __init__(
tie_weights=False,
embeddings_initializer=_mistral_kernel_initializer(stddev=0.01),
dtype=dtype,
reverse_dtype=dtype,
name="token_embedding",
)
self.transformer_layers = []
Expand Down
4 changes: 1 addition & 3 deletions keras_nlp/samplers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,8 @@ def compute_probabilities(self, logits):
This will always be done in full precision, regardless of dtype, and
scale by `temperature`.
"""
logits_dtype = logits.dtype
logits = ops.cast(logits, "float32")
probs = keras.activations.softmax(logits / self.temperature)
return ops.cast(probs, logits_dtype)
return keras.activations.softmax(logits / self.temperature)

def run_loop(
self, cond, body, model=None, loop_vars=None, maximum_iterations=None
Expand Down
2 changes: 1 addition & 1 deletion keras_nlp/version_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from keras_nlp.api_export import keras_nlp_export

# Unique source of truth for the version number.
__version__ = "0.9.0"
__version__ = "0.10.0"


@keras_nlp_export("keras_nlp.version")
Expand Down

0 comments on commit 069303c

Please sign in to comment.