diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 45a79dfe..74a02f69 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -1,5 +1,6 @@ """The model container class for ExLlamaV2 models.""" import gc +from itertools import zip_longest import pathlib import time @@ -486,9 +487,11 @@ def get_logprobs(self, logits: torch.Tensor, max_logprobs: int): ) top_values = top_values[0].tolist() - return dict(zip(top_tokens, top_values, strict=True)) + return dict(zip_longest(top_tokens, top_values)) def get_token_probs(self, token_ids: torch.tensor, token_probs: torch.Tensor): + normalized_probs = torch.log(token_probs) + tokens = list( map( lambda index: self.tokenizer.extended_id_to_piece.get( @@ -498,7 +501,7 @@ def get_token_probs(self, token_ids: torch.tensor, token_probs: torch.Tensor): ) ) - return dict(zip(tokens, token_probs[0].tolist(), strict=True)) + return dict(zip_longest(tokens, normalized_probs[0].tolist())) def generate(self, prompt: str, **kwargs): """Generate a response to a prompt"""