Skip to content

Commit

Permalink
Merge pull request #149 from stacklok/normalize-vllm-output
Browse files Browse the repository at this point in the history
Respond with JSON if the request is non-stream
  • Loading branch information
aponcedeleonch authored Dec 2, 2024
2 parents 366bd6e + 9b3e488 commit d4f1ab8
Show file tree
Hide file tree
Showing 9 changed files with 50 additions and 13 deletions.
2 changes: 1 addition & 1 deletion src/codegate/providers/anthropic/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,4 @@ async def create_message(

is_fim_request = self._is_fim_request(request, data)
stream = await self.complete(data, x_api_key, is_fim_request)
return self._completion_handler.create_streaming_response(stream)
return self._completion_handler.create_response(stream)
17 changes: 15 additions & 2 deletions src/codegate/providers/completion/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from abc import ABC, abstractmethod
from collections.abc import Iterator
from typing import Any, AsyncIterator, Optional, Union

from fastapi.responses import StreamingResponse
from fastapi.responses import JSONResponse, StreamingResponse
from litellm import ChatCompletionRequest, ModelResponse


Expand All @@ -23,5 +24,17 @@ async def execute_completion(
pass

@abstractmethod
def create_streaming_response(self, stream: AsyncIterator[Any]) -> StreamingResponse:
def _create_streaming_response(self, stream: AsyncIterator[Any]) -> StreamingResponse:
pass

@abstractmethod
def _create_json_response(self, response: Any) -> JSONResponse:
pass

def create_response(self, response: Any) -> Union[JSONResponse, StreamingResponse]:
"""
Create a FastAPI response from the completion response.
"""
if isinstance(response, Iterator):
return self._create_streaming_response(response)
return self._create_json_response(response)
24 changes: 21 additions & 3 deletions src/codegate/providers/litellmshim/litellmshim.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
from typing import Any, AsyncIterator, Callable, Optional, Union

from fastapi.responses import StreamingResponse
from litellm import ChatCompletionRequest, ModelResponse, acompletion
import structlog
from fastapi.responses import JSONResponse, StreamingResponse
from litellm import (
ChatCompletionRequest,
ModelResponse,
acompletion,
)

from codegate.providers.base import BaseCompletionHandler, StreamGenerator

logger = structlog.get_logger("codegate")


class LiteLLmShim(BaseCompletionHandler):
"""
Expand Down Expand Up @@ -42,7 +49,7 @@ async def execute_completion(
return await self._fim_completion_func(**request)
return await self._completion_func(**request)

def create_streaming_response(self, stream: AsyncIterator[Any]) -> StreamingResponse:
def _create_streaming_response(self, stream: AsyncIterator[Any]) -> StreamingResponse:
"""
Create a streaming response from a stream generator. The StreamingResponse
is the format that FastAPI expects for streaming responses.
Expand All @@ -56,3 +63,14 @@ def create_streaming_response(self, stream: AsyncIterator[Any]) -> StreamingResp
},
status_code=200,
)

def _create_json_response(self, response: ModelResponse) -> JSONResponse:
"""
Create a JSON FastAPI response from a ModelResponse object.
ModelResponse is obtained when the request is not streaming.
"""
# ModelResponse is not a Pydantic object but has a json method we can use to serialize
if isinstance(response, ModelResponse):
return JSONResponse(status_code=200, content=response.json())
# Most of others objects in LiteLLM are Pydantic, we can use the model_dump method
return JSONResponse(status_code=200, content=response.model_dump())
7 changes: 5 additions & 2 deletions src/codegate/providers/llamacpp/completion_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import json
from typing import Any, AsyncIterator, Iterator, Optional, Union

from fastapi.responses import StreamingResponse
from fastapi.responses import JSONResponse, StreamingResponse
from litellm import ChatCompletionRequest, ModelResponse
from llama_cpp.llama_types import (
CreateChatCompletionStreamResponse,
Expand Down Expand Up @@ -75,7 +75,7 @@ async def execute_completion(

return convert_to_async_iterator(response) if stream else response

def create_streaming_response(self, stream: AsyncIterator[Any]) -> StreamingResponse:
def _create_streaming_response(self, stream: AsyncIterator[Any]) -> StreamingResponse:
"""
Create a streaming response from a stream generator. The StreamingResponse
is the format that FastAPI expects for streaming responses.
Expand All @@ -89,3 +89,6 @@ def create_streaming_response(self, stream: AsyncIterator[Any]) -> StreamingResp
},
status_code=200,
)

def _create_json_response(self, response: Any) -> JSONResponse:
raise NotImplementedError("JSON Reponse in LlamaCPP not implemented yet.")
2 changes: 1 addition & 1 deletion src/codegate/providers/llamacpp/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,4 @@ async def create_completion(

is_fim_request = self._is_fim_request(request, data)
stream = await self.complete(data, None, is_fim_request=is_fim_request)
return self._completion_handler.create_streaming_response(stream)
return self._completion_handler.create_response(stream)
2 changes: 1 addition & 1 deletion src/codegate/providers/openai/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,4 @@ async def create_completion(

is_fim_request = self._is_fim_request(request, data)
stream = await self.complete(data, api_key, is_fim_request=is_fim_request)
return self._completion_handler.create_streaming_response(stream)
return self._completion_handler.create_response(stream)
2 changes: 1 addition & 1 deletion src/codegate/providers/vllm/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,4 @@ async def create_completion(

is_fim_request = self._is_fim_request(request, data)
stream = await self.complete(data, api_key, is_fim_request=is_fim_request)
return self._completion_handler.create_streaming_response(stream)
return self._completion_handler.create_response(stream)
2 changes: 1 addition & 1 deletion tests/providers/litellmshim/test_litellmshim.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ async def mock_stream_gen():
generator = mock_stream_gen()

litellm_shim = LiteLLmShim(stream_generator=sse_stream_generator)
response = litellm_shim.create_streaming_response(generator)
response = litellm_shim._create_streaming_response(generator)

# Verify response metadata
assert isinstance(response, StreamingResponse)
Expand Down
5 changes: 4 additions & 1 deletion tests/providers/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,15 @@ def execute_completion(
) -> Any:
pass

def create_streaming_response(
def _create_streaming_response(
self,
stream: AsyncIterator[Any],
) -> StreamingResponse:
return StreamingResponse(stream)

def _create_json_response(self, response: Any) -> Any:
raise NotImplementedError


class MockInputNormalizer(ModelInputNormalizer):
def normalize(self, data: Dict) -> Dict:
Expand Down

0 comments on commit d4f1ab8

Please sign in to comment.