Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
yrobla committed Nov 29, 2024
1 parent 9e5c84a commit 1e30daa
Show file tree
Hide file tree
Showing 13 changed files with 60 additions and 64 deletions.
43 changes: 24 additions & 19 deletions src/codegate/storage/storage_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,31 +19,34 @@


class StorageEngine:
def __init__(self, data_path='./weaviate_data'):
self.client = weaviate.WeaviateClient(
def get_client(self, data_path):
client = weaviate.WeaviateClient(
embedded_options=weaviate.EmbeddedOptions(
persistence_data_path=data_path
),
)
return client

def __init__(self, data_path='./weaviate_data'):
self.__logger = setup_logging()
self.data_path = data_path
self.inference_engine = LlamaCppInferenceEngine()
self.model_path = "./models/all-minilm-L6-v2-q5_k_m.gguf"
self.schema_config = schema_config
self.connect()
self.setup_schema()

def connect(self):
self.client.connect()
if self.client.is_ready():
self.__logger.info("Weaviate connection established and client is ready.")
else:
raise Exception("Weaviate client is not ready.")
# setup schema for weaviate
weaviate_client = self.get_client(self.data_path)
try:
weaviate_client.connect()
self.setup_schema(weaviate_client)
finally:
weaviate_client.close()

def setup_schema(self):
def setup_schema(self, client):
for class_config in self.schema_config:
if not self.client.collections.exists(class_config['name']):
self.client.collections.create(class_config['name'],
properties=class_config['properties'])
if not client.collections.exists(class_config['name']):
client.collections.create(class_config['name'],
properties=class_config['properties'])
self.__logger.info(
f"Weaviate schema for class {class_config['name']} setup complete.")

Expand All @@ -63,18 +66,20 @@ async def search(self, query: str, limit=5, distance=0.3) -> list[object]:

# Perform the vector search
try:
collection = self.client.collections.get("Package")
weaviate_client = self.get_client(self.data_path)
weaviate_client.connect()
collection = weaviate_client.collections.get("Package")
response = collection.query.near_vector(
query_vector[0], limit=limit, distance=distance,
return_metadata=MetadataQuery(distance=True))

weaviate_client.close()
if not response:
return []
return response.objects

except Exception as e:
self.__logger.error(f"Error during search: {str(e)}")
return []

def close(self):
self.client.close()
self.__logger.info("Connection closed.")
finally:
weaviate_client.close()
Binary file added src/weaviate_data/classifications.db
Binary file not shown.
Empty file.
Empty file.
Empty file.
Binary file added src/weaviate_data/modules.db
Binary file not shown.
Empty file.
Empty file.
1 change: 1 addition & 0 deletions src/weaviate_data/package/9e4pu9kSqOe2/proplengths
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"BucketedData":{},"SumData":{},"CountData":{},"ObjectCount":0}
Binary file added src/weaviate_data/package/9e4pu9kSqOe2/version
Binary file not shown.
Binary file added src/weaviate_data/raft/raft.db
Binary file not shown.
Binary file added src/weaviate_data/schema.db
Binary file not shown.
80 changes: 35 additions & 45 deletions tests/test_storage.py
Original file line number Diff line number Diff line change
@@ -1,58 +1,48 @@
import pytest
from unittest.mock import Mock, AsyncMock
from codegate.storage.storage_engine import StorageEngine # Adjust the import according to your project structure
from unittest.mock import patch, MagicMock, AsyncMock
from codegate.storage.storage_engine import StorageEngine # Adjust the import based on your actual path


@pytest.fixture
def mock_client():
client = Mock()
client.connect = Mock()
client.is_ready = Mock(return_value=True)
client.schema.contains = Mock(return_value=False)
client.schema.create_class = Mock()
client.collections.get = Mock()
client.close = Mock()
def mock_weaviate_client():
client = MagicMock()
response = MagicMock()
response.objects = [
{
'properties': {
'name': 'test',
'type': 'library',
'status': 'active',
'description': 'test description'
}
}
]
client.collections.get.return_value.query.near_vector.return_value = response
return client


@pytest.fixture
def mock_logger():
logger = Mock()
return logger


@pytest.fixture
def mock_inference_engine():
inference_engine = AsyncMock()
inference_engine.embed = AsyncMock(
return_value=[0.1, 0.2, 0.3]) # Adjust based on expected vector dimensions
return inference_engine


@pytest.fixture
def storage_engine(mock_client, mock_logger, mock_inference_engine):
engine = StorageEngine(data_path='./weaviate_data')
engine.client = mock_client
engine.__logger = mock_logger
engine.inference_engine = mock_inference_engine
engine = AsyncMock()
engine.embed.return_value = [0.1, 0.2, 0.3] # Adjust based on expected vector dimensions
return engine


def test_connect(storage_engine, mock_client):
storage_engine.connect()
mock_client.connect.assert_called_once()
mock_client.is_ready.assert_called_once()


@pytest.mark.asyncio
async def test_search(storage_engine, mock_client):
query = "test query"
results = await storage_engine.search(query)
storage_engine.inference_engine.embed.assert_called_once_with(
"./models/all-minilm-L6-v2-q5_k_m.gguf", [query])
assert results is not None # Further asserts can be based on your application logic


def test_close(storage_engine, mock_client):
storage_engine.close()
mock_client.close.assert_called_once()
async def test_search(mock_weaviate_client, mock_inference_engine):
# Patch the WeaviateClient and LlamaCppInferenceEngine inside the test function
with patch('weaviate.WeaviateClient', return_value=mock_weaviate_client), \
patch('codegate.inference.inference_engine.LlamaCppInferenceEngine',
return_value=mock_inference_engine):

# Initialize StorageEngine
storage_engine = StorageEngine(data_path='./weaviate_data')

# Invoke the search method
results = await storage_engine.search("test query", 5, 0.3)

# Assertions to validate the expected behavior
assert len(results) == 1 # Assert that one result is returned
assert results[0]['properties']['name'] == 'test'
mock_weaviate_client.connect.assert_called()
mock_weaviate_client.close.assert_called()

0 comments on commit 1e30daa

Please sign in to comment.