diff --git a/outlines/generate/api.py b/outlines/generate/api.py index 3f4f182d2..51a995664 100644 --- a/outlines/generate/api.py +++ b/outlines/generate/api.py @@ -340,15 +340,6 @@ def token_generator() -> Iterator[Union[List[str], str, List[List[str]]]]: return generated_token_ids = sequence.token_ids[:, -num_generated:] generated_sequences = self.tokenizer.decode(generated_token_ids) - next_tokens = [ - token[len(sequence) :] if not stop else "" - for token, sequence, stop in zip( - generated_sequences, - previously_generated_sequences, - is_stop_at_reached, - ) - ] - previously_generated_sequences = generated_sequences if stop_sequences: is_stop_at_reached = [ stop @@ -360,6 +351,25 @@ def token_generator() -> Iterator[Union[List[str], str, List[List[str]]]]: ) ] + generated_sequences = [ + self.format_sequence( + self.strip_stop_sequences(sequence, stop_sequences) + ) + if stop + else sequence + for sequence, stop in zip( + generated_sequences, is_stop_at_reached + ) + ] + next_tokens = [ + token[len(sequence) :] + for token, sequence, stop in zip( + generated_sequences, + previously_generated_sequences, + is_stop_at_reached, + ) + ] + previously_generated_sequences = generated_sequences # We reshape the output to (batch_size, sample_size) output: List[List[str]] = list() for i in range(batch_size):