Skip to content

Commit

Permalink
Merge pull request #128 from stacklok/issue-63
Browse files Browse the repository at this point in the history
feat: enable weaviate usage in codegate
  • Loading branch information
yrobla authored Dec 2, 2024
2 parents d4f1ab8 + 529ef2d commit 466d495
Show file tree
Hide file tree
Showing 31 changed files with 286 additions and 29 deletions.
6 changes: 3 additions & 3 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

28 changes: 2 additions & 26 deletions scripts/import_packages.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from weaviate.util import generate_uuid5

from codegate.inference.inference_engine import LlamaCppInferenceEngine
from src.codegate.utils.utils import generate_vector_string


class PackageImporter:
Expand Down Expand Up @@ -37,33 +38,8 @@ def setup_schema(self):
],
)

def generate_vector_string(self, package):
vector_str = f"{package['name']}"
package_url = ""
type_map = {
"pypi": "Python package available on PyPI",
"npm": "JavaScript package available on NPM",
"go": "Go package",
"crates": "Rust package available on Crates",
"java": "Java package",
}
status_messages = {
"archived": "However, this package is found to be archived and no longer maintained.",
"deprecated": "However, this package is found to be deprecated and no longer "
"recommended for use.",
"malicious": "However, this package is found to be malicious.",
}
vector_str += f" is a {type_map.get(package['type'], 'unknown type')} "
package_url = f"https://trustypkg.dev/{package['type']}/{package['name']}"

# Add extra status
status_suffix = status_messages.get(package["status"], "")
if status_suffix:
vector_str += f"{status_suffix} For additional information refer to {package_url}"
return vector_str

async def process_package(self, batch, package):
vector_str = self.generate_vector_string(package)
vector_str = generate_vector_string(package)
vector = await self.inference_engine.embed(self.model_path, [vector_str])
# This is where the synchronous call is made
batch.add_object(properties=package, vector=vector[0])
Expand Down
3 changes: 3 additions & 0 deletions src/codegate/pipeline/codegate_context_retriever/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from codegate.pipeline.codegate_context_retriever.codegate import CodegateContextRetriever

__all__ = ["CodegateContextRetriever"]
82 changes: 82 additions & 0 deletions src/codegate/pipeline/codegate_context_retriever/codegate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from typing import Optional

from litellm import ChatCompletionRequest, ChatCompletionSystemMessage

from codegate.pipeline.base import (
PipelineContext,
PipelineResult,
PipelineStep,
)
from src.codegate.storage.storage_engine import StorageEngine
from src.codegate.utils.utils import generate_vector_string


class CodegateContextRetriever(PipelineStep):
"""
Pipeline step that adds a context message to the completion request when it detects
the word "codegate" in the user message.
"""

def __init__(self, system_prompt_message: Optional[str] = None):
self._system_message = ChatCompletionSystemMessage(
content=system_prompt_message, role="system"
)
self.storage_engine = StorageEngine()

@property
def name(self) -> str:
"""
Returns the name of this pipeline step.
"""
return "codegate-context-retriever"

async def get_objects_from_search(self, search: str) -> list[object]:
objects = await self.storage_engine.search(search)
return objects

def generate_context_str(self, objects: list[object]) -> str:
context_str = "Please use the information about related packages "
"to influence your answer:\n"
for obj in objects:
# generate dictionary from object
package_obj = {
"name": obj.properties["name"],
"type": obj.properties["type"],
"status": obj.properties["status"],
"description": obj.properties["description"],
}
package_str = generate_vector_string(package_obj)
context_str += package_str + "\n"
return context_str

async def process(
self, request: ChatCompletionRequest, context: PipelineContext
) -> PipelineResult:
"""
Process the completion request and add a system prompt if the user message contains
the word "codegate".
"""
# no prompt configured
if not self._system_message["content"]:
return PipelineResult(request=request)

last_user_message = self.get_last_user_message(request)

if last_user_message is not None:
last_user_message_str, last_user_idx = last_user_message
if "codegate" in last_user_message_str.lower():
# strip codegate from prompt and trim it
last_user_message_str = (
last_user_message_str.lower().replace("codegate", "").strip()
)
searched_objects = await self.get_objects_from_search(last_user_message_str)
context_str = self.generate_context_str(searched_objects)
# Add a system prompt to the completion request
new_request = request.copy()
new_request["messages"].insert(last_user_idx, context_str)
return PipelineResult(
request=new_request,
)

# Fall through
return PipelineResult(request=request)
2 changes: 2 additions & 0 deletions src/codegate/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from codegate import __description__, __version__
from codegate.config import Config
from codegate.pipeline.base import PipelineStep, SequentialPipelineProcessor
from codegate.pipeline.codegate_context_retriever.codegate import CodegateContextRetriever
from codegate.pipeline.codegate_system_prompt.codegate import CodegateSystemPrompt
from codegate.pipeline.extract_snippets.extract_snippets import CodeSnippetExtractor
from codegate.pipeline.version.version import CodegateVersion
Expand All @@ -27,6 +28,7 @@ def init_app() -> FastAPI:
CodegateVersion(),
CodeSnippetExtractor(),
CodegateSystemPrompt(Config.get_config().prompts.codegate_chat),
CodegateContextRetriever(Config.get_config().prompts.codegate_chat),
# CodegateSecrets(),
]
# Leaving the pipeline empty for now
Expand Down
3 changes: 3 additions & 0 deletions src/codegate/storage/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from codegate.storage.storage_engine import StorageEngine

__all__ = [StorageEngine]
106 changes: 106 additions & 0 deletions src/codegate/storage/storage_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import structlog
import weaviate
from weaviate.classes.config import DataType
from weaviate.classes.query import MetadataQuery

from codegate.inference.inference_engine import LlamaCppInferenceEngine

logger = structlog.get_logger("codegate")

schema_config = [
{
"name": "Package",
"properties": [
{"name": "name", "data_type": DataType.TEXT},
{"name": "type", "data_type": DataType.TEXT},
{"name": "status", "data_type": DataType.TEXT},
{"name": "description", "data_type": DataType.TEXT},
],
},
]


class StorageEngine:
def get_client(self, data_path):
try:
client = weaviate.WeaviateClient(
embedded_options=weaviate.EmbeddedOptions(persistence_data_path=data_path),
)
return client
except Exception as e:
logger.error(f"Error during client creation: {str(e)}")
return None

def __init__(self, data_path="./weaviate_data"):
self.data_path = data_path
self.inference_engine = LlamaCppInferenceEngine()
self.model_path = "./models/all-minilm-L6-v2-q5_k_m.gguf"
self.schema_config = schema_config

# setup schema for weaviate
weaviate_client = self.get_client(self.data_path)
if weaviate_client is not None:
try:
weaviate_client.connect()
self.setup_schema(weaviate_client)
except Exception as e:
logger.error(f"Failed to connect or setup schema: {str(e)}")
finally:
try:
weaviate_client.close()
except Exception as e:
logger.info(f"Failed to close client: {str(e)}")
else:
logger.error("Could not find client, skipping schema setup.")

def setup_schema(self, client):
for class_config in self.schema_config:
if not client.collections.exists(class_config["name"]):
client.collections.create(
class_config["name"], properties=class_config["properties"]
)
logger.info(f"Weaviate schema for class {class_config['name']} setup complete.")

async def search(self, query: str, limit=5, distance=0.3) -> list[object]:
"""
Search the 'Package' collection based on a query string.
Args:
query (str): The text query for which to search.
limit (int): The number of results to return.
Returns:
list: A list of matching results with their properties and distances.
"""
# Generate the vector for the query
query_vector = await self.inference_engine.embed(self.model_path, [query])

# Perform the vector search
weaviate_client = self.get_client(self.data_path)
if weaviate_client is None:
logger.error("Could not find client, not returning results.")
return []

try:
weaviate_client.connect()
collection = weaviate_client.collections.get("Package")
response = collection.query.near_vector(
query_vector[0],
limit=limit,
distance=distance,
return_metadata=MetadataQuery(distance=True),
)

weaviate_client.close()
if not response:
return []
return response.objects

except Exception as e:
logger.error(f"Error during search: {str(e)}")
return []
finally:
try:
weaviate_client.close()
except Exception as e:
logger.info(f"Failed to close client: {str(e)}")
27 changes: 27 additions & 0 deletions src/codegate/utils/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
def generate_vector_string(package) -> str:
vector_str = f"{package['name']}"
package_url = ""
type_map = {
"pypi": "Python package available on PyPI",
"npm": "JavaScript package available on NPM",
"go": "Go package",
"crates": "Rust package available on Crates",
"java": "Java package",
}
status_messages = {
"archived": "However, this package is found to be archived and no longer maintained.",
"deprecated": "However, this package is found to be deprecated and no longer "
"recommended for use.",
"malicious": "However, this package is found to be malicious.",
}
vector_str += f" is a {type_map.get(package['type'], 'unknown type')} "
package_url = f"https://trustypkg.dev/{package['type']}/{package['name']}"

# Add extra status
status_suffix = status_messages.get(package["status"], "")
if status_suffix:
vector_str += f"{status_suffix} For additional information refer to {package_url}"

# add description
vector_str += f" - Package offers this functionality: {package['description']}"
return vector_str
Binary file added src/weaviate_data/classifications.db
Binary file not shown.
Empty file.
Empty file.
Empty file.
Binary file added src/weaviate_data/modules.db
Binary file not shown.
Empty file.
Empty file.
1 change: 1 addition & 0 deletions src/weaviate_data/package/9e4pu9kSqOe2/proplengths
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"BucketedData":{},"SumData":{},"CountData":{},"ObjectCount":0}
Binary file added src/weaviate_data/package/9e4pu9kSqOe2/version
Binary file not shown.
Binary file added src/weaviate_data/raft/raft.db
Binary file not shown.
Binary file added src/weaviate_data/schema.db
Binary file not shown.
56 changes: 56 additions & 0 deletions tests/test_storage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from unittest.mock import AsyncMock, MagicMock, patch

import pytest

from codegate.storage.storage_engine import (
StorageEngine,
) # Adjust the import based on your actual path


@pytest.fixture
def mock_weaviate_client():
client = MagicMock()
response = MagicMock()
response.objects = [
{
"properties": {
"name": "test",
"type": "library",
"status": "active",
"description": "test description",
}
}
]
client.collections.get.return_value.query.near_vector.return_value = response
return client


@pytest.fixture
def mock_inference_engine():
engine = AsyncMock()
engine.embed.return_value = [0.1, 0.2, 0.3] # Adjust based on expected vector dimensions
return engine


@pytest.mark.asyncio
async def test_search(mock_weaviate_client, mock_inference_engine):
# Patch the WeaviateClient and LlamaCppInferenceEngine inside the test function
with (
patch("weaviate.WeaviateClient", return_value=mock_weaviate_client),
patch(
"codegate.inference.inference_engine.LlamaCppInferenceEngine",
return_value=mock_inference_engine,
),
):

# Initialize StorageEngine
storage_engine = StorageEngine(data_path="./weaviate_data")

# Invoke the search method
results = await storage_engine.search("test query", 5, 0.3)

# Assertions to validate the expected behavior
assert len(results) == 1 # Assert that one result is returned
assert results[0]["properties"]["name"] == "test"
mock_weaviate_client.connect.assert_called()
mock_weaviate_client.close.assert_called()
Binary file added weaviate_data/classifications.db
Binary file not shown.
Empty file.
Empty file.
Empty file.
Binary file added weaviate_data/modules.db
Binary file not shown.
Empty file.
Empty file.
1 change: 1 addition & 0 deletions weaviate_data/package/yhcabdxdWUhw/proplengths
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"BucketedData":{},"SumData":{},"CountData":{},"ObjectCount":0}
Binary file added weaviate_data/package/yhcabdxdWUhw/version
Binary file not shown.
Binary file added weaviate_data/raft/raft.db
Binary file not shown.
Binary file added weaviate_data/schema.db
Binary file not shown.

0 comments on commit 466d495

Please sign in to comment.