Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix logprobs when multiple tokens are returned at once. #141

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 58 additions & 23 deletions backends/exllamav2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,21 +767,25 @@ def get_special_tokens(
}

def get_logprobs(self, token_ids: torch.Tensor, token_probs: torch.Tensor):
top_tokens = [
self.tokenizer.extended_id_to_piece.get(
index, self.tokenizer.id_to_piece[index]
)
for index in token_ids.flatten().tolist()
]
logprobs = []
for token_idx in range(token_ids.shape[1]):
top_tokens = [
self.tokenizer.extended_id_to_piece.get(
index, self.tokenizer.id_to_piece[index]
)
for index in token_ids[0, token_idx].tolist()
]

top_values = torch.log(token_probs).flatten().tolist()
top_values = torch.log(token_probs[0, token_idx]).tolist()

# Cannot return -inf in JSON
cleaned_values = [
-1000 if value == float("-inf") else value for value in top_values
]
# Cannot return -inf in JSON
cleaned_values = [
-1000 if value == float("-inf") else value for value in top_values
]

return dict(zip_longest(top_tokens, cleaned_values))
logprobs.append(dict(zip_longest(top_tokens, cleaned_values)))

return logprobs

async def generate(self, prompt: str, **kwargs):
"""Generate a response to a prompt"""
Expand All @@ -793,8 +797,9 @@ async def generate(self, prompt: str, **kwargs):
"text": "",
"prompt_tokens": 0,
"generation_tokens": 0,
"tokens": [],
"offset": [],
"token_probs": {},
"token_probs": [],
"logprobs": [],
}

Expand All @@ -811,13 +816,14 @@ async def generate(self, prompt: str, **kwargs):
if len(generations) > 0:
for generation in generations:
joined_generation["text"] += unwrap(generation.get("text"), "")
joined_generation["offset"].append(unwrap(generation.get("offset"), -1))
joined_generation["token_probs"].update(
unwrap(generation.get("token_probs"), {})
joined_generation["tokens"].extend(unwrap(generation.get("tokens"), []))
joined_generation["offset"].extend(unwrap(generation.get("offset"), []))
joined_generation["token_probs"].extend(
unwrap(generation.get("token_probs"), [])
)

# Include empty logprob dicts for index preservation
joined_generation["logprobs"].append(
joined_generation["logprobs"].extend(
unwrap(generation.get("logprobs"), {})
)

Expand Down Expand Up @@ -1145,7 +1151,6 @@ async def generate_gen(
"text": chunk,
"prompt_tokens": context_len,
"generated_tokens": generated_tokens,
"offset": len(full_response),
}

if request_logprobs > 0:
Expand All @@ -1164,11 +1169,41 @@ async def generate_gen(
logprobs = self.get_logprobs(top_tokens, top_probs)
generation["logprobs"] = logprobs

# The first logprob is the selected token prob
generation["token_probs"] = {
token: logprobs[token]
for token in list(logprobs.keys())[:1]
}
token_ids = unwrap(
result.get("token_ids"),
torch.empty(0),
)

token_probs = unwrap(
result.get("token_probs"),
torch.empty(0),
)

if token_ids.numel() > 0 and token_probs.numel() > 0:
token_ids = token_ids.flatten().tolist()
token_probs = token_probs.flatten().tolist()

tokens = [
self.tokenizer.extended_id_to_piece.get(
index, self.tokenizer.id_to_piece[index]
)
for index in token_ids
]

generation["tokens"] = tokens
generation["token_probs"] = [
math.log(prob) for prob in token_probs
]

# Calculate the offset of each token in the output,
# working backwards from the end.
offsets = []
token_offset = 0
for token in tokens:
token_offset += len(token)
offsets.append(len(full_response) - token_offset)
offsets.reverse()
generation["offset"] = offsets

yield generation

Expand Down
5 changes: 4 additions & 1 deletion endpoints/OAI/types/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,14 @@
class ChatCompletionLogprob(BaseModel):
token: str
logprob: float


class ChatCompletionLogprobChoice(ChatCompletionLogprob):
top_logprobs: Optional[List["ChatCompletionLogprob"]] = None


class ChatCompletionLogprobs(BaseModel):
content: List[ChatCompletionLogprob] = Field(default_factory=list)
content: List[ChatCompletionLogprobChoice] = Field(default_factory=list)


class ChatCompletionMessage(BaseModel):
Expand Down
57 changes: 34 additions & 23 deletions endpoints/OAI/utils/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from endpoints.OAI.types.chat_completion import (
ChatCompletionLogprobs,
ChatCompletionLogprob,
ChatCompletionLogprobChoice,
ChatCompletionMessage,
ChatCompletionRequest,
ChatCompletionRespChoice,
Expand All @@ -46,22 +47,24 @@ def _create_response(generations: List[dict], model_name: Optional[str]):

logprob_response = None

token_probs = unwrap(generation.get("token_probs"), {})
token_probs = unwrap(generation.get("token_probs"), [])
if token_probs:
tokens = unwrap(generation.get("tokens"), [])
logprobs = unwrap(generation.get("logprobs"), [])

collected_token_probs = []
for index, token in enumerate(token_probs.keys()):
top_logprobs = [
ChatCompletionLogprob(token=token, logprob=logprob)
for token, logprob in logprobs[index].items()
for generated_token, generated_token_logprob, top_logprobs in zip(
tokens, token_probs, logprobs, strict=True
):
completion_logprobs = [
ChatCompletionLogprob(token=token, logprob=token_logprob)
for token, token_logprob in top_logprobs.items()
]

collected_token_probs.append(
ChatCompletionLogprob(
token=token,
logprob=token_probs[token],
top_logprobs=top_logprobs,
ChatCompletionLogprobChoice(
token=generated_token,
logprob=generated_token_logprob,
top_logprobs=completion_logprobs,
)
)

Expand Down Expand Up @@ -109,22 +112,30 @@ def _create_stream_chunk(
role="assistant", content=unwrap(generation.get("text"), "")
)

logprob_response = None

token_probs = unwrap(generation.get("token_probs"), {})
if token_probs:
logprobs = unwrap(generation.get("logprobs"), {})
top_logprobs = [
ChatCompletionLogprob(token=token, logprob=logprob)
for token, logprob in logprobs.items()
]

generated_token = next(iter(token_probs))
token_prob_response = ChatCompletionLogprob(
token=generated_token,
logprob=token_probs[generated_token],
top_logprobs=top_logprobs,
)
tokens = unwrap(generation.get("tokens"), [])
logprobs = unwrap(generation.get("logprobs"), [])
collected_token_probs = []
for generated_token, generated_token_logprob, top_logprobs in zip(
tokens, token_probs, logprobs, strict=True
):
completion_logprobs = [
ChatCompletionLogprob(token=token, logprob=token_logprob)
for token, token_logprob in top_logprobs.items()
]

collected_token_probs.append(
ChatCompletionLogprobChoice(
token=generated_token,
logprob=generated_token_logprob,
top_logprobs=completion_logprobs,
)
)

logprob_response = ChatCompletionLogprobs(content=[token_prob_response])
logprob_response = ChatCompletionLogprobs(content=collected_token_probs)

choice = ChatCompletionStreamChoice(
index=index,
Expand Down
7 changes: 4 additions & 3 deletions endpoints/OAI/utils/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,16 @@ def _create_response(generations: Union[dict, List[dict]], model_name: str = "")
for index, generation in enumerate(generations):
logprob_response = None

token_probs = unwrap(generation.get("token_probs"), {})
tokens = unwrap(generation.get("tokens"), [])
token_probs = unwrap(generation.get("token_probs"), [])
if token_probs:
logprobs = unwrap(generation.get("logprobs"), [])
offset = unwrap(generation.get("offset"), [])

logprob_response = CompletionLogProbs(
text_offset=offset if isinstance(offset, list) else [offset],
token_logprobs=token_probs.values(),
tokens=token_probs.keys(),
token_logprobs=token_probs,
tokens=tokens,
top_logprobs=logprobs if isinstance(logprobs, list) else [logprobs],
)

Expand Down
Loading