Skip to content

Commit

Permalink
Dynamic gen: Return held output with last results
Browse files Browse the repository at this point in the history
  • Loading branch information
turboderp committed Jul 21, 2024
1 parent 304e021 commit 81cd6b7
Showing 1 changed file with 31 additions and 7 deletions.
38 changes: 31 additions & 7 deletions exllamav2/generator/dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1856,7 +1856,8 @@ def emit(
suppressed_text = None,
suppressed_tokens = None,
stop_token: int = None,
stop_string: str = None
stop_string: str = None,
rem_held_text: str = None
):
r = {
"job": self,
Expand Down Expand Up @@ -1919,18 +1920,29 @@ def emit(
"accepted_draft_tokens": self.accepted_draft_tokens,
"rejected_draft_tokens": self.rejected_draft_tokens
})
if eos_reason == "stop_string":
self.held_text = rem_held_text
rh = {}
if self.held_text:
rh.update({ "text": self.held_text })
if self.held_tokens:
rh.update({ "token_ids": self.held_tokens.torch().clone() })
if self.held_probs:
rh.update({ "token_probs": self.held_probs.torch().clone() })
if self.held_k_tokens:
rh.update({ "top_k_tokens": self.held_k_tokens.torch().clone() })
rh.update({ "top_k_probs": self.held_k_probs.torch().clone() })
if self.held_logits:
rh.update({ "logits": self.held_logits.torch().clone() })
if rh:
r.update({ "held": rh })

if self.identifier is not None:
r.update({ "identifier": self.identifier })

results.append(r)
return emit_eos, next_token

# End on stop tokens

if next_token.item() in self.stop_tokens:
return emit(results, emit_eos = True, eos_reason = "stop_token", stop_token = next_token.item())

# Decode and buffer output

id_to_piece = self.generator.tokenizer.get_id_to_piece_list(self.decode_special_tokens)
Expand All @@ -1950,6 +1962,11 @@ def emit(
if self.return_logits:
self.held_logits.append(logits[:1, :, :])

# End on stop tokens

if next_token.item() in self.stop_tokens:
return emit(results, emit_eos = True, eos_reason = "stop_token", stop_token = next_token.item())

# Stop if we reach max_new_tokens

if self.new_tokens >= self.max_new_tokens - self.generator.num_draft_tokens:
Expand Down Expand Up @@ -2052,7 +2069,14 @@ def rewind_checkpoint():
self.held_text = self.held_text[:match]
for s in self.stop_strings:
if held.startswith(s):
return emit(results, emit_eos = True, emit_held = True, eos_reason = "stop_string", stop_string = s)
return emit(
results,
emit_eos = True,
emit_held = True,
eos_reason = "stop_string",
stop_string = s,
rem_held_text = held
)
assert False, "Detected stop string but couldn't identify it (logic error)"
if match == -2:
return emit(results)
Expand Down

0 comments on commit 81cd6b7

Please sign in to comment.