Skip to content

Commit

Permalink
Update generate_step function.
Browse files Browse the repository at this point in the history
  • Loading branch information
SamanehSaadat committed May 22, 2024
1 parent e5dfe55 commit f3c4c2e
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions keras_nlp/src/models/falcon/falcon_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from keras_nlp.src.models.falcon.falcon_causal_lm_preprocessor import (
FalconCausalLMPreprocessor,
)
from keras_nlp.src.utils.tensor_utils import any_equal


@keras_nlp_export("keras_nlp.models.FalconCausalLM")
Expand Down Expand Up @@ -266,11 +267,10 @@ def next(prompt, cache, index):

# Compute an output padding mask with the token ids we updated.
if stop_token_ids is not None:
# Build a mask of `end_token_id` locations not in the original
# Build a mask of stop token locations not in the original
# prompt (not in locations where `padding_mask` is True).
end_locations = ops.logical_and(
ops.equal(token_ids, stop_token_ids),
ops.logical_not(padding_mask),
end_locations = any_equal(
token_ids, stop_token_ids, ops.logical_not(padding_mask)
)
end_locations = ops.cast(end_locations, "int32")
# Use cumsum to get ones in all locations after end_locations.
Expand Down

0 comments on commit f3c4c2e

Please sign in to comment.