Skip to content

Commit

Permalink
Fix logprobs when multiple tokens are returned at once.
Browse files Browse the repository at this point in the history
  • Loading branch information
zewt committed Jun 24, 2024
1 parent d03752e commit e55d3c7
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 33 deletions.
81 changes: 58 additions & 23 deletions backends/exllamav2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,21 +767,25 @@ def get_special_tokens(
}

def get_logprobs(self, token_ids: torch.Tensor, token_probs: torch.Tensor):
top_tokens = [
self.tokenizer.extended_id_to_piece.get(
index, self.tokenizer.id_to_piece[index]
)
for index in token_ids.flatten().tolist()
]
logprobs = []
for token_idx in range(token_ids.shape[1]):
top_tokens = [
self.tokenizer.extended_id_to_piece.get(
index, self.tokenizer.id_to_piece[index]
)
for index in token_ids[0, token_idx].tolist()
]

top_values = torch.log(token_probs).flatten().tolist()
top_values = torch.log(token_probs[0, token_idx]).tolist()

# Cannot return -inf in JSON
cleaned_values = [
-1000 if value == float("-inf") else value for value in top_values
]
# Cannot return -inf in JSON
cleaned_values = [
-1000 if value == float("-inf") else value for value in top_values
]

return dict(zip_longest(top_tokens, cleaned_values))
logprobs.append(dict(zip_longest(top_tokens, cleaned_values)))

return logprobs

async def generate(self, prompt: str, **kwargs):
"""Generate a response to a prompt"""
Expand All @@ -793,8 +797,9 @@ async def generate(self, prompt: str, **kwargs):
"text": "",
"prompt_tokens": 0,
"generation_tokens": 0,
"tokens": [],
"offset": [],
"token_probs": {},
"token_probs": [],
"logprobs": [],
}

Expand All @@ -811,13 +816,14 @@ async def generate(self, prompt: str, **kwargs):
if len(generations) > 0:
for generation in generations:
joined_generation["text"] += unwrap(generation.get("text"), "")
joined_generation["offset"].append(unwrap(generation.get("offset"), -1))
joined_generation["token_probs"].update(
unwrap(generation.get("token_probs"), {})
joined_generation["tokens"].extend(unwrap(generation.get("tokens"), []))
joined_generation["offset"].extend(unwrap(generation.get("offset"), []))
joined_generation["token_probs"].extend(
unwrap(generation.get("token_probs"), [])
)

# Include empty logprob dicts for index preservation
joined_generation["logprobs"].append(
joined_generation["logprobs"].extend(
unwrap(generation.get("logprobs"), {})
)

Expand Down Expand Up @@ -1145,7 +1151,6 @@ async def generate_gen(
"text": chunk,
"prompt_tokens": context_len,
"generated_tokens": generated_tokens,
"offset": len(full_response),
}

if request_logprobs > 0:
Expand All @@ -1164,11 +1169,41 @@ async def generate_gen(
logprobs = self.get_logprobs(top_tokens, top_probs)
generation["logprobs"] = logprobs

# The first logprob is the selected token prob
generation["token_probs"] = {
token: logprobs[token]
for token in list(logprobs.keys())[:1]
}
token_ids = unwrap(
result.get("token_ids"),
torch.empty(0),
)

token_probs = unwrap(
result.get("token_probs"),
torch.empty(0),
)

if token_ids.numel() > 0 and token_probs.numel() > 0:
token_ids = token_ids.flatten().tolist()
token_probs = token_probs.flatten().tolist()

tokens = [
self.tokenizer.extended_id_to_piece.get(
index, self.tokenizer.id_to_piece[index]
)
for index in token_ids
]

generation["tokens"] = tokens
generation["token_probs"] = [
math.log(prob) for prob in token_probs
]

# Calculate the offset of each token in the output, working backwards from
# the end.
offsets = []
token_offset = 0
for token in tokens:
token_offset += len(token)
offsets.append(len(full_response) - token_offset)
offsets.reverse()
generation["offset"] = offsets

yield generation

Expand Down
5 changes: 4 additions & 1 deletion endpoints/OAI/types/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,14 @@
class ChatCompletionLogprob(BaseModel):
token: str
logprob: float


class ChatCompletionLogprobChoice(ChatCompletionLogprob):
top_logprobs: Optional[List["ChatCompletionLogprob"]] = None


class ChatCompletionLogprobs(BaseModel):
content: List[ChatCompletionLogprob] = Field(default_factory=list)
content: List[ChatCompletionLogprobChoice] = Field(default_factory=list)


class ChatCompletionMessage(BaseModel):
Expand Down
16 changes: 10 additions & 6 deletions endpoints/OAI/utils/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from endpoints.OAI.types.chat_completion import (
ChatCompletionLogprobs,
ChatCompletionLogprob,
ChatCompletionLogprobChoice,
ChatCompletionMessage,
ChatCompletionRequest,
ChatCompletionRespChoice,
Expand All @@ -46,21 +47,24 @@ def _create_response(generations: List[dict], model_name: Optional[str]):

logprob_response = None

token_probs = unwrap(generation.get("token_probs"), {})
tokens = unwrap(generation.get("tokens"), [])
token_probs = unwrap(generation.get("token_probs"), [])
if token_probs:
logprobs = unwrap(generation.get("logprobs"), [])

collected_token_probs = []
for index, token in enumerate(token_probs.keys()):
for output_token, token_logprob, logprobs in zip(
tokens, token_probs, logprobs
):
top_logprobs = [
ChatCompletionLogprob(token=token, logprob=logprob)
for token, logprob in logprobs[index].items()
for token, logprob in logprobs.items()
]

collected_token_probs.append(
ChatCompletionLogprob(
token=token,
logprob=token_probs[token],
ChatCompletionLogprobChoice(
token=output_token,
logprob=token_logprob,
top_logprobs=top_logprobs,
)
)
Expand Down
7 changes: 4 additions & 3 deletions endpoints/OAI/utils/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,16 @@ def _create_response(generations: Union[dict, List[dict]], model_name: str = "")
for index, generation in enumerate(generations):
logprob_response = None

token_probs = unwrap(generation.get("token_probs"), {})
tokens = unwrap(generation.get("tokens"), [])
token_probs = unwrap(generation.get("token_probs"), [])
if token_probs:
logprobs = unwrap(generation.get("logprobs"), [])
offset = unwrap(generation.get("offset"), [])

logprob_response = CompletionLogProbs(
text_offset=offset if isinstance(offset, list) else [offset],
token_logprobs=token_probs.values(),
tokens=token_probs.keys(),
token_logprobs=token_probs,
tokens=tokens,
top_logprobs=logprobs if isinstance(logprobs, list) else [logprobs],
)

Expand Down

0 comments on commit e55d3c7

Please sign in to comment.