-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #128 from stacklok/issue-63
feat: enable weaviate usage in codegate
- Loading branch information
Showing
31 changed files
with
286 additions
and
29 deletions.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from codegate.pipeline.codegate_context_retriever.codegate import CodegateContextRetriever | ||
|
||
__all__ = ["CodegateContextRetriever"] |
82 changes: 82 additions & 0 deletions
82
src/codegate/pipeline/codegate_context_retriever/codegate.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
from typing import Optional | ||
|
||
from litellm import ChatCompletionRequest, ChatCompletionSystemMessage | ||
|
||
from codegate.pipeline.base import ( | ||
PipelineContext, | ||
PipelineResult, | ||
PipelineStep, | ||
) | ||
from src.codegate.storage.storage_engine import StorageEngine | ||
from src.codegate.utils.utils import generate_vector_string | ||
|
||
|
||
class CodegateContextRetriever(PipelineStep): | ||
""" | ||
Pipeline step that adds a context message to the completion request when it detects | ||
the word "codegate" in the user message. | ||
""" | ||
|
||
def __init__(self, system_prompt_message: Optional[str] = None): | ||
self._system_message = ChatCompletionSystemMessage( | ||
content=system_prompt_message, role="system" | ||
) | ||
self.storage_engine = StorageEngine() | ||
|
||
@property | ||
def name(self) -> str: | ||
""" | ||
Returns the name of this pipeline step. | ||
""" | ||
return "codegate-context-retriever" | ||
|
||
async def get_objects_from_search(self, search: str) -> list[object]: | ||
objects = await self.storage_engine.search(search) | ||
return objects | ||
|
||
def generate_context_str(self, objects: list[object]) -> str: | ||
context_str = "Please use the information about related packages " | ||
"to influence your answer:\n" | ||
for obj in objects: | ||
# generate dictionary from object | ||
package_obj = { | ||
"name": obj.properties["name"], | ||
"type": obj.properties["type"], | ||
"status": obj.properties["status"], | ||
"description": obj.properties["description"], | ||
} | ||
package_str = generate_vector_string(package_obj) | ||
context_str += package_str + "\n" | ||
return context_str | ||
|
||
async def process( | ||
self, request: ChatCompletionRequest, context: PipelineContext | ||
) -> PipelineResult: | ||
""" | ||
Process the completion request and add a system prompt if the user message contains | ||
the word "codegate". | ||
""" | ||
# no prompt configured | ||
if not self._system_message["content"]: | ||
return PipelineResult(request=request) | ||
|
||
last_user_message = self.get_last_user_message(request) | ||
|
||
if last_user_message is not None: | ||
last_user_message_str, last_user_idx = last_user_message | ||
if "codegate" in last_user_message_str.lower(): | ||
# strip codegate from prompt and trim it | ||
last_user_message_str = ( | ||
last_user_message_str.lower().replace("codegate", "").strip() | ||
) | ||
searched_objects = await self.get_objects_from_search(last_user_message_str) | ||
context_str = self.generate_context_str(searched_objects) | ||
# Add a system prompt to the completion request | ||
new_request = request.copy() | ||
new_request["messages"].insert(last_user_idx, context_str) | ||
return PipelineResult( | ||
request=new_request, | ||
) | ||
|
||
# Fall through | ||
return PipelineResult(request=request) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from codegate.storage.storage_engine import StorageEngine | ||
|
||
__all__ = [StorageEngine] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
import structlog | ||
import weaviate | ||
from weaviate.classes.config import DataType | ||
from weaviate.classes.query import MetadataQuery | ||
|
||
from codegate.inference.inference_engine import LlamaCppInferenceEngine | ||
|
||
logger = structlog.get_logger("codegate") | ||
|
||
schema_config = [ | ||
{ | ||
"name": "Package", | ||
"properties": [ | ||
{"name": "name", "data_type": DataType.TEXT}, | ||
{"name": "type", "data_type": DataType.TEXT}, | ||
{"name": "status", "data_type": DataType.TEXT}, | ||
{"name": "description", "data_type": DataType.TEXT}, | ||
], | ||
}, | ||
] | ||
|
||
|
||
class StorageEngine: | ||
def get_client(self, data_path): | ||
try: | ||
client = weaviate.WeaviateClient( | ||
embedded_options=weaviate.EmbeddedOptions(persistence_data_path=data_path), | ||
) | ||
return client | ||
except Exception as e: | ||
logger.error(f"Error during client creation: {str(e)}") | ||
return None | ||
|
||
def __init__(self, data_path="./weaviate_data"): | ||
self.data_path = data_path | ||
self.inference_engine = LlamaCppInferenceEngine() | ||
self.model_path = "./models/all-minilm-L6-v2-q5_k_m.gguf" | ||
self.schema_config = schema_config | ||
|
||
# setup schema for weaviate | ||
weaviate_client = self.get_client(self.data_path) | ||
if weaviate_client is not None: | ||
try: | ||
weaviate_client.connect() | ||
self.setup_schema(weaviate_client) | ||
except Exception as e: | ||
logger.error(f"Failed to connect or setup schema: {str(e)}") | ||
finally: | ||
try: | ||
weaviate_client.close() | ||
except Exception as e: | ||
logger.info(f"Failed to close client: {str(e)}") | ||
else: | ||
logger.error("Could not find client, skipping schema setup.") | ||
|
||
def setup_schema(self, client): | ||
for class_config in self.schema_config: | ||
if not client.collections.exists(class_config["name"]): | ||
client.collections.create( | ||
class_config["name"], properties=class_config["properties"] | ||
) | ||
logger.info(f"Weaviate schema for class {class_config['name']} setup complete.") | ||
|
||
async def search(self, query: str, limit=5, distance=0.3) -> list[object]: | ||
""" | ||
Search the 'Package' collection based on a query string. | ||
Args: | ||
query (str): The text query for which to search. | ||
limit (int): The number of results to return. | ||
Returns: | ||
list: A list of matching results with their properties and distances. | ||
""" | ||
# Generate the vector for the query | ||
query_vector = await self.inference_engine.embed(self.model_path, [query]) | ||
|
||
# Perform the vector search | ||
weaviate_client = self.get_client(self.data_path) | ||
if weaviate_client is None: | ||
logger.error("Could not find client, not returning results.") | ||
return [] | ||
|
||
try: | ||
weaviate_client.connect() | ||
collection = weaviate_client.collections.get("Package") | ||
response = collection.query.near_vector( | ||
query_vector[0], | ||
limit=limit, | ||
distance=distance, | ||
return_metadata=MetadataQuery(distance=True), | ||
) | ||
|
||
weaviate_client.close() | ||
if not response: | ||
return [] | ||
return response.objects | ||
|
||
except Exception as e: | ||
logger.error(f"Error during search: {str(e)}") | ||
return [] | ||
finally: | ||
try: | ||
weaviate_client.close() | ||
except Exception as e: | ||
logger.info(f"Failed to close client: {str(e)}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
def generate_vector_string(package) -> str: | ||
vector_str = f"{package['name']}" | ||
package_url = "" | ||
type_map = { | ||
"pypi": "Python package available on PyPI", | ||
"npm": "JavaScript package available on NPM", | ||
"go": "Go package", | ||
"crates": "Rust package available on Crates", | ||
"java": "Java package", | ||
} | ||
status_messages = { | ||
"archived": "However, this package is found to be archived and no longer maintained.", | ||
"deprecated": "However, this package is found to be deprecated and no longer " | ||
"recommended for use.", | ||
"malicious": "However, this package is found to be malicious.", | ||
} | ||
vector_str += f" is a {type_map.get(package['type'], 'unknown type')} " | ||
package_url = f"https://trustypkg.dev/{package['type']}/{package['name']}" | ||
|
||
# Add extra status | ||
status_suffix = status_messages.get(package["status"], "") | ||
if status_suffix: | ||
vector_str += f"{status_suffix} For additional information refer to {package_url}" | ||
|
||
# add description | ||
vector_str += f" - Package offers this functionality: {package['description']}" | ||
return vector_str |
Binary file not shown.
Empty file.
Empty file.
Empty file.
Binary file not shown.
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"BucketedData":{},"SumData":{},"CountData":{},"ObjectCount":0} |
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
from unittest.mock import AsyncMock, MagicMock, patch | ||
|
||
import pytest | ||
|
||
from codegate.storage.storage_engine import ( | ||
StorageEngine, | ||
) # Adjust the import based on your actual path | ||
|
||
|
||
@pytest.fixture | ||
def mock_weaviate_client(): | ||
client = MagicMock() | ||
response = MagicMock() | ||
response.objects = [ | ||
{ | ||
"properties": { | ||
"name": "test", | ||
"type": "library", | ||
"status": "active", | ||
"description": "test description", | ||
} | ||
} | ||
] | ||
client.collections.get.return_value.query.near_vector.return_value = response | ||
return client | ||
|
||
|
||
@pytest.fixture | ||
def mock_inference_engine(): | ||
engine = AsyncMock() | ||
engine.embed.return_value = [0.1, 0.2, 0.3] # Adjust based on expected vector dimensions | ||
return engine | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_search(mock_weaviate_client, mock_inference_engine): | ||
# Patch the WeaviateClient and LlamaCppInferenceEngine inside the test function | ||
with ( | ||
patch("weaviate.WeaviateClient", return_value=mock_weaviate_client), | ||
patch( | ||
"codegate.inference.inference_engine.LlamaCppInferenceEngine", | ||
return_value=mock_inference_engine, | ||
), | ||
): | ||
|
||
# Initialize StorageEngine | ||
storage_engine = StorageEngine(data_path="./weaviate_data") | ||
|
||
# Invoke the search method | ||
results = await storage_engine.search("test query", 5, 0.3) | ||
|
||
# Assertions to validate the expected behavior | ||
assert len(results) == 1 # Assert that one result is returned | ||
assert results[0]["properties"]["name"] == "test" | ||
mock_weaviate_client.connect.assert_called() | ||
mock_weaviate_client.close.assert_called() |
Binary file not shown.
Empty file.
Empty file.
Empty file.
Binary file not shown.
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"BucketedData":{},"SumData":{},"CountData":{},"ObjectCount":0} |
Binary file not shown.
Binary file not shown.
Binary file not shown.