From 9b3e488bbf063768e3dfac7e6f847743ac78e19b Mon Sep 17 00:00:00 2001 From: Alejandro Ponce Date: Mon, 2 Dec 2024 10:19:53 +0200 Subject: [PATCH] Respond with JSON if the request is non-async We are currently not handling non-streaming requests, e.g. ```sh % curl -SsX POST "http://localhost:8989/vllm/chat/completions" \ -H "Content-Type: application/json" \ -H "Authorization: Bearer $token" \ -d '{ "model": "Qwen/Qwen2.5-Coder-14B-Instruct", "messages": [{"role": "user", "content": "hello."}], "stream": false }' ``` This PR enables to respons with the entire JSON if the request was non-streaming --- .../extract_snippets/extract_snippets.py | 1 + src/codegate/providers/anthropic/provider.py | 2 +- src/codegate/providers/completion/base.py | 17 +++++++++++-- .../providers/litellmshim/litellmshim.py | 24 ++++++++++++++++--- .../providers/llamacpp/completion_handler.py | 7 ++++-- src/codegate/providers/llamacpp/provider.py | 2 +- src/codegate/providers/openai/provider.py | 2 +- src/codegate/providers/vllm/provider.py | 2 +- .../providers/litellmshim/test_litellmshim.py | 2 +- tests/providers/test_registry.py | 5 +++- 10 files changed, 51 insertions(+), 13 deletions(-) diff --git a/src/codegate/pipeline/extract_snippets/extract_snippets.py b/src/codegate/pipeline/extract_snippets/extract_snippets.py index a50460ee..bee756e0 100644 --- a/src/codegate/pipeline/extract_snippets/extract_snippets.py +++ b/src/codegate/pipeline/extract_snippets/extract_snippets.py @@ -13,6 +13,7 @@ logger = structlog.get_logger("codegate") + def ecosystem_from_filepath(filepath: str) -> Optional[str]: """ Determine language from filepath. diff --git a/src/codegate/providers/anthropic/provider.py b/src/codegate/providers/anthropic/provider.py index a035f7aa..4d7eba59 100644 --- a/src/codegate/providers/anthropic/provider.py +++ b/src/codegate/providers/anthropic/provider.py @@ -48,4 +48,4 @@ async def create_message( is_fim_request = self._is_fim_request(request, data) stream = await self.complete(data, x_api_key, is_fim_request) - return self._completion_handler.create_streaming_response(stream) + return self._completion_handler.create_response(stream) diff --git a/src/codegate/providers/completion/base.py b/src/codegate/providers/completion/base.py index 1e86129c..9c424a24 100644 --- a/src/codegate/providers/completion/base.py +++ b/src/codegate/providers/completion/base.py @@ -1,7 +1,8 @@ from abc import ABC, abstractmethod +from collections.abc import Iterator from typing import Any, AsyncIterator, Optional, Union -from fastapi.responses import StreamingResponse +from fastapi.responses import JSONResponse, StreamingResponse from litellm import ChatCompletionRequest, ModelResponse @@ -23,5 +24,17 @@ async def execute_completion( pass @abstractmethod - def create_streaming_response(self, stream: AsyncIterator[Any]) -> StreamingResponse: + def _create_streaming_response(self, stream: AsyncIterator[Any]) -> StreamingResponse: pass + + @abstractmethod + def _create_json_response(self, response: Any) -> JSONResponse: + pass + + def create_response(self, response: Any) -> Union[JSONResponse, StreamingResponse]: + """ + Create a FastAPI response from the completion response. + """ + if isinstance(response, Iterator): + return self._create_streaming_response(response) + return self._create_json_response(response) diff --git a/src/codegate/providers/litellmshim/litellmshim.py b/src/codegate/providers/litellmshim/litellmshim.py index 1de5bb28..5954691f 100644 --- a/src/codegate/providers/litellmshim/litellmshim.py +++ b/src/codegate/providers/litellmshim/litellmshim.py @@ -1,10 +1,17 @@ from typing import Any, AsyncIterator, Callable, Optional, Union -from fastapi.responses import StreamingResponse -from litellm import ChatCompletionRequest, ModelResponse, acompletion +import structlog +from fastapi.responses import JSONResponse, StreamingResponse +from litellm import ( + ChatCompletionRequest, + ModelResponse, + acompletion, +) from codegate.providers.base import BaseCompletionHandler, StreamGenerator +logger = structlog.get_logger("codegate") + class LiteLLmShim(BaseCompletionHandler): """ @@ -42,7 +49,7 @@ async def execute_completion( return await self._fim_completion_func(**request) return await self._completion_func(**request) - def create_streaming_response(self, stream: AsyncIterator[Any]) -> StreamingResponse: + def _create_streaming_response(self, stream: AsyncIterator[Any]) -> StreamingResponse: """ Create a streaming response from a stream generator. The StreamingResponse is the format that FastAPI expects for streaming responses. @@ -56,3 +63,14 @@ def create_streaming_response(self, stream: AsyncIterator[Any]) -> StreamingResp }, status_code=200, ) + + def _create_json_response(self, response: ModelResponse) -> JSONResponse: + """ + Create a JSON FastAPI response from a ModelResponse object. + ModelResponse is obtained when the request is not streaming. + """ + # ModelResponse is not a Pydantic object but has a json method we can use to serialize + if isinstance(response, ModelResponse): + return JSONResponse(status_code=200, content=response.json()) + # Most of others objects in LiteLLM are Pydantic, we can use the model_dump method + return JSONResponse(status_code=200, content=response.model_dump()) diff --git a/src/codegate/providers/llamacpp/completion_handler.py b/src/codegate/providers/llamacpp/completion_handler.py index f5e6fc1d..72490b1f 100644 --- a/src/codegate/providers/llamacpp/completion_handler.py +++ b/src/codegate/providers/llamacpp/completion_handler.py @@ -2,7 +2,7 @@ import json from typing import Any, AsyncIterator, Iterator, Optional, Union -from fastapi.responses import StreamingResponse +from fastapi.responses import JSONResponse, StreamingResponse from litellm import ChatCompletionRequest, ModelResponse from llama_cpp.llama_types import ( CreateChatCompletionStreamResponse, @@ -75,7 +75,7 @@ async def execute_completion( return convert_to_async_iterator(response) if stream else response - def create_streaming_response(self, stream: AsyncIterator[Any]) -> StreamingResponse: + def _create_streaming_response(self, stream: AsyncIterator[Any]) -> StreamingResponse: """ Create a streaming response from a stream generator. The StreamingResponse is the format that FastAPI expects for streaming responses. @@ -89,3 +89,6 @@ def create_streaming_response(self, stream: AsyncIterator[Any]) -> StreamingResp }, status_code=200, ) + + def _create_json_response(self, response: Any) -> JSONResponse: + raise NotImplementedError("JSON Reponse in LlamaCPP not implemented yet.") diff --git a/src/codegate/providers/llamacpp/provider.py b/src/codegate/providers/llamacpp/provider.py index befc169e..efe06f09 100644 --- a/src/codegate/providers/llamacpp/provider.py +++ b/src/codegate/providers/llamacpp/provider.py @@ -43,4 +43,4 @@ async def create_completion( is_fim_request = self._is_fim_request(request, data) stream = await self.complete(data, None, is_fim_request=is_fim_request) - return self._completion_handler.create_streaming_response(stream) + return self._completion_handler.create_response(stream) diff --git a/src/codegate/providers/openai/provider.py b/src/codegate/providers/openai/provider.py index 60c36a4b..741d3143 100644 --- a/src/codegate/providers/openai/provider.py +++ b/src/codegate/providers/openai/provider.py @@ -49,4 +49,4 @@ async def create_completion( is_fim_request = self._is_fim_request(request, data) stream = await self.complete(data, api_key, is_fim_request=is_fim_request) - return self._completion_handler.create_streaming_response(stream) + return self._completion_handler.create_response(stream) diff --git a/src/codegate/providers/vllm/provider.py b/src/codegate/providers/vllm/provider.py index d2e710af..a342ac6f 100644 --- a/src/codegate/providers/vllm/provider.py +++ b/src/codegate/providers/vllm/provider.py @@ -57,4 +57,4 @@ async def create_completion( is_fim_request = self._is_fim_request(request, data) stream = await self.complete(data, api_key, is_fim_request=is_fim_request) - return self._completion_handler.create_streaming_response(stream) + return self._completion_handler.create_response(stream) diff --git a/tests/providers/litellmshim/test_litellmshim.py b/tests/providers/litellmshim/test_litellmshim.py index 73889a34..87b75803 100644 --- a/tests/providers/litellmshim/test_litellmshim.py +++ b/tests/providers/litellmshim/test_litellmshim.py @@ -117,7 +117,7 @@ async def mock_stream_gen(): generator = mock_stream_gen() litellm_shim = LiteLLmShim(stream_generator=sse_stream_generator) - response = litellm_shim.create_streaming_response(generator) + response = litellm_shim._create_streaming_response(generator) # Verify response metadata assert isinstance(response, StreamingResponse) diff --git a/tests/providers/test_registry.py b/tests/providers/test_registry.py index 8c957f13..d7c97da9 100644 --- a/tests/providers/test_registry.py +++ b/tests/providers/test_registry.py @@ -43,12 +43,15 @@ def execute_completion( ) -> Any: pass - def create_streaming_response( + def _create_streaming_response( self, stream: AsyncIterator[Any], ) -> StreamingResponse: return StreamingResponse(stream) + def _create_json_response(self, response: Any) -> Any: + raise NotImplementedError + class MockInputNormalizer(ModelInputNormalizer): def normalize(self, data: Dict) -> Dict: