Skip to content

Commit

Permalink
Keep the code coverage high (#80)
Browse files Browse the repository at this point in the history
We still need to add unit tests for OpenAI, will add them in a separate
patch.
  • Loading branch information
jhrozek authored Nov 25, 2024
1 parent e48b1c5 commit 1b711ff
Show file tree
Hide file tree
Showing 7 changed files with 457 additions and 3 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ bandit = ">=1.7.10"
build = ">=1.0.0"
wheel = ">=0.40.0"
litellm = ">=1.52.11"
pytest-asyncio = "0.24.0"

[build-system]
requires = ["poetry-core"]
Expand Down
6 changes: 5 additions & 1 deletion src/codegate/providers/litellmshim/generators.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import json
from typing import Any, AsyncIterator

from pydantic import BaseModel

# Since different providers typically use one of these formats for streaming
# responses, we have a single stream generator for each format that is then plugged
# into the adapter.
Expand All @@ -10,7 +12,9 @@ async def sse_stream_generator(stream: AsyncIterator[Any]) -> AsyncIterator[str]
"""OpenAI-style SSE format"""
try:
async for chunk in stream:
if hasattr(chunk, "model_dump_json"):
if isinstance(chunk, BaseModel):
# alternatively we might want to just dump the whole object
# this might even allow us to tighten the typing of the stream
chunk = chunk.model_dump_json(exclude_none=True, exclude_unset=True)
try:
yield f"data:{chunk}\n\n"
Expand Down
5 changes: 3 additions & 2 deletions src/codegate/providers/litellmshim/litellmshim.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@ class LiteLLmShim(BaseCompletionHandler):
LiteLLM API.
"""

def __init__(self, adapter: BaseAdapter):
def __init__(self, adapter: BaseAdapter, completion_func=acompletion):
self._adapter = adapter
self._completion_func = completion_func

async def complete(self, data: Dict, api_key: str) -> AsyncIterator[Any]:
"""
Expand All @@ -28,7 +29,7 @@ async def complete(self, data: Dict, api_key: str) -> AsyncIterator[Any]:
if completion_request is None:
raise Exception("Couldn't translate the request")

response = await acompletion(**completion_request)
response = await self._completion_func(**completion_request)

if isinstance(response, ModelResponse):
return self._adapter.translate_completion_output_params(response)
Expand Down
154 changes: 154 additions & 0 deletions tests/providers/anthropic/test_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
from typing import AsyncIterator, Dict, List, Union

import pytest
from litellm import ModelResponse
from litellm.adapters.anthropic_adapter import AnthropicStreamWrapper
from litellm.types.llms.anthropic import (
ContentBlockDelta,
ContentBlockStart,
ContentTextBlockDelta,
MessageChunk,
MessageStartBlock,
)
from litellm.types.utils import Delta, StreamingChoices

from codegate.providers.anthropic.adapter import AnthropicAdapter


@pytest.fixture
def adapter():
return AnthropicAdapter()

def test_translate_completion_input_params(adapter):
# Test input data
completion_request = {
"model": "claude-3-haiku-20240307",
"max_tokens": 1024,
"stream": True,
"messages": [
{
"role": "user",
"system": "You are an expert code reviewer",
"content": [
{
"type": "text",
"text": "Review this code"
}
]
}
]
}
expected = {
'max_tokens': 1024,
'messages': [
{'content': [{'text': 'Review this code', 'type': 'text'}], 'role': 'user'}
],
'model': 'claude-3-haiku-20240307',
'stream': True
}

# Get translation
result = adapter.translate_completion_input_params(completion_request)
assert result == expected

@pytest.mark.asyncio
async def test_translate_completion_output_params_streaming(adapter):
# Test stream data
async def mock_stream():
messages = [
ModelResponse(
id="test_id_1",
choices=[
StreamingChoices(
finish_reason=None,
index=0,
delta=Delta(content="Hello", role="assistant")),
],
model="claude-3-haiku-20240307",
),
ModelResponse(
id="test_id_2",
choices=[
StreamingChoices(finish_reason=None,
index=0,
delta=Delta(content="world", role="assistant")),
],
model="claude-3-haiku-20240307",
),
ModelResponse(
id="test_id_2",
choices=[
StreamingChoices(finish_reason=None,
index=0,
delta=Delta(content="!", role="assistant")),
],
model="claude-3-haiku-20240307",
),
]
for msg in messages:
yield msg

expected: List[Union[MessageStartBlock,ContentBlockStart,ContentBlockDelta]] = [
MessageStartBlock(
type="message_start",
message=MessageChunk(
id="msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY",
type="message",
role="assistant",
content=[],
# litellm makes up a message start block with hardcoded values
model="claude-3-5-sonnet-20240620",
stop_reason=None,
stop_sequence=None,
usage={"input_tokens": 25, "output_tokens": 1},
),
),
ContentBlockStart(
type="content_block_start",
index=0,
content_block={"type": "text", "text": ""},
),
ContentBlockDelta(
type="content_block_delta",
index=0,
delta=ContentTextBlockDelta(type="text_delta", text="Hello"),
),
ContentBlockDelta(
type="content_block_delta",
index=0,
delta=ContentTextBlockDelta(type="text_delta", text="world"),
),
ContentBlockDelta(
type="content_block_delta",
index=0,
delta=ContentTextBlockDelta(type="text_delta", text="!"),
),
# litellm doesn't seem to have a type for message stop
dict(type="message_stop"),
]

stream = adapter.translate_completion_output_params_streaming(mock_stream())
assert isinstance(stream, AnthropicStreamWrapper)

# just so that we can zip over the expected chunks
stream_list = [chunk async for chunk in stream]
# Verify we got all chunks
assert len(stream_list) == 6

for chunk, expected_chunk in zip(stream_list, expected):
assert chunk == expected_chunk


def test_stream_generator_initialization(adapter):
# Verify the default stream generator is set
from codegate.providers.litellmshim import anthropic_stream_generator
assert adapter.stream_generator == anthropic_stream_generator

def test_custom_stream_generator():
# Test that we can inject a custom stream generator
async def custom_generator(stream: AsyncIterator[Dict]) -> AsyncIterator[str]:
async for chunk in stream:
yield "custom: " + str(chunk)

adapter = AnthropicAdapter(stream_generator=custom_generator)
assert adapter.stream_generator == custom_generator
80 changes: 80 additions & 0 deletions tests/providers/litellmshim/test_generators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from typing import AsyncIterator

import pytest
from litellm import ModelResponse

from codegate.providers.litellmshim import (
anthropic_stream_generator,
sse_stream_generator,
)


@pytest.mark.asyncio
async def test_sse_stream_generator():
# Mock stream data
mock_chunks = [
ModelResponse(id="1", choices=[{"text": "Hello"}]),
ModelResponse(id="2", choices=[{"text": "World"}])
]

async def mock_stream():
for chunk in mock_chunks:
yield chunk

# Collect generated SSE messages
messages = []
async for message in sse_stream_generator(mock_stream()):
messages.append(message)

# Verify format and content
assert len(messages) == len(mock_chunks) + 1 # +1 for the [DONE] message
assert all(msg.startswith("data:") for msg in messages)
assert "Hello" in messages[0]
assert "World" in messages[1]
assert messages[-1] == "data: [DONE]\n\n"

@pytest.mark.asyncio
async def test_anthropic_stream_generator():
# Mock Anthropic-style chunks
mock_chunks = [
{"type": "message_start", "message": {"id": "1"}},
{"type": "content_block_start", "content_block": {"text": "Hello"}},
{"type": "content_block_stop", "content_block": {"text": "World"}}
]

async def mock_stream():
for chunk in mock_chunks:
yield chunk

# Collect generated SSE messages
messages = []
async for message in anthropic_stream_generator(mock_stream()):
messages.append(message)

# Verify format and content
assert len(messages) == 3
for msg, chunk in zip(messages, mock_chunks):
assert msg.startswith(f"event: {chunk['type']}\ndata:")
assert "Hello" in messages[1] # content_block_start message
assert "World" in messages[2] # content_block_stop message

@pytest.mark.asyncio
async def test_generators_error_handling():
async def error_stream() -> AsyncIterator[str]:
raise Exception("Test error")
yield # This will never be reached, but is needed for AsyncIterator typing

# Test SSE generator error handling
messages = []
async for message in sse_stream_generator(error_stream()):
messages.append(message)
assert len(messages) == 2
assert "Test error" in messages[0]
assert messages[1] == "data: [DONE]\n\n"

# Test Anthropic generator error handling
messages = []
async for message in anthropic_stream_generator(error_stream()):
messages.append(message)
assert len(messages) == 1
assert "Test error" in messages[0]
Loading

0 comments on commit 1b711ff

Please sign in to comment.