Skip to content

Commit

Permalink
Fix linting errors in CI
Browse files Browse the repository at this point in the history
  • Loading branch information
lukehinds committed Nov 25, 2024
1 parent 560c56e commit 25ac074
Show file tree
Hide file tree
Showing 8 changed files with 1,065 additions and 78 deletions.
964 changes: 959 additions & 5 deletions poetry.lock

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ description = "Generative AI CodeGen security gateway"
readme = "README.md"
authors = []
packages = [{include = "codegate", from = "src"}]
requires-python = ">=3.11"

[tool.poetry.dependencies]
python = ">=3.11"
Expand Down
65 changes: 39 additions & 26 deletions scripts/import_packages.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import json
from utils.embedding_util import generate_embeddings

import weaviate
from weaviate.classes.config import DataType, Property
from weaviate.embedded import EmbeddedOptions
from weaviate.classes.config import Property, DataType

from utils.embedding_util import generate_embeddings

json_files = [
'data/archived.jsonl',
'data/deprecated.jsonl',
'data/malicious.jsonl',
"data/archived.jsonl",
"data/deprecated.jsonl",
"data/malicious.jsonl",
]


Expand All @@ -21,7 +22,7 @@ def setup_schema(client):
Property(name="type", data_type=DataType.TEXT),
Property(name="status", data_type=DataType.TEXT),
Property(name="description", data_type=DataType.TEXT),
]
],
)


Expand All @@ -47,11 +48,20 @@ def generate_vector_string(package):

# add extra status
if package["status"] == "archived":
vector_str += f". However, this package is found to be archived and no longer maintained. For additional information refer to {package_url}"
vector_str += (
f". However, this package is found to be archived and no longer "
f"maintained. For additional information refer to {package_url}"
)
elif package["status"] == "deprecated":
vector_str += f". However, this package is found to be deprecated and no longer recommended for use. For additional information refer to {package_url}"
vector_str += (
f". However, this package is found to be deprecated and no longer "
f"recommended for use. For additional information refer to {package_url}"
)
elif package["status"] == "malicious":
vector_str += f". However, this package is found to be malicious. For additional information refer to {package_url}"
vector_str += (
f". However, this package is found to be malicious. "
f"For additional information refer to {package_url}"
)
return vector_str


Expand All @@ -62,34 +72,38 @@ def add_data(client):
existing_packages = list(collection.iterator())
packages_dict = {}
for package in existing_packages:
key = package.properties['name']+"/"+package.properties['type']
key = package.properties["name"] + "/" + package.properties["type"]
value = {
'status': package.properties['status'],
'description': package.properties['description'],
"status": package.properties["status"],
"description": package.properties["description"],
}
packages_dict[key] = value

for json_file in json_files:
with open(json_file, 'r') as f:
with open(json_file, "r") as f:
print("Adding data from", json_file)
with collection.batch.dynamic() as batch:
for line in f:
package = json.loads(line)

# now add the status column
if 'archived' in json_file:
package['status'] = 'archived'
elif 'deprecated' in json_file:
package['status'] = 'deprecated'
elif 'malicious' in json_file:
package['status'] = 'malicious'
if "archived" in json_file:
package["status"] = "archived"
elif "deprecated" in json_file:
package["status"] = "deprecated"
elif "malicious" in json_file:
package["status"] = "malicious"
else:
package['status'] = 'unknown'
package["status"] = "unknown"

# check for the existing package and only add if different
key = package['name']+"/"+package['type']
key = package["name"] + "/" + package["type"]
if key in packages_dict:
if packages_dict[key]['status'] == package['status'] and packages_dict[key]['description'] == package['description']:
if (
packages_dict[key]["status"] == package["status"]
and packages_dict[key]["description"]
== package["description"]
):
print("Package already exists", key)
continue

Expand All @@ -104,17 +118,16 @@ def add_data(client):
def run_import():
client = weaviate.WeaviateClient(
embedded_options=EmbeddedOptions(
persistence_data_path="./weaviate_data",
grpc_port=50052
persistence_data_path="./weaviate_data", grpc_port=50052
),
)
with client:
client.connect()
print('is_ready:', client.is_ready())
print("is_ready:", client.is_ready())

setup_schema(client)
add_data(client)


if __name__ == '__main__':
if __name__ == "__main__":
run_import()
52 changes: 28 additions & 24 deletions tests/providers/anthropic/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,38 +19,35 @@
def adapter():
return AnthropicAdapter()


def test_translate_completion_input_params(adapter):
# Test input data
completion_request = {
"model": "claude-3-haiku-20240307",
"max_tokens": 1024,
"stream": True,
"messages": [
{
"role": "user",
"system": "You are an expert code reviewer",
"content": [
{
"type": "text",
"text": "Review this code"
}
]
}
]
{
"role": "user",
"system": "You are an expert code reviewer",
"content": [{"type": "text", "text": "Review this code"}],
}
],
}
expected = {
'max_tokens': 1024,
'messages': [
{'content': [{'text': 'Review this code', 'type': 'text'}], 'role': 'user'}
"max_tokens": 1024,
"messages": [
{"content": [{"text": "Review this code", "type": "text"}], "role": "user"}
],
'model': 'claude-3-haiku-20240307',
'stream': True
"model": "claude-3-haiku-20240307",
"stream": True,
}

# Get translation
result = adapter.translate_completion_input_params(completion_request)
assert result == expected


@pytest.mark.asyncio
async def test_translate_completion_output_params_streaming(adapter):
# Test stream data
Expand All @@ -62,33 +59,38 @@ async def mock_stream():
StreamingChoices(
finish_reason=None,
index=0,
delta=Delta(content="Hello", role="assistant")),
delta=Delta(content="Hello", role="assistant"),
),
],
model="claude-3-haiku-20240307",
),
ModelResponse(
id="test_id_2",
choices=[
StreamingChoices(finish_reason=None,
index=0,
delta=Delta(content="world", role="assistant")),
StreamingChoices(
finish_reason=None,
index=0,
delta=Delta(content="world", role="assistant"),
),
],
model="claude-3-haiku-20240307",
),
ModelResponse(
id="test_id_2",
choices=[
StreamingChoices(finish_reason=None,
index=0,
delta=Delta(content="!", role="assistant")),
StreamingChoices(
finish_reason=None,
index=0,
delta=Delta(content="!", role="assistant"),
),
],
model="claude-3-haiku-20240307",
),
]
for msg in messages:
yield msg

expected: List[Union[MessageStartBlock,ContentBlockStart,ContentBlockDelta]] = [
expected: List[Union[MessageStartBlock, ContentBlockStart, ContentBlockDelta]] = [
MessageStartBlock(
type="message_start",
message=MessageChunk(
Expand Down Expand Up @@ -142,8 +144,10 @@ async def mock_stream():
def test_stream_generator_initialization(adapter):
# Verify the default stream generator is set
from codegate.providers.litellmshim import anthropic_stream_generator

assert adapter.stream_generator == anthropic_stream_generator


def test_custom_stream_generator():
# Test that we can inject a custom stream generator
async def custom_generator(stream: AsyncIterator[Dict]) -> AsyncIterator[str]:
Expand Down
6 changes: 4 additions & 2 deletions tests/providers/litellmshim/test_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ async def test_sse_stream_generator():
# Mock stream data
mock_chunks = [
ModelResponse(id="1", choices=[{"text": "Hello"}]),
ModelResponse(id="2", choices=[{"text": "World"}])
ModelResponse(id="2", choices=[{"text": "World"}]),
]

async def mock_stream():
Expand All @@ -33,13 +33,14 @@ async def mock_stream():
assert "World" in messages[1]
assert messages[-1] == "data: [DONE]\n\n"


@pytest.mark.asyncio
async def test_anthropic_stream_generator():
# Mock Anthropic-style chunks
mock_chunks = [
{"type": "message_start", "message": {"id": "1"}},
{"type": "content_block_start", "content_block": {"text": "Hello"}},
{"type": "content_block_stop", "content_block": {"text": "World"}}
{"type": "content_block_stop", "content_block": {"text": "World"}},
]

async def mock_stream():
Expand All @@ -58,6 +59,7 @@ async def mock_stream():
assert "Hello" in messages[1] # content_block_start message
assert "World" in messages[2] # content_block_stop message


@pytest.mark.asyncio
async def test_generators_error_handling():
async def error_stream() -> AsyncIterator[str]:
Expand Down
18 changes: 13 additions & 5 deletions tests/providers/litellmshim/test_litellmshim.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,22 +27,27 @@ def translate_completion_output_params(self, response: ModelResponse) -> Any:
return response

def translate_completion_output_params_streaming(
self, completion_stream: Any,
self,
completion_stream: Any,
) -> Any:
async def modified_stream():
async for chunk in completion_stream:
chunk.mock_adapter_processed = True
yield chunk

return modified_stream()


@pytest.fixture
def mock_adapter():
return MockAdapter()


@pytest.fixture
def litellm_shim(mock_adapter):
return LiteLLmShim(mock_adapter)


@pytest.mark.asyncio
async def test_complete_non_streaming(litellm_shim, mock_adapter):
# Mock response
Expand All @@ -55,7 +60,7 @@ async def test_complete_non_streaming(litellm_shim, mock_adapter):
# Test data
data = {
"messages": [{"role": "user", "content": "Hello"}],
"model": "gpt-3.5-turbo"
"model": "gpt-3.5-turbo",
}
api_key = "test-key"

Expand All @@ -71,6 +76,7 @@ async def test_complete_non_streaming(litellm_shim, mock_adapter):
# Verify adapter processed the input
assert called_args["mock_adapter_processed"] is True


@pytest.mark.asyncio
async def test_complete_streaming():
# Mock streaming response with specific test content
Expand All @@ -86,7 +92,7 @@ async def mock_stream() -> AsyncIterator[ModelResponse]:
data = {
"messages": [{"role": "user", "content": "Hello"}],
"model": "gpt-3.5-turbo",
"stream": True
"stream": True,
}
api_key = "test-key"

Expand Down Expand Up @@ -114,6 +120,7 @@ async def mock_stream() -> AsyncIterator[ModelResponse]:
assert called_args["stream"] is True
assert called_args["api_key"] == api_key


@pytest.mark.asyncio
async def test_create_streaming_response(litellm_shim):
# Create a simple async generator that we know works
Expand All @@ -133,6 +140,7 @@ async def mock_stream_gen():
assert response.headers["Connection"] == "keep-alive"
assert response.headers["Transfer-Encoding"] == "chunked"


@pytest.mark.asyncio
async def test_complete_invalid_params():
mock_completion = AsyncMock()
Expand All @@ -148,8 +156,8 @@ async def test_complete_invalid_params():

# Execute and verify specific exception is raised
with pytest.raises(
ValueError,
match="Required fields 'messages' and 'model' must be present",
ValueError,
match="Required fields 'messages' and 'model' must be present",
):
await litellm_shim.complete(data, api_key)

Expand Down
Loading

0 comments on commit 25ac074

Please sign in to comment.