Skip to content

Commit

Permalink
Fixing stream stopping at wrong location (#898)
Browse files Browse the repository at this point in the history
Fixes #896
  • Loading branch information
isamu-isozaki committed May 17, 2024
1 parent 78852b0 commit 159d1ec
Showing 1 changed file with 19 additions and 9 deletions.
28 changes: 19 additions & 9 deletions outlines/generate/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,15 +340,6 @@ def token_generator() -> Iterator[Union[List[str], str, List[List[str]]]]:
return
generated_token_ids = sequence.token_ids[:, -num_generated:]
generated_sequences = self.tokenizer.decode(generated_token_ids)
next_tokens = [
token[len(sequence) :] if not stop else ""
for token, sequence, stop in zip(
generated_sequences,
previously_generated_sequences,
is_stop_at_reached,
)
]
previously_generated_sequences = generated_sequences
if stop_sequences:
is_stop_at_reached = [
stop
Expand All @@ -360,6 +351,25 @@ def token_generator() -> Iterator[Union[List[str], str, List[List[str]]]]:
)
]

generated_sequences = [
self.format_sequence(
self.strip_stop_sequences(sequence, stop_sequences)
)
if stop
else sequence
for sequence, stop in zip(
generated_sequences, is_stop_at_reached
)
]
next_tokens = [
token[len(sequence) :]
for token, sequence, stop in zip(
generated_sequences,
previously_generated_sequences,
is_stop_at_reached,
)
]
previously_generated_sequences = generated_sequences
# We reshape the output to (batch_size, sample_size)
output: List[List[str]] = list()
for i in range(batch_size):
Expand Down

0 comments on commit 159d1ec

Please sign in to comment.