diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index b28cfd87..3e194bbd 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -767,21 +767,25 @@ def get_special_tokens( } def get_logprobs(self, token_ids: torch.Tensor, token_probs: torch.Tensor): - top_tokens = [ - self.tokenizer.extended_id_to_piece.get( - index, self.tokenizer.id_to_piece[index] - ) - for index in token_ids.flatten().tolist() - ] + logprobs = [] + for token_idx in range(token_ids.shape[1]): + top_tokens = [ + self.tokenizer.extended_id_to_piece.get( + index, self.tokenizer.id_to_piece[index] + ) + for index in token_ids[0, token_idx].tolist() + ] - top_values = torch.log(token_probs).flatten().tolist() + top_values = torch.log(token_probs[0, token_idx]).tolist() - # Cannot return -inf in JSON - cleaned_values = [ - -1000 if value == float("-inf") else value for value in top_values - ] + # Cannot return -inf in JSON + cleaned_values = [ + -1000 if value == float("-inf") else value for value in top_values + ] - return dict(zip_longest(top_tokens, cleaned_values)) + logprobs.append(dict(zip_longest(top_tokens, cleaned_values))) + + return logprobs async def generate(self, prompt: str, **kwargs): """Generate a response to a prompt""" @@ -793,8 +797,9 @@ async def generate(self, prompt: str, **kwargs): "text": "", "prompt_tokens": 0, "generation_tokens": 0, + "tokens": [], "offset": [], - "token_probs": {}, + "token_probs": [], "logprobs": [], } @@ -811,13 +816,14 @@ async def generate(self, prompt: str, **kwargs): if len(generations) > 0: for generation in generations: joined_generation["text"] += unwrap(generation.get("text"), "") - joined_generation["offset"].append(unwrap(generation.get("offset"), -1)) - joined_generation["token_probs"].update( - unwrap(generation.get("token_probs"), {}) + joined_generation["tokens"].extend(unwrap(generation.get("tokens"), [])) + joined_generation["offset"].extend(unwrap(generation.get("offset"), [])) + joined_generation["token_probs"].extend( + unwrap(generation.get("token_probs"), []) ) # Include empty logprob dicts for index preservation - joined_generation["logprobs"].append( + joined_generation["logprobs"].extend( unwrap(generation.get("logprobs"), {}) ) @@ -1145,7 +1151,6 @@ async def generate_gen( "text": chunk, "prompt_tokens": context_len, "generated_tokens": generated_tokens, - "offset": len(full_response), } if request_logprobs > 0: @@ -1164,11 +1169,41 @@ async def generate_gen( logprobs = self.get_logprobs(top_tokens, top_probs) generation["logprobs"] = logprobs - # The first logprob is the selected token prob - generation["token_probs"] = { - token: logprobs[token] - for token in list(logprobs.keys())[:1] - } + token_ids = unwrap( + result.get("token_ids"), + torch.empty(0), + ) + + token_probs = unwrap( + result.get("token_probs"), + torch.empty(0), + ) + + if token_ids.numel() > 0 and token_probs.numel() > 0: + token_ids = token_ids.flatten().tolist() + token_probs = token_probs.flatten().tolist() + + tokens = [ + self.tokenizer.extended_id_to_piece.get( + index, self.tokenizer.id_to_piece[index] + ) + for index in token_ids + ] + + generation["tokens"] = tokens + generation["token_probs"] = [ + math.log(prob) for prob in token_probs + ] + + # Calculate the offset of each token in the output, working backwards from + # the end. + offsets = [] + token_offset = 0 + for token in tokens: + token_offset += len(token) + offsets.append(len(full_response) - token_offset) + offsets.reverse() + generation["offset"] = offsets yield generation diff --git a/endpoints/OAI/types/chat_completion.py b/endpoints/OAI/types/chat_completion.py index be5cfea9..5b88e625 100644 --- a/endpoints/OAI/types/chat_completion.py +++ b/endpoints/OAI/types/chat_completion.py @@ -9,11 +9,14 @@ class ChatCompletionLogprob(BaseModel): token: str logprob: float + + +class ChatCompletionLogprobChoice(ChatCompletionLogprob): top_logprobs: Optional[List["ChatCompletionLogprob"]] = None class ChatCompletionLogprobs(BaseModel): - content: List[ChatCompletionLogprob] = Field(default_factory=list) + content: List[ChatCompletionLogprobChoice] = Field(default_factory=list) class ChatCompletionMessage(BaseModel): diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index 9e82b1b6..1786fbc8 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -22,6 +22,7 @@ from endpoints.OAI.types.chat_completion import ( ChatCompletionLogprobs, ChatCompletionLogprob, + ChatCompletionLogprobChoice, ChatCompletionMessage, ChatCompletionRequest, ChatCompletionRespChoice, @@ -46,21 +47,24 @@ def _create_response(generations: List[dict], model_name: Optional[str]): logprob_response = None - token_probs = unwrap(generation.get("token_probs"), {}) + tokens = unwrap(generation.get("tokens"), []) + token_probs = unwrap(generation.get("token_probs"), []) if token_probs: logprobs = unwrap(generation.get("logprobs"), []) collected_token_probs = [] - for index, token in enumerate(token_probs.keys()): + for output_token, token_logprob, logprobs in zip( + tokens, token_probs, logprobs + ): top_logprobs = [ ChatCompletionLogprob(token=token, logprob=logprob) - for token, logprob in logprobs[index].items() + for token, logprob in logprobs.items() ] collected_token_probs.append( - ChatCompletionLogprob( - token=token, - logprob=token_probs[token], + ChatCompletionLogprobChoice( + token=output_token, + logprob=token_logprob, top_logprobs=top_logprobs, ) ) diff --git a/endpoints/OAI/utils/completion.py b/endpoints/OAI/utils/completion.py index 2b5dfbf2..65905954 100644 --- a/endpoints/OAI/utils/completion.py +++ b/endpoints/OAI/utils/completion.py @@ -35,15 +35,16 @@ def _create_response(generations: Union[dict, List[dict]], model_name: str = "") for index, generation in enumerate(generations): logprob_response = None - token_probs = unwrap(generation.get("token_probs"), {}) + tokens = unwrap(generation.get("tokens"), []) + token_probs = unwrap(generation.get("token_probs"), []) if token_probs: logprobs = unwrap(generation.get("logprobs"), []) offset = unwrap(generation.get("offset"), []) logprob_response = CompletionLogProbs( text_offset=offset if isinstance(offset, list) else [offset], - token_logprobs=token_probs.values(), - tokens=token_probs.keys(), + token_logprobs=token_probs, + tokens=tokens, top_logprobs=logprobs if isinstance(logprobs, list) else [logprobs], )