diff --git a/endpoints/OAI/types/chat_completion.py b/endpoints/OAI/types/chat_completion.py index be5cfea9..b50e646b 100644 --- a/endpoints/OAI/types/chat_completion.py +++ b/endpoints/OAI/types/chat_completion.py @@ -64,3 +64,4 @@ class ChatCompletionStreamChunk(BaseModel): created: int = Field(default_factory=lambda: int(time())) model: str object: str = "chat.completion.chunk" + usage: Optional[UsageStats] = None diff --git a/endpoints/OAI/types/common.py b/endpoints/OAI/types/common.py index d44e41a5..6970adf7 100644 --- a/endpoints/OAI/types/common.py +++ b/endpoints/OAI/types/common.py @@ -18,6 +18,10 @@ class CompletionResponseFormat(BaseModel): type: str = "text" +class ChatCompletionStreamOptions(BaseModel): + include_usage: Optional[bool] = False + + class CommonCompletionRequest(BaseSamplerRequest): """Represents a common completion request.""" @@ -27,6 +31,7 @@ class CommonCompletionRequest(BaseSamplerRequest): # Generation info (remainder is in BaseSamplerRequest superclass) stream: Optional[bool] = False + stream_options: Optional[ChatCompletionStreamOptions] = None logprobs: Optional[int] = Field( default_factory=lambda: get_default_sampler_value("logprobs", 0) ) diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index 9e82b1b6..10f25cdd 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -93,22 +93,37 @@ def _create_stream_chunk( const_id: str, generation: Optional[dict] = None, model_name: Optional[str] = None, + is_usage_chunk: bool = False, ): """Create a chat completion stream chunk from the provided text.""" index = generation.get("index") - logprob_response = None + choices = [] + usage_stats = None + + if is_usage_chunk: + prompt_tokens = unwrap(generation.get("prompt_tokens"), 0) + completion_tokens = unwrap(generation.get("generated_tokens"), 0) - if "finish_reason" in generation: + usage_stats = UsageStats( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ) + elif "finish_reason" in generation: choice = ChatCompletionStreamChoice( index=index, finish_reason=generation.get("finish_reason"), ) + + choices.append(choice) else: message = ChatCompletionMessage( role="assistant", content=unwrap(generation.get("text"), "") ) + logprob_response = None + token_probs = unwrap(generation.get("token_probs"), {}) if token_probs: logprobs = unwrap(generation.get("logprobs"), {}) @@ -132,8 +147,13 @@ def _create_stream_chunk( logprobs=logprob_response, ) + choices.append(choice) + chunk = ChatCompletionStreamChunk( - id=const_id, choices=[choice], model=unwrap(model_name, "") + id=const_id, + choices=choices, + model=unwrap(model_name, ""), + usage=usage_stats, ) return chunk @@ -279,6 +299,14 @@ async def stream_generate_chat_completion( # Check if all tasks are completed if all(task.done() for task in gen_tasks) and gen_queue.empty(): + # Send a usage chunk + if data.stream_options and data.stream_options.include_usage: + usage_chunk = _create_stream_chunk( + const_id, generation, model_path.name, is_usage_chunk=True + ) + yield usage_chunk.model_dump_json() + + yield "[DONE]" break except CancelledError: # Get out if the request gets disconnected