From b149d3398d4f62c282452121fc6057c4439eb123 Mon Sep 17 00:00:00 2001 From: Volodymyr Kuznetsov Date: Mon, 8 Jul 2024 13:42:54 -0700 Subject: [PATCH 1/2] OAI: support stream_options argument --- endpoints/OAI/types/chat_completion.py | 1 + endpoints/OAI/types/common.py | 5 +++++ endpoints/OAI/utils/chat_completion.py | 19 ++++++++++++++++++- 3 files changed, 24 insertions(+), 1 deletion(-) diff --git a/endpoints/OAI/types/chat_completion.py b/endpoints/OAI/types/chat_completion.py index be5cfea9..b50e646b 100644 --- a/endpoints/OAI/types/chat_completion.py +++ b/endpoints/OAI/types/chat_completion.py @@ -64,3 +64,4 @@ class ChatCompletionStreamChunk(BaseModel): created: int = Field(default_factory=lambda: int(time())) model: str object: str = "chat.completion.chunk" + usage: Optional[UsageStats] = None diff --git a/endpoints/OAI/types/common.py b/endpoints/OAI/types/common.py index d44e41a5..6970adf7 100644 --- a/endpoints/OAI/types/common.py +++ b/endpoints/OAI/types/common.py @@ -18,6 +18,10 @@ class CompletionResponseFormat(BaseModel): type: str = "text" +class ChatCompletionStreamOptions(BaseModel): + include_usage: Optional[bool] = False + + class CommonCompletionRequest(BaseSamplerRequest): """Represents a common completion request.""" @@ -27,6 +31,7 @@ class CommonCompletionRequest(BaseSamplerRequest): # Generation info (remainder is in BaseSamplerRequest superclass) stream: Optional[bool] = False + stream_options: Optional[ChatCompletionStreamOptions] = None logprobs: Optional[int] = Field( default_factory=lambda: get_default_sampler_value("logprobs", 0) ) diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index 9e82b1b6..9b91d1d8 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -246,6 +246,7 @@ async def stream_generate_chat_completion( gen_queue = asyncio.Queue() gen_tasks: List[asyncio.Task] = [] disconnect_task = asyncio.create_task(request_disconnect_loop(request)) + need_usage = data.stream_options and data.stream_options.include_usage try: gen_params = data.to_gen_params() @@ -275,10 +276,26 @@ async def stream_generate_chat_completion( raise generation response = _create_stream_chunk(const_id, generation, model_path.name) - yield response.model_dump_json() + yield response.model_dump_json(exclude=None if need_usage else "usage") # Check if all tasks are completed if all(task.done() for task in gen_tasks) and gen_queue.empty(): + if need_usage: + prompt_tokens = unwrap(generation.get("prompt_tokens"), 0) + completion_tokens = unwrap(generation.get("generated_tokens"), 0) + + response = ChatCompletionStreamChunk( + id=const_id, + choices=[], + model=unwrap(model_path.name, ""), + usage=UsageStats( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ), + ) + + yield response.model_dump_json() break except CancelledError: # Get out if the request gets disconnected From c1b61441f46d7601076fe7f737475b3f2392f61d Mon Sep 17 00:00:00 2001 From: kingbri Date: Fri, 12 Jul 2024 14:35:48 -0400 Subject: [PATCH 2/2] OAI: Fix usage chunk return Place the logic into their proper utility functions and cleanup the code with formatting. Also, OAI's docs specify that a [DONE] return is needed when everything is finished. Signed-off-by: kingbri --- endpoints/OAI/utils/chat_completion.py | 49 ++++++++++++++++---------- 1 file changed, 30 insertions(+), 19 deletions(-) diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index 9b91d1d8..10f25cdd 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -93,22 +93,37 @@ def _create_stream_chunk( const_id: str, generation: Optional[dict] = None, model_name: Optional[str] = None, + is_usage_chunk: bool = False, ): """Create a chat completion stream chunk from the provided text.""" index = generation.get("index") - logprob_response = None + choices = [] + usage_stats = None + + if is_usage_chunk: + prompt_tokens = unwrap(generation.get("prompt_tokens"), 0) + completion_tokens = unwrap(generation.get("generated_tokens"), 0) - if "finish_reason" in generation: + usage_stats = UsageStats( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ) + elif "finish_reason" in generation: choice = ChatCompletionStreamChoice( index=index, finish_reason=generation.get("finish_reason"), ) + + choices.append(choice) else: message = ChatCompletionMessage( role="assistant", content=unwrap(generation.get("text"), "") ) + logprob_response = None + token_probs = unwrap(generation.get("token_probs"), {}) if token_probs: logprobs = unwrap(generation.get("logprobs"), {}) @@ -132,8 +147,13 @@ def _create_stream_chunk( logprobs=logprob_response, ) + choices.append(choice) + chunk = ChatCompletionStreamChunk( - id=const_id, choices=[choice], model=unwrap(model_name, "") + id=const_id, + choices=choices, + model=unwrap(model_name, ""), + usage=usage_stats, ) return chunk @@ -246,7 +266,6 @@ async def stream_generate_chat_completion( gen_queue = asyncio.Queue() gen_tasks: List[asyncio.Task] = [] disconnect_task = asyncio.create_task(request_disconnect_loop(request)) - need_usage = data.stream_options and data.stream_options.include_usage try: gen_params = data.to_gen_params() @@ -276,26 +295,18 @@ async def stream_generate_chat_completion( raise generation response = _create_stream_chunk(const_id, generation, model_path.name) - yield response.model_dump_json(exclude=None if need_usage else "usage") + yield response.model_dump_json() # Check if all tasks are completed if all(task.done() for task in gen_tasks) and gen_queue.empty(): - if need_usage: - prompt_tokens = unwrap(generation.get("prompt_tokens"), 0) - completion_tokens = unwrap(generation.get("generated_tokens"), 0) - - response = ChatCompletionStreamChunk( - id=const_id, - choices=[], - model=unwrap(model_path.name, ""), - usage=UsageStats( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens, - ), + # Send a usage chunk + if data.stream_options and data.stream_options.include_usage: + usage_chunk = _create_stream_chunk( + const_id, generation, model_path.name, is_usage_chunk=True ) + yield usage_chunk.model_dump_json() - yield response.model_dump_json() + yield "[DONE]" break except CancelledError: # Get out if the request gets disconnected