From 7ebb0b203a02486855bf357d64f7e80ab18c59e8 Mon Sep 17 00:00:00 2001 From: Glenn Maynard Date: Mon, 24 Jun 2024 22:13:13 +0000 Subject: [PATCH 1/2] Fix logprobs when multiple tokens are returned at once. --- backends/exllamav2/model.py | 81 ++++++++++++++++++-------- endpoints/OAI/types/chat_completion.py | 5 +- endpoints/OAI/utils/chat_completion.py | 25 ++++---- endpoints/OAI/utils/completion.py | 7 ++- 4 files changed, 80 insertions(+), 38 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index b28cfd87..741501f7 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -767,21 +767,25 @@ def get_special_tokens( } def get_logprobs(self, token_ids: torch.Tensor, token_probs: torch.Tensor): - top_tokens = [ - self.tokenizer.extended_id_to_piece.get( - index, self.tokenizer.id_to_piece[index] - ) - for index in token_ids.flatten().tolist() - ] + logprobs = [] + for token_idx in range(token_ids.shape[1]): + top_tokens = [ + self.tokenizer.extended_id_to_piece.get( + index, self.tokenizer.id_to_piece[index] + ) + for index in token_ids[0, token_idx].tolist() + ] - top_values = torch.log(token_probs).flatten().tolist() + top_values = torch.log(token_probs[0, token_idx]).tolist() - # Cannot return -inf in JSON - cleaned_values = [ - -1000 if value == float("-inf") else value for value in top_values - ] + # Cannot return -inf in JSON + cleaned_values = [ + -1000 if value == float("-inf") else value for value in top_values + ] - return dict(zip_longest(top_tokens, cleaned_values)) + logprobs.append(dict(zip_longest(top_tokens, cleaned_values))) + + return logprobs async def generate(self, prompt: str, **kwargs): """Generate a response to a prompt""" @@ -793,8 +797,9 @@ async def generate(self, prompt: str, **kwargs): "text": "", "prompt_tokens": 0, "generation_tokens": 0, + "tokens": [], "offset": [], - "token_probs": {}, + "token_probs": [], "logprobs": [], } @@ -811,13 +816,14 @@ async def generate(self, prompt: str, **kwargs): if len(generations) > 0: for generation in generations: joined_generation["text"] += unwrap(generation.get("text"), "") - joined_generation["offset"].append(unwrap(generation.get("offset"), -1)) - joined_generation["token_probs"].update( - unwrap(generation.get("token_probs"), {}) + joined_generation["tokens"].extend(unwrap(generation.get("tokens"), [])) + joined_generation["offset"].extend(unwrap(generation.get("offset"), [])) + joined_generation["token_probs"].extend( + unwrap(generation.get("token_probs"), []) ) # Include empty logprob dicts for index preservation - joined_generation["logprobs"].append( + joined_generation["logprobs"].extend( unwrap(generation.get("logprobs"), {}) ) @@ -1145,7 +1151,6 @@ async def generate_gen( "text": chunk, "prompt_tokens": context_len, "generated_tokens": generated_tokens, - "offset": len(full_response), } if request_logprobs > 0: @@ -1164,11 +1169,41 @@ async def generate_gen( logprobs = self.get_logprobs(top_tokens, top_probs) generation["logprobs"] = logprobs - # The first logprob is the selected token prob - generation["token_probs"] = { - token: logprobs[token] - for token in list(logprobs.keys())[:1] - } + token_ids = unwrap( + result.get("token_ids"), + torch.empty(0), + ) + + token_probs = unwrap( + result.get("token_probs"), + torch.empty(0), + ) + + if token_ids.numel() > 0 and token_probs.numel() > 0: + token_ids = token_ids.flatten().tolist() + token_probs = token_probs.flatten().tolist() + + tokens = [ + self.tokenizer.extended_id_to_piece.get( + index, self.tokenizer.id_to_piece[index] + ) + for index in token_ids + ] + + generation["tokens"] = tokens + generation["token_probs"] = [ + math.log(prob) for prob in token_probs + ] + + # Calculate the offset of each token in the output, + # working backwards from the end. + offsets = [] + token_offset = 0 + for token in tokens: + token_offset += len(token) + offsets.append(len(full_response) - token_offset) + offsets.reverse() + generation["offset"] = offsets yield generation diff --git a/endpoints/OAI/types/chat_completion.py b/endpoints/OAI/types/chat_completion.py index be5cfea9..5b88e625 100644 --- a/endpoints/OAI/types/chat_completion.py +++ b/endpoints/OAI/types/chat_completion.py @@ -9,11 +9,14 @@ class ChatCompletionLogprob(BaseModel): token: str logprob: float + + +class ChatCompletionLogprobChoice(ChatCompletionLogprob): top_logprobs: Optional[List["ChatCompletionLogprob"]] = None class ChatCompletionLogprobs(BaseModel): - content: List[ChatCompletionLogprob] = Field(default_factory=list) + content: List[ChatCompletionLogprobChoice] = Field(default_factory=list) class ChatCompletionMessage(BaseModel): diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index 9e82b1b6..a61952a7 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -22,6 +22,7 @@ from endpoints.OAI.types.chat_completion import ( ChatCompletionLogprobs, ChatCompletionLogprob, + ChatCompletionLogprobChoice, ChatCompletionMessage, ChatCompletionRequest, ChatCompletionRespChoice, @@ -46,22 +47,24 @@ def _create_response(generations: List[dict], model_name: Optional[str]): logprob_response = None - token_probs = unwrap(generation.get("token_probs"), {}) + tokens = unwrap(generation.get("tokens"), []) + token_probs = unwrap(generation.get("token_probs"), []) + logprobs = unwrap(generation.get("logprobs"), []) if token_probs: - logprobs = unwrap(generation.get("logprobs"), []) - collected_token_probs = [] - for index, token in enumerate(token_probs.keys()): - top_logprobs = [ - ChatCompletionLogprob(token=token, logprob=logprob) - for token, logprob in logprobs[index].items() + for output_token, token_logprob, top_logprobs in zip( + tokens, token_probs, logprobs, strict=True + ): + completion_logprobs = [ + ChatCompletionLogprob(token=token, logprob=token_logprob) + for token, token_logprob in top_logprobs.items() ] collected_token_probs.append( - ChatCompletionLogprob( - token=token, - logprob=token_probs[token], - top_logprobs=top_logprobs, + ChatCompletionLogprobChoice( + token=output_token, + logprob=token_logprob, + top_logprobs=completion_logprobs, ) ) diff --git a/endpoints/OAI/utils/completion.py b/endpoints/OAI/utils/completion.py index 2b5dfbf2..65905954 100644 --- a/endpoints/OAI/utils/completion.py +++ b/endpoints/OAI/utils/completion.py @@ -35,15 +35,16 @@ def _create_response(generations: Union[dict, List[dict]], model_name: str = "") for index, generation in enumerate(generations): logprob_response = None - token_probs = unwrap(generation.get("token_probs"), {}) + tokens = unwrap(generation.get("tokens"), []) + token_probs = unwrap(generation.get("token_probs"), []) if token_probs: logprobs = unwrap(generation.get("logprobs"), []) offset = unwrap(generation.get("offset"), []) logprob_response = CompletionLogProbs( text_offset=offset if isinstance(offset, list) else [offset], - token_logprobs=token_probs.values(), - tokens=token_probs.keys(), + token_logprobs=token_probs, + tokens=tokens, top_logprobs=logprobs if isinstance(logprobs, list) else [logprobs], ) From 958e22297267c2b68225057b497a019b6e39d489 Mon Sep 17 00:00:00 2001 From: Glenn Maynard Date: Tue, 25 Jun 2024 01:36:18 +0000 Subject: [PATCH 2/2] Logprobs fixes for streaming chat/completions. This also brings the two chat/completions code paths back into alignment. --- endpoints/OAI/utils/chat_completion.py | 44 +++++++++++++++----------- 1 file changed, 26 insertions(+), 18 deletions(-) diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index a61952a7..447021e0 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -47,12 +47,12 @@ def _create_response(generations: List[dict], model_name: Optional[str]): logprob_response = None - tokens = unwrap(generation.get("tokens"), []) token_probs = unwrap(generation.get("token_probs"), []) - logprobs = unwrap(generation.get("logprobs"), []) if token_probs: + tokens = unwrap(generation.get("tokens"), []) + logprobs = unwrap(generation.get("logprobs"), []) collected_token_probs = [] - for output_token, token_logprob, top_logprobs in zip( + for generated_token, generated_token_logprob, top_logprobs in zip( tokens, token_probs, logprobs, strict=True ): completion_logprobs = [ @@ -62,8 +62,8 @@ def _create_response(generations: List[dict], model_name: Optional[str]): collected_token_probs.append( ChatCompletionLogprobChoice( - token=output_token, - logprob=token_logprob, + token=generated_token, + logprob=generated_token_logprob, top_logprobs=completion_logprobs, ) ) @@ -112,22 +112,30 @@ def _create_stream_chunk( role="assistant", content=unwrap(generation.get("text"), "") ) + logprob_response = None + token_probs = unwrap(generation.get("token_probs"), {}) if token_probs: - logprobs = unwrap(generation.get("logprobs"), {}) - top_logprobs = [ - ChatCompletionLogprob(token=token, logprob=logprob) - for token, logprob in logprobs.items() - ] - - generated_token = next(iter(token_probs)) - token_prob_response = ChatCompletionLogprob( - token=generated_token, - logprob=token_probs[generated_token], - top_logprobs=top_logprobs, - ) + tokens = unwrap(generation.get("tokens"), []) + logprobs = unwrap(generation.get("logprobs"), []) + collected_token_probs = [] + for generated_token, generated_token_logprob, top_logprobs in zip( + tokens, token_probs, logprobs, strict=True + ): + completion_logprobs = [ + ChatCompletionLogprob(token=token, logprob=token_logprob) + for token, token_logprob in top_logprobs.items() + ] + + collected_token_probs.append( + ChatCompletionLogprobChoice( + token=generated_token, + logprob=generated_token_logprob, + top_logprobs=completion_logprobs, + ) + ) - logprob_response = ChatCompletionLogprobs(content=[token_prob_response]) + logprob_response = ChatCompletionLogprobs(content=collected_token_probs) choice = ChatCompletionStreamChoice( index=index,