diff --git a/pyproject.toml b/pyproject.toml index 5fa962ad..d889f6fa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ bandit = ">=1.7.10" build = ">=1.0.0" wheel = ">=0.40.0" litellm = ">=1.52.11" +pytest-asyncio = "0.24.0" [build-system] requires = ["poetry-core"] diff --git a/src/codegate/providers/litellmshim/generators.py b/src/codegate/providers/litellmshim/generators.py index 65f4f89c..306f1900 100644 --- a/src/codegate/providers/litellmshim/generators.py +++ b/src/codegate/providers/litellmshim/generators.py @@ -1,6 +1,8 @@ import json from typing import Any, AsyncIterator +from pydantic import BaseModel + # Since different providers typically use one of these formats for streaming # responses, we have a single stream generator for each format that is then plugged # into the adapter. @@ -10,7 +12,9 @@ async def sse_stream_generator(stream: AsyncIterator[Any]) -> AsyncIterator[str] """OpenAI-style SSE format""" try: async for chunk in stream: - if hasattr(chunk, "model_dump_json"): + if isinstance(chunk, BaseModel): + # alternatively we might want to just dump the whole object + # this might even allow us to tighten the typing of the stream chunk = chunk.model_dump_json(exclude_none=True, exclude_unset=True) try: yield f"data:{chunk}\n\n" diff --git a/src/codegate/providers/litellmshim/litellmshim.py b/src/codegate/providers/litellmshim/litellmshim.py index 0829debc..e6b54c81 100644 --- a/src/codegate/providers/litellmshim/litellmshim.py +++ b/src/codegate/providers/litellmshim/litellmshim.py @@ -14,8 +14,9 @@ class LiteLLmShim(BaseCompletionHandler): LiteLLM API. """ - def __init__(self, adapter: BaseAdapter): + def __init__(self, adapter: BaseAdapter, completion_func=acompletion): self._adapter = adapter + self._completion_func = completion_func async def complete(self, data: Dict, api_key: str) -> AsyncIterator[Any]: """ @@ -28,7 +29,7 @@ async def complete(self, data: Dict, api_key: str) -> AsyncIterator[Any]: if completion_request is None: raise Exception("Couldn't translate the request") - response = await acompletion(**completion_request) + response = await self._completion_func(**completion_request) if isinstance(response, ModelResponse): return self._adapter.translate_completion_output_params(response) diff --git a/tests/providers/anthropic/test_adapter.py b/tests/providers/anthropic/test_adapter.py new file mode 100644 index 00000000..e5597d80 --- /dev/null +++ b/tests/providers/anthropic/test_adapter.py @@ -0,0 +1,154 @@ +from typing import AsyncIterator, Dict, List, Union + +import pytest +from litellm import ModelResponse +from litellm.adapters.anthropic_adapter import AnthropicStreamWrapper +from litellm.types.llms.anthropic import ( + ContentBlockDelta, + ContentBlockStart, + ContentTextBlockDelta, + MessageChunk, + MessageStartBlock, +) +from litellm.types.utils import Delta, StreamingChoices + +from codegate.providers.anthropic.adapter import AnthropicAdapter + + +@pytest.fixture +def adapter(): + return AnthropicAdapter() + +def test_translate_completion_input_params(adapter): + # Test input data + completion_request = { + "model": "claude-3-haiku-20240307", + "max_tokens": 1024, + "stream": True, + "messages": [ + { + "role": "user", + "system": "You are an expert code reviewer", + "content": [ + { + "type": "text", + "text": "Review this code" + } + ] + } + ] + } + expected = { + 'max_tokens': 1024, + 'messages': [ + {'content': [{'text': 'Review this code', 'type': 'text'}], 'role': 'user'} + ], + 'model': 'claude-3-haiku-20240307', + 'stream': True + } + + # Get translation + result = adapter.translate_completion_input_params(completion_request) + assert result == expected + +@pytest.mark.asyncio +async def test_translate_completion_output_params_streaming(adapter): + # Test stream data + async def mock_stream(): + messages = [ + ModelResponse( + id="test_id_1", + choices=[ + StreamingChoices( + finish_reason=None, + index=0, + delta=Delta(content="Hello", role="assistant")), + ], + model="claude-3-haiku-20240307", + ), + ModelResponse( + id="test_id_2", + choices=[ + StreamingChoices(finish_reason=None, + index=0, + delta=Delta(content="world", role="assistant")), + ], + model="claude-3-haiku-20240307", + ), + ModelResponse( + id="test_id_2", + choices=[ + StreamingChoices(finish_reason=None, + index=0, + delta=Delta(content="!", role="assistant")), + ], + model="claude-3-haiku-20240307", + ), + ] + for msg in messages: + yield msg + + expected: List[Union[MessageStartBlock,ContentBlockStart,ContentBlockDelta]] = [ + MessageStartBlock( + type="message_start", + message=MessageChunk( + id="msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", + type="message", + role="assistant", + content=[], + # litellm makes up a message start block with hardcoded values + model="claude-3-5-sonnet-20240620", + stop_reason=None, + stop_sequence=None, + usage={"input_tokens": 25, "output_tokens": 1}, + ), + ), + ContentBlockStart( + type="content_block_start", + index=0, + content_block={"type": "text", "text": ""}, + ), + ContentBlockDelta( + type="content_block_delta", + index=0, + delta=ContentTextBlockDelta(type="text_delta", text="Hello"), + ), + ContentBlockDelta( + type="content_block_delta", + index=0, + delta=ContentTextBlockDelta(type="text_delta", text="world"), + ), + ContentBlockDelta( + type="content_block_delta", + index=0, + delta=ContentTextBlockDelta(type="text_delta", text="!"), + ), + # litellm doesn't seem to have a type for message stop + dict(type="message_stop"), + ] + + stream = adapter.translate_completion_output_params_streaming(mock_stream()) + assert isinstance(stream, AnthropicStreamWrapper) + + # just so that we can zip over the expected chunks + stream_list = [chunk async for chunk in stream] + # Verify we got all chunks + assert len(stream_list) == 6 + + for chunk, expected_chunk in zip(stream_list, expected): + assert chunk == expected_chunk + + +def test_stream_generator_initialization(adapter): + # Verify the default stream generator is set + from codegate.providers.litellmshim import anthropic_stream_generator + assert adapter.stream_generator == anthropic_stream_generator + +def test_custom_stream_generator(): + # Test that we can inject a custom stream generator + async def custom_generator(stream: AsyncIterator[Dict]) -> AsyncIterator[str]: + async for chunk in stream: + yield "custom: " + str(chunk) + + adapter = AnthropicAdapter(stream_generator=custom_generator) + assert adapter.stream_generator == custom_generator diff --git a/tests/providers/litellmshim/test_generators.py b/tests/providers/litellmshim/test_generators.py new file mode 100644 index 00000000..7a8de33b --- /dev/null +++ b/tests/providers/litellmshim/test_generators.py @@ -0,0 +1,80 @@ +from typing import AsyncIterator + +import pytest +from litellm import ModelResponse + +from codegate.providers.litellmshim import ( + anthropic_stream_generator, + sse_stream_generator, +) + + +@pytest.mark.asyncio +async def test_sse_stream_generator(): + # Mock stream data + mock_chunks = [ + ModelResponse(id="1", choices=[{"text": "Hello"}]), + ModelResponse(id="2", choices=[{"text": "World"}]) + ] + + async def mock_stream(): + for chunk in mock_chunks: + yield chunk + + # Collect generated SSE messages + messages = [] + async for message in sse_stream_generator(mock_stream()): + messages.append(message) + + # Verify format and content + assert len(messages) == len(mock_chunks) + 1 # +1 for the [DONE] message + assert all(msg.startswith("data:") for msg in messages) + assert "Hello" in messages[0] + assert "World" in messages[1] + assert messages[-1] == "data: [DONE]\n\n" + +@pytest.mark.asyncio +async def test_anthropic_stream_generator(): + # Mock Anthropic-style chunks + mock_chunks = [ + {"type": "message_start", "message": {"id": "1"}}, + {"type": "content_block_start", "content_block": {"text": "Hello"}}, + {"type": "content_block_stop", "content_block": {"text": "World"}} + ] + + async def mock_stream(): + for chunk in mock_chunks: + yield chunk + + # Collect generated SSE messages + messages = [] + async for message in anthropic_stream_generator(mock_stream()): + messages.append(message) + + # Verify format and content + assert len(messages) == 3 + for msg, chunk in zip(messages, mock_chunks): + assert msg.startswith(f"event: {chunk['type']}\ndata:") + assert "Hello" in messages[1] # content_block_start message + assert "World" in messages[2] # content_block_stop message + +@pytest.mark.asyncio +async def test_generators_error_handling(): + async def error_stream() -> AsyncIterator[str]: + raise Exception("Test error") + yield # This will never be reached, but is needed for AsyncIterator typing + + # Test SSE generator error handling + messages = [] + async for message in sse_stream_generator(error_stream()): + messages.append(message) + assert len(messages) == 2 + assert "Test error" in messages[0] + assert messages[1] == "data: [DONE]\n\n" + + # Test Anthropic generator error handling + messages = [] + async for message in anthropic_stream_generator(error_stream()): + messages.append(message) + assert len(messages) == 1 + assert "Test error" in messages[0] diff --git a/tests/providers/litellmshim/test_litellmshim.py b/tests/providers/litellmshim/test_litellmshim.py new file mode 100644 index 00000000..1b97bee7 --- /dev/null +++ b/tests/providers/litellmshim/test_litellmshim.py @@ -0,0 +1,157 @@ +from typing import Any, AsyncIterator, Dict +from unittest.mock import AsyncMock + +import pytest +from fastapi.responses import StreamingResponse +from litellm import ChatCompletionRequest, ModelResponse + +from codegate.providers.litellmshim import BaseAdapter, LiteLLmShim + + +class MockAdapter(BaseAdapter): + def __init__(self): + self.stream_generator = AsyncMock() + super().__init__(self.stream_generator) + + def translate_completion_input_params(self, kwargs: Dict) -> ChatCompletionRequest: + # Validate required fields + if "messages" not in kwargs or "model" not in kwargs: + raise ValueError("Required fields 'messages' and 'model' must be present") + + modified_kwargs = kwargs.copy() + modified_kwargs["mock_adapter_processed"] = True + return ChatCompletionRequest(**modified_kwargs) + + def translate_completion_output_params(self, response: ModelResponse) -> Any: + response.mock_adapter_processed = True + return response + + def translate_completion_output_params_streaming( + self, completion_stream: Any, + ) -> Any: + async def modified_stream(): + async for chunk in completion_stream: + chunk.mock_adapter_processed = True + yield chunk + return modified_stream() + +@pytest.fixture +def mock_adapter(): + return MockAdapter() + +@pytest.fixture +def litellm_shim(mock_adapter): + return LiteLLmShim(mock_adapter) + +@pytest.mark.asyncio +async def test_complete_non_streaming(litellm_shim, mock_adapter): + # Mock response + mock_response = ModelResponse(id="123", choices=[{"text": "test response"}]) + mock_completion = AsyncMock(return_value=mock_response) + + # Create shim with mocked completion + litellm_shim = LiteLLmShim(mock_adapter, completion_func=mock_completion) + + # Test data + data = { + "messages": [{"role": "user", "content": "Hello"}], + "model": "gpt-3.5-turbo" + } + api_key = "test-key" + + # Execute + result = await litellm_shim.complete(data, api_key) + + # Verify + assert result == mock_response + mock_completion.assert_called_once() + called_args = mock_completion.call_args[1] + assert called_args["api_key"] == api_key + assert called_args["messages"] == data["messages"] + # Verify adapter processed the input + assert called_args["mock_adapter_processed"] is True + +@pytest.mark.asyncio +async def test_complete_streaming(): + # Mock streaming response with specific test content + async def mock_stream() -> AsyncIterator[ModelResponse]: + yield ModelResponse(id="123", choices=[{"text": "chunk1"}]) + yield ModelResponse(id="123", choices=[{"text": "chunk2"}]) + + mock_completion = AsyncMock(return_value=mock_stream()) + mock_adapter = MockAdapter() + litellm_shim = LiteLLmShim(mock_adapter, completion_func=mock_completion) + + # Test data + data = { + "messages": [{"role": "user", "content": "Hello"}], + "model": "gpt-3.5-turbo", + "stream": True + } + api_key = "test-key" + + # Execute + result_stream = await litellm_shim.complete(data, api_key) + + # Verify stream contents and adapter processing + chunks = [] + async for chunk in result_stream: + chunks.append(chunk) + # Verify each chunk was processed by the adapter + assert hasattr(chunk, "mock_adapter_processed") + assert chunk.mock_adapter_processed is True + + assert len(chunks) == 2 + assert chunks[0].choices[0]["text"] == "chunk1" + assert chunks[1].choices[0]["text"] == "chunk2" + + # Verify completion function was called with correct parameters + mock_completion.assert_called_once() + called_args = mock_completion.call_args[1] + assert called_args["mock_adapter_processed"] is True # Verify input was processed + assert called_args["messages"] == data["messages"] + assert called_args["model"] == data["model"] + assert called_args["stream"] is True + assert called_args["api_key"] == api_key + +@pytest.mark.asyncio +async def test_create_streaming_response(litellm_shim): + # Create a simple async generator that we know works + async def mock_stream_gen(): + for msg in ["Hello", "World"]: + yield msg.encode() # FastAPI expects bytes + + # Create and verify the generator + generator = mock_stream_gen() + + response = litellm_shim.create_streaming_response(generator) + + # Verify response metadata + assert isinstance(response, StreamingResponse) + assert response.status_code == 200 + assert response.headers["Cache-Control"] == "no-cache" + assert response.headers["Connection"] == "keep-alive" + assert response.headers["Transfer-Encoding"] == "chunked" + +@pytest.mark.asyncio +async def test_complete_invalid_params(): + mock_completion = AsyncMock() + mock_adapter = MockAdapter() + litellm_shim = LiteLLmShim(mock_adapter, completion_func=mock_completion) + + # Test data missing required fields + data = { + "invalid_field": "test" + # missing 'messages' and 'model' + } + api_key = "test-key" + + # Execute and verify specific exception is raised + with pytest.raises( + ValueError, + match="Required fields 'messages' and 'model' must be present", + ): + await litellm_shim.complete(data, api_key) + + # Verify the completion function was never called + mock_completion.assert_not_called() diff --git a/tests/providers/test_registry.py b/tests/providers/test_registry.py new file mode 100644 index 00000000..c1dc542b --- /dev/null +++ b/tests/providers/test_registry.py @@ -0,0 +1,57 @@ +from typing import Any, AsyncIterator, Dict + +import pytest +from fastapi import FastAPI +from fastapi.responses import StreamingResponse + +from codegate.providers.base import BaseCompletionHandler, BaseProvider +from codegate.providers.registry import ProviderRegistry + + +class MockCompletionHandler(BaseCompletionHandler): + async def complete(self, data: Dict, api_key: str) -> AsyncIterator[Any]: + yield "test" + + def create_streaming_response( + self, stream: AsyncIterator[Any], + ) -> StreamingResponse: + return StreamingResponse(stream) + +class MockProvider(BaseProvider): + def _setup_routes(self) -> None: + @self.router.get("/test") + def test_route(): + return {"message": "test"} + +@pytest.fixture +def mock_completion_handler(): + return MockCompletionHandler() + +@pytest.fixture +def app(): + return FastAPI() + +@pytest.fixture +def registry(app): + return ProviderRegistry(app) + +def test_add_provider(registry, mock_completion_handler): + provider = MockProvider(mock_completion_handler) + registry.add_provider("test", provider) + + assert "test" in registry.providers + assert registry.providers["test"] == provider + +def test_get_provider(registry, mock_completion_handler): + provider = MockProvider(mock_completion_handler) + registry.add_provider("test", provider) + + assert registry.get_provider("test") == provider + assert registry.get_provider("nonexistent") is None + +def test_provider_routes_added(app, registry, mock_completion_handler): + provider = MockProvider(mock_completion_handler) + registry.add_provider("test", provider) + + routes = [route for route in app.routes if route.path == "/test"] + assert len(routes) == 1