Skip to content

Commit

Permalink
OAI: Fix usage chunk return
Browse files Browse the repository at this point in the history
Place the logic into their proper utility functions and cleanup
the code with formatting.

Also, OAI's docs specify that a [DONE] return is needed when everything
is finished.

Signed-off-by: kingbri <[email protected]>
  • Loading branch information
bdashore3 committed Jul 12, 2024
1 parent b149d33 commit c1b6144
Showing 1 changed file with 30 additions and 19 deletions.
49 changes: 30 additions & 19 deletions endpoints/OAI/utils/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,22 +93,37 @@ def _create_stream_chunk(
const_id: str,
generation: Optional[dict] = None,
model_name: Optional[str] = None,
is_usage_chunk: bool = False,
):
"""Create a chat completion stream chunk from the provided text."""

index = generation.get("index")
logprob_response = None
choices = []
usage_stats = None

if is_usage_chunk:
prompt_tokens = unwrap(generation.get("prompt_tokens"), 0)
completion_tokens = unwrap(generation.get("generated_tokens"), 0)

if "finish_reason" in generation:
usage_stats = UsageStats(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
elif "finish_reason" in generation:
choice = ChatCompletionStreamChoice(
index=index,
finish_reason=generation.get("finish_reason"),
)

choices.append(choice)
else:
message = ChatCompletionMessage(
role="assistant", content=unwrap(generation.get("text"), "")
)

logprob_response = None

token_probs = unwrap(generation.get("token_probs"), {})
if token_probs:
logprobs = unwrap(generation.get("logprobs"), {})
Expand All @@ -132,8 +147,13 @@ def _create_stream_chunk(
logprobs=logprob_response,
)

choices.append(choice)

chunk = ChatCompletionStreamChunk(
id=const_id, choices=[choice], model=unwrap(model_name, "")
id=const_id,
choices=choices,
model=unwrap(model_name, ""),
usage=usage_stats,
)

return chunk
Expand Down Expand Up @@ -246,7 +266,6 @@ async def stream_generate_chat_completion(
gen_queue = asyncio.Queue()
gen_tasks: List[asyncio.Task] = []
disconnect_task = asyncio.create_task(request_disconnect_loop(request))
need_usage = data.stream_options and data.stream_options.include_usage

try:
gen_params = data.to_gen_params()
Expand Down Expand Up @@ -276,26 +295,18 @@ async def stream_generate_chat_completion(
raise generation

response = _create_stream_chunk(const_id, generation, model_path.name)
yield response.model_dump_json(exclude=None if need_usage else "usage")
yield response.model_dump_json()

# Check if all tasks are completed
if all(task.done() for task in gen_tasks) and gen_queue.empty():
if need_usage:
prompt_tokens = unwrap(generation.get("prompt_tokens"), 0)
completion_tokens = unwrap(generation.get("generated_tokens"), 0)

response = ChatCompletionStreamChunk(
id=const_id,
choices=[],
model=unwrap(model_path.name, ""),
usage=UsageStats(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
# Send a usage chunk
if data.stream_options and data.stream_options.include_usage:
usage_chunk = _create_stream_chunk(
const_id, generation, model_path.name, is_usage_chunk=True
)
yield usage_chunk.model_dump_json()

yield response.model_dump_json()
yield "[DONE]"
break
except CancelledError:
# Get out if the request gets disconnected
Expand Down

0 comments on commit c1b6144

Please sign in to comment.