Skip to content

Commit

Permalink
Do the reverse embedding in the same dtype as the input embedding (ke…
Browse files Browse the repository at this point in the history
  • Loading branch information
mattdangerw authored Apr 10, 2024
1 parent c157ac2 commit ab649f5
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 10 deletions.
10 changes: 5 additions & 5 deletions keras_nlp/layers/modeling/reversible_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ class ReversibleEmbedding(keras.layers.Embedding):
mask_zero: Boolean, whether or not the input value 0 is a special
"padding" value that should be masked out.
reverse_dtype: The dtype for the reverse projection computation.
For stability, it is usually best to use full precision even when
working with half or mixed precision training.
Defaults to the `compute_dtype` of the layer.
**kwargs: other keyword arguments passed to `keras.layers.Embedding`,
including `name`, `trainable`, `dtype` etc.
Expand Down Expand Up @@ -90,7 +89,7 @@ def __init__(
embeddings_regularizer=None,
embeddings_constraint=None,
mask_zero=False,
reverse_dtype="float32",
reverse_dtype=None,
**kwargs,
):
super().__init__(
Expand Down Expand Up @@ -122,8 +121,9 @@ def call(self, inputs, reverse=False):
kernel = ops.transpose(ops.convert_to_tensor(self.embeddings))
else:
kernel = self.reverse_embeddings
inputs = ops.cast(inputs, self.reverse_dtype)
kernel = ops.cast(kernel, self.reverse_dtype)
if self.reverse_dtype is not None:
inputs = ops.cast(inputs, self.reverse_dtype)
kernel = ops.cast(kernel, self.reverse_dtype)
return ops.matmul(inputs, kernel)

return super().call(inputs)
Expand Down
1 change: 0 additions & 1 deletion keras_nlp/models/llama/llama_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ def __init__(
tie_weights=False,
embeddings_initializer=_llama_kernel_initializer(stddev=0.01),
dtype=dtype,
reverse_dtype=dtype,
name="token_embedding",
)
self.transformer_layers = []
Expand Down
1 change: 0 additions & 1 deletion keras_nlp/models/mistral/mistral_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ def __init__(
tie_weights=False,
embeddings_initializer=_mistral_kernel_initializer(stddev=0.01),
dtype=dtype,
reverse_dtype=dtype,
name="token_embedding",
)
self.transformer_layers = []
Expand Down
4 changes: 1 addition & 3 deletions keras_nlp/samplers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,8 @@ def compute_probabilities(self, logits):
This will always be done in full precision, regardless of dtype, and
scale by `temperature`.
"""
logits_dtype = logits.dtype
logits = ops.cast(logits, "float32")
probs = keras.activations.softmax(logits / self.temperature)
return ops.cast(probs, logits_dtype)
return keras.activations.softmax(logits / self.temperature)

def run_loop(
self, cond, body, model=None, loop_vars=None, maximum_iterations=None
Expand Down

0 comments on commit ab649f5

Please sign in to comment.