Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: enable weaviate usage in codegate #128

Merged
merged 11 commits into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

28 changes: 2 additions & 26 deletions scripts/import_packages.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from weaviate.util import generate_uuid5

from codegate.inference.inference_engine import LlamaCppInferenceEngine
from src.codegate.utils.utils import generate_vector_string


class PackageImporter:
Expand Down Expand Up @@ -37,33 +38,8 @@ def setup_schema(self):
],
)

def generate_vector_string(self, package):
vector_str = f"{package['name']}"
package_url = ""
type_map = {
"pypi": "Python package available on PyPI",
"npm": "JavaScript package available on NPM",
"go": "Go package",
"crates": "Rust package available on Crates",
"java": "Java package",
}
status_messages = {
"archived": "However, this package is found to be archived and no longer maintained.",
"deprecated": "However, this package is found to be deprecated and no longer "
"recommended for use.",
"malicious": "However, this package is found to be malicious.",
}
vector_str += f" is a {type_map.get(package['type'], 'unknown type')} "
package_url = f"https://trustypkg.dev/{package['type']}/{package['name']}"

# Add extra status
status_suffix = status_messages.get(package["status"], "")
if status_suffix:
vector_str += f"{status_suffix} For additional information refer to {package_url}"
return vector_str

async def process_package(self, batch, package):
vector_str = self.generate_vector_string(package)
vector_str = generate_vector_string(package)
vector = await self.inference_engine.embed(self.model_path, [vector_str])
# This is where the synchronous call is made
batch.add_object(properties=package, vector=vector[0])
Expand Down
3 changes: 3 additions & 0 deletions src/codegate/pipeline/codegate_context_retriever/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from codegate.pipeline.codegate_context_retriever.codegate import CodegateContextRetriever

__all__ = ["CodegateContextRetriever"]
82 changes: 82 additions & 0 deletions src/codegate/pipeline/codegate_context_retriever/codegate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from typing import Optional

from litellm import ChatCompletionRequest, ChatCompletionSystemMessage

from codegate.pipeline.base import (
PipelineContext,
PipelineResult,
PipelineStep,
)
from src.codegate.storage.storage_engine import StorageEngine
from src.codegate.utils.utils import generate_vector_string


class CodegateContextRetriever(PipelineStep):
"""
Pipeline step that adds a context message to the completion request when it detects
the word "codegate" in the user message.
"""
Comment on lines +11 to +18
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question from @lukehinds on the duplicated PR #133:

Would like to understand this a little better, does it mean we only report on MAL pkgs if someone says 'codegate', what does the word trigger?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jhrozek as you proposed that system... I believe that the codegate word Will be prepended automatically ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was just an example. We need to figure out when do we trigger the RAG. With the Go implementation we did that on all conversations that had no code snippets, which is an easy implementation. We could do that for a start, but the downside was that the RAG was called also if the user didn't ask about any package. For example if they asked "give me an example SQL query" the RAG would match all packages with the word SQL in it. It would be nice to work together with @ptelang to figure out how to fine tune this.


def __init__(self, system_prompt_message: Optional[str] = None):
self._system_message = ChatCompletionSystemMessage(
content=system_prompt_message, role="system"
)
self.storage_engine = StorageEngine()

@property
def name(self) -> str:
"""
Returns the name of this pipeline step.
"""
return "codegate-context-retriever"

async def get_objects_from_search(self, search: str) -> list[object]:
objects = await self.storage_engine.search(search)
return objects

def generate_context_str(self, objects: list[object]) -> str:
context_str = "Please use the information about related packages "
"to influence your answer:\n"
for obj in objects:
# generate dictionary from object
package_obj = {
"name": obj.properties["name"],
"type": obj.properties["type"],
"status": obj.properties["status"],
"description": obj.properties["description"],
}
package_str = generate_vector_string(package_obj)
context_str += package_str + "\n"
return context_str

async def process(
self, request: ChatCompletionRequest, context: PipelineContext
) -> PipelineResult:
"""
Process the completion request and add a system prompt if the user message contains
the word "codegate".
"""
# no prompt configured
if not self._system_message["content"]:
return PipelineResult(request=request)

last_user_message = self.get_last_user_message(request)

if last_user_message is not None:
last_user_message_str, last_user_idx = last_user_message
if "codegate" in last_user_message_str.lower():
# strip codegate from prompt and trim it
last_user_message_str = (
last_user_message_str.lower().replace("codegate", "").strip()
)
searched_objects = await self.get_objects_from_search(last_user_message_str)
context_str = self.generate_context_str(searched_objects)
# Add a system prompt to the completion request
new_request = request.copy()
new_request["messages"].insert(last_user_idx, context_str)
return PipelineResult(
request=new_request,
)

# Fall through
return PipelineResult(request=request)
1 change: 1 addition & 0 deletions src/codegate/pipeline/extract_snippets/extract_snippets.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

logger = structlog.get_logger("codegate")


def ecosystem_from_filepath(filepath: str) -> Optional[str]:
"""
Determine language from filepath.
Expand Down
2 changes: 2 additions & 0 deletions src/codegate/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from codegate import __description__, __version__
from codegate.config import Config
from codegate.pipeline.base import PipelineStep, SequentialPipelineProcessor
from codegate.pipeline.codegate_context_retriever.codegate import CodegateContextRetriever
from codegate.pipeline.codegate_system_prompt.codegate import CodegateSystemPrompt
from codegate.pipeline.extract_snippets.extract_snippets import CodeSnippetExtractor
from codegate.pipeline.version.version import CodegateVersion
Expand All @@ -26,6 +27,7 @@ def init_app() -> FastAPI:
CodegateVersion(),
CodeSnippetExtractor(),
CodegateSystemPrompt(Config.get_config().prompts.codegate_chat),
CodegateContextRetriever(Config.get_config().prompts.codegate_chat),
# CodegateSecrets(),
]
# Leaving the pipeline empty for now
Expand Down
3 changes: 3 additions & 0 deletions src/codegate/storage/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from codegate.storage.storage_engine import StorageEngine

__all__ = [StorageEngine]
106 changes: 106 additions & 0 deletions src/codegate/storage/storage_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import structlog
import weaviate
from weaviate.classes.config import DataType
from weaviate.classes.query import MetadataQuery

from codegate.inference.inference_engine import LlamaCppInferenceEngine

logger = structlog.get_logger("codegate")

schema_config = [
{
"name": "Package",
"properties": [
{"name": "name", "data_type": DataType.TEXT},
{"name": "type", "data_type": DataType.TEXT},
{"name": "status", "data_type": DataType.TEXT},
{"name": "description", "data_type": DataType.TEXT},
],
},
]


class StorageEngine:
def get_client(self, data_path):
try:
client = weaviate.WeaviateClient(
embedded_options=weaviate.EmbeddedOptions(persistence_data_path=data_path),
)
return client
except Exception as e:
logger.error(f"Error during client creation: {str(e)}")
return None

def __init__(self, data_path="./weaviate_data"):
self.data_path = data_path
self.inference_engine = LlamaCppInferenceEngine()
self.model_path = "./models/all-minilm-L6-v2-q5_k_m.gguf"
self.schema_config = schema_config

# setup schema for weaviate
weaviate_client = self.get_client(self.data_path)
if weaviate_client is not None:
try:
weaviate_client.connect()
self.setup_schema(weaviate_client)
except Exception as e:
logger.error(f"Failed to connect or setup schema: {str(e)}")
finally:
try:
weaviate_client.close()
except Exception as e:
logger.info(f"Failed to close client: {str(e)}")
else:
logger.error("Could not find client, skipping schema setup.")

def setup_schema(self, client):
for class_config in self.schema_config:
if not client.collections.exists(class_config["name"]):
client.collections.create(
class_config["name"], properties=class_config["properties"]
)
logger.info(f"Weaviate schema for class {class_config['name']} setup complete.")

async def search(self, query: str, limit=5, distance=0.3) -> list[object]:
"""
Search the 'Package' collection based on a query string.

Args:
query (str): The text query for which to search.
limit (int): The number of results to return.

Returns:
list: A list of matching results with their properties and distances.
"""
# Generate the vector for the query
query_vector = await self.inference_engine.embed(self.model_path, [query])

# Perform the vector search
weaviate_client = self.get_client(self.data_path)
if weaviate_client is None:
logger.error("Could not find client, not returning results.")
return []

try:
weaviate_client.connect()
collection = weaviate_client.collections.get("Package")
response = collection.query.near_vector(
query_vector[0],
limit=limit,
distance=distance,
return_metadata=MetadataQuery(distance=True),
)

weaviate_client.close()
if not response:
return []
return response.objects

except Exception as e:
logger.error(f"Error during search: {str(e)}")
return []
finally:
try:
weaviate_client.close()
except Exception as e:
logger.info(f"Failed to close client: {str(e)}")
27 changes: 27 additions & 0 deletions src/codegate/utils/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
def generate_vector_string(package) -> str:
vector_str = f"{package['name']}"
package_url = ""
type_map = {
"pypi": "Python package available on PyPI",
"npm": "JavaScript package available on NPM",
"go": "Go package",
"crates": "Rust package available on Crates",
"java": "Java package",
}
status_messages = {
"archived": "However, this package is found to be archived and no longer maintained.",
"deprecated": "However, this package is found to be deprecated and no longer "
"recommended for use.",
"malicious": "However, this package is found to be malicious.",
}
vector_str += f" is a {type_map.get(package['type'], 'unknown type')} "
package_url = f"https://trustypkg.dev/{package['type']}/{package['name']}"

# Add extra status
status_suffix = status_messages.get(package["status"], "")
if status_suffix:
vector_str += f"{status_suffix} For additional information refer to {package_url}"

# add description
vector_str += f" - Package offers this functionality: {package['description']}"
return vector_str
Binary file added src/weaviate_data/classifications.db
Binary file not shown.
Empty file.
Empty file.
Empty file.
Binary file added src/weaviate_data/modules.db
Binary file not shown.
Empty file.
Empty file.
1 change: 1 addition & 0 deletions src/weaviate_data/package/9e4pu9kSqOe2/proplengths
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"BucketedData":{},"SumData":{},"CountData":{},"ObjectCount":0}
Binary file added src/weaviate_data/package/9e4pu9kSqOe2/version
Binary file not shown.
Binary file added src/weaviate_data/raft/raft.db
Binary file not shown.
Binary file added src/weaviate_data/schema.db
Binary file not shown.
56 changes: 56 additions & 0 deletions tests/test_storage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from unittest.mock import AsyncMock, MagicMock, patch

import pytest

from codegate.storage.storage_engine import (
StorageEngine,
) # Adjust the import based on your actual path


@pytest.fixture
def mock_weaviate_client():
client = MagicMock()
response = MagicMock()
response.objects = [
{
"properties": {
"name": "test",
"type": "library",
"status": "active",
"description": "test description",
}
}
]
client.collections.get.return_value.query.near_vector.return_value = response
return client


@pytest.fixture
def mock_inference_engine():
engine = AsyncMock()
engine.embed.return_value = [0.1, 0.2, 0.3] # Adjust based on expected vector dimensions
return engine


@pytest.mark.asyncio
async def test_search(mock_weaviate_client, mock_inference_engine):
# Patch the WeaviateClient and LlamaCppInferenceEngine inside the test function
with (
patch("weaviate.WeaviateClient", return_value=mock_weaviate_client),
patch(
"codegate.inference.inference_engine.LlamaCppInferenceEngine",
return_value=mock_inference_engine,
),
):

# Initialize StorageEngine
storage_engine = StorageEngine(data_path="./weaviate_data")

# Invoke the search method
results = await storage_engine.search("test query", 5, 0.3)

# Assertions to validate the expected behavior
assert len(results) == 1 # Assert that one result is returned
assert results[0]["properties"]["name"] == "test"
mock_weaviate_client.connect.assert_called()
mock_weaviate_client.close.assert_called()
Binary file added weaviate_data/classifications.db
Binary file not shown.
Empty file.
Empty file.
Empty file.
Binary file added weaviate_data/modules.db
Binary file not shown.
Empty file.
Empty file.
1 change: 1 addition & 0 deletions weaviate_data/package/yhcabdxdWUhw/proplengths
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"BucketedData":{},"SumData":{},"CountData":{},"ObjectCount":0}
Binary file added weaviate_data/package/yhcabdxdWUhw/version
Binary file not shown.
Binary file added weaviate_data/raft/raft.db
Binary file not shown.
Binary file added weaviate_data/schema.db
Binary file not shown.