From b01f6cb7eed436f030e871673b46336560899b73 Mon Sep 17 00:00:00 2001 From: kingbri Date: Thu, 8 Feb 2024 19:28:19 -0500 Subject: [PATCH] Model: Fix logprobs unwrapping Take a log of the token probs since they're already normalized which reflects the proper value. Also, don't error out if a token prob doesn't exist in the dict and return None instead from zip. Signed-off-by: kingbri --- backends/exllamav2/model.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 45a79dfe..74a02f69 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -1,5 +1,6 @@ """The model container class for ExLlamaV2 models.""" import gc +from itertools import zip_longest import pathlib import time @@ -486,9 +487,11 @@ def get_logprobs(self, logits: torch.Tensor, max_logprobs: int): ) top_values = top_values[0].tolist() - return dict(zip(top_tokens, top_values, strict=True)) + return dict(zip_longest(top_tokens, top_values)) def get_token_probs(self, token_ids: torch.tensor, token_probs: torch.Tensor): + normalized_probs = torch.log(token_probs) + tokens = list( map( lambda index: self.tokenizer.extended_id_to_piece.get( @@ -498,7 +501,7 @@ def get_token_probs(self, token_ids: torch.tensor, token_probs: torch.Tensor): ) ) - return dict(zip(tokens, token_probs[0].tolist(), strict=True)) + return dict(zip_longest(tokens, normalized_probs[0].tolist())) def generate(self, prompt: str, **kwargs): """Generate a response to a prompt"""