diff --git a/exllamav2/generator/dynamic.py b/exllamav2/generator/dynamic.py index da579e57..1802fa57 100644 --- a/exllamav2/generator/dynamic.py +++ b/exllamav2/generator/dynamic.py @@ -1856,7 +1856,8 @@ def emit( suppressed_text = None, suppressed_tokens = None, stop_token: int = None, - stop_string: str = None + stop_string: str = None, + rem_held_text: str = None ): r = { "job": self, @@ -1919,6 +1920,22 @@ def emit( "accepted_draft_tokens": self.accepted_draft_tokens, "rejected_draft_tokens": self.rejected_draft_tokens }) + if eos_reason == "stop_string": + self.held_text = rem_held_text + rh = {} + if self.held_text: + rh.update({ "text": self.held_text }) + if self.held_tokens: + rh.update({ "token_ids": self.held_tokens.torch().clone() }) + if self.held_probs: + rh.update({ "token_probs": self.held_probs.torch().clone() }) + if self.held_k_tokens: + rh.update({ "top_k_tokens": self.held_k_tokens.torch().clone() }) + rh.update({ "top_k_probs": self.held_k_probs.torch().clone() }) + if self.held_logits: + rh.update({ "logits": self.held_logits.torch().clone() }) + if rh: + r.update({ "held": rh }) if self.identifier is not None: r.update({ "identifier": self.identifier }) @@ -1926,11 +1943,6 @@ def emit( results.append(r) return emit_eos, next_token - # End on stop tokens - - if next_token.item() in self.stop_tokens: - return emit(results, emit_eos = True, eos_reason = "stop_token", stop_token = next_token.item()) - # Decode and buffer output id_to_piece = self.generator.tokenizer.get_id_to_piece_list(self.decode_special_tokens) @@ -1950,6 +1962,11 @@ def emit( if self.return_logits: self.held_logits.append(logits[:1, :, :]) + # End on stop tokens + + if next_token.item() in self.stop_tokens: + return emit(results, emit_eos = True, eos_reason = "stop_token", stop_token = next_token.item()) + # Stop if we reach max_new_tokens if self.new_tokens >= self.max_new_tokens - self.generator.num_draft_tokens: @@ -2052,7 +2069,14 @@ def rewind_checkpoint(): self.held_text = self.held_text[:match] for s in self.stop_strings: if held.startswith(s): - return emit(results, emit_eos = True, emit_held = True, eos_reason = "stop_string", stop_string = s) + return emit( + results, + emit_eos = True, + emit_held = True, + eos_reason = "stop_string", + stop_string = s, + rem_held_text = held + ) assert False, "Detected stop string but couldn't identify it (logic error)" if match == -2: return emit(results)