diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index 9b91d1d8..10f25cdd 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -93,22 +93,37 @@ def _create_stream_chunk( const_id: str, generation: Optional[dict] = None, model_name: Optional[str] = None, + is_usage_chunk: bool = False, ): """Create a chat completion stream chunk from the provided text.""" index = generation.get("index") - logprob_response = None + choices = [] + usage_stats = None + + if is_usage_chunk: + prompt_tokens = unwrap(generation.get("prompt_tokens"), 0) + completion_tokens = unwrap(generation.get("generated_tokens"), 0) - if "finish_reason" in generation: + usage_stats = UsageStats( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ) + elif "finish_reason" in generation: choice = ChatCompletionStreamChoice( index=index, finish_reason=generation.get("finish_reason"), ) + + choices.append(choice) else: message = ChatCompletionMessage( role="assistant", content=unwrap(generation.get("text"), "") ) + logprob_response = None + token_probs = unwrap(generation.get("token_probs"), {}) if token_probs: logprobs = unwrap(generation.get("logprobs"), {}) @@ -132,8 +147,13 @@ def _create_stream_chunk( logprobs=logprob_response, ) + choices.append(choice) + chunk = ChatCompletionStreamChunk( - id=const_id, choices=[choice], model=unwrap(model_name, "") + id=const_id, + choices=choices, + model=unwrap(model_name, ""), + usage=usage_stats, ) return chunk @@ -246,7 +266,6 @@ async def stream_generate_chat_completion( gen_queue = asyncio.Queue() gen_tasks: List[asyncio.Task] = [] disconnect_task = asyncio.create_task(request_disconnect_loop(request)) - need_usage = data.stream_options and data.stream_options.include_usage try: gen_params = data.to_gen_params() @@ -276,26 +295,18 @@ async def stream_generate_chat_completion( raise generation response = _create_stream_chunk(const_id, generation, model_path.name) - yield response.model_dump_json(exclude=None if need_usage else "usage") + yield response.model_dump_json() # Check if all tasks are completed if all(task.done() for task in gen_tasks) and gen_queue.empty(): - if need_usage: - prompt_tokens = unwrap(generation.get("prompt_tokens"), 0) - completion_tokens = unwrap(generation.get("generated_tokens"), 0) - - response = ChatCompletionStreamChunk( - id=const_id, - choices=[], - model=unwrap(model_path.name, ""), - usage=UsageStats( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens, - ), + # Send a usage chunk + if data.stream_options and data.stream_options.include_usage: + usage_chunk = _create_stream_chunk( + const_id, generation, model_path.name, is_usage_chunk=True ) + yield usage_chunk.model_dump_json() - yield response.model_dump_json() + yield "[DONE]" break except CancelledError: # Get out if the request gets disconnected