Skip to content

Commit

Permalink
API: Fix finish_reason returns
Browse files Browse the repository at this point in the history
OAI expects finish_reason to be "stop" or "length" (there are others,
but they're not in the current scope of this project).

Make all completions and chat completions responses return this
from the model generation itself rather than putting a placeholder.

Signed-off-by: kingbri <[email protected]>
  • Loading branch information
kingbri1 committed Mar 18, 2024
1 parent 25f5d4a commit 5c7fc69
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 16 deletions.
7 changes: 7 additions & 0 deletions backends/exllamav2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,9 @@ async def generate(self, prompt: str, **kwargs):
joined_generation["generation_tokens"] = unwrap(
generations[-1].get("generated_tokens"), 0
)
joined_generation["finish_reason"] = unwrap(
generations[-1].get("finish_reason"), "stop"
)

return joined_generation

Expand Down Expand Up @@ -1004,6 +1007,10 @@ def generate_gen_sync(self, prompt: str, **kwargs):
last_chunk_time = now

if eos or generated_tokens == max_tokens:
finish_reason = "length" if generated_tokens == max_tokens else "stop"
generation = {"finish_reason": finish_reason}
yield generation

break

# Print response
Expand Down
4 changes: 2 additions & 2 deletions endpoints/OAI/types/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@ class ChatCompletionMessage(BaseModel):
class ChatCompletionRespChoice(BaseModel):
# Index is 0 since we aren't using multiple choices
index: int = 0
finish_reason: str
finish_reason: Optional[str] = None
message: ChatCompletionMessage
logprobs: Optional[ChatCompletionLogprobs] = None


class ChatCompletionStreamChoice(BaseModel):
# Index is 0 since we aren't using multiple choices
index: int = 0
finish_reason: Optional[str]
finish_reason: Optional[str] = None
delta: Union[ChatCompletionMessage, dict] = {}
logprobs: Optional[ChatCompletionLogprobs] = None

Expand Down
2 changes: 1 addition & 1 deletion endpoints/OAI/types/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class CompletionRespChoice(BaseModel):

# Index is 0 since we aren't using multiple choices
index: int = 0
finish_reason: str
finish_reason: Optional[str] = None
logprobs: Optional[CompletionLogProbs] = None
text: str

Expand Down
27 changes: 17 additions & 10 deletions endpoints/OAI/utils/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ def _create_response(generation: dict, model_name: Optional[str]):
logprob_response = ChatCompletionLogprobs(content=collected_token_probs)

choice = ChatCompletionRespChoice(
finish_reason="Generated", message=message, logprobs=logprob_response
finish_reason=generation.get("finish_reason"),
message=message,
logprobs=logprob_response,
)

prompt_tokens = unwrap(generation.get("prompt_tokens"), 0)
Expand All @@ -83,14 +85,15 @@ def _create_stream_chunk(
const_id: str,
generation: Optional[dict] = None,
model_name: Optional[str] = None,
finish_reason: Optional[str] = None,
):
"""Create a chat completion stream chunk from the provided text."""

logprob_response = None

if finish_reason:
message = {}
if "finish_reason" in generation:
choice = ChatCompletionStreamChoice(
finish_reason=generation.get("finish_reason")
)
else:
message = ChatCompletionMessage(
role="assistant", content=unwrap(generation.get("text"), "")
Expand All @@ -113,10 +116,10 @@ def _create_stream_chunk(

logprob_response = ChatCompletionLogprobs(content=[token_prob_response])

# The finish reason can be None
choice = ChatCompletionStreamChoice(
finish_reason=finish_reason, delta=message, logprobs=logprob_response
)
choice = ChatCompletionStreamChoice(
delta=message,
logprobs=logprob_response,
)

chunk = ChatCompletionStreamChunk(
id=const_id, choices=[choice], model=unwrap(model_name, "")
Expand Down Expand Up @@ -165,10 +168,14 @@ async def stream_generate_chat_completion(

yield response.model_dump_json()

# Break if the generation is finished
if "finish_reason" in generation:
break

# Yield a finish response on successful generation
finish_response = _create_stream_chunk(const_id, finish_reason="stop")
# finish_response = _create_stream_chunk(const_id, finish_reason="stop")

yield finish_response.model_dump_json()
# yield finish_response.model_dump_json()
except CancelledError:
# Get out if the request gets disconnected

Expand Down
10 changes: 7 additions & 3 deletions endpoints/OAI/utils/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def _create_response(generation: dict, model_name: Optional[str]):
)

choice = CompletionRespChoice(
finish_reason="Generated",
finish_reason=generation.get("finish_reason"),
text=unwrap(generation.get("text"), ""),
logprobs=logprob_response,
)
Expand Down Expand Up @@ -69,11 +69,15 @@ async def stream_generate_completion(data: CompletionRequest, model_path: pathli
)
async for generation in new_generation:
response = _create_response(generation, model_path.name)

yield response.model_dump_json()

# Break if the generation is finished
if "finish_reason" in generation:
yield "[DONE]"
break

# Yield a finish response on successful generation
yield "[DONE]"
# yield "[DONE]"
except CancelledError:
# Get out if the request gets disconnected

Expand Down

0 comments on commit 5c7fc69

Please sign in to comment.