-
Notifications
You must be signed in to change notification settings - Fork 33
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
13 changed files
with
60 additions
and
64 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
Empty file.
Empty file.
Empty file.
Binary file not shown.
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"BucketedData":{},"SumData":{},"CountData":{},"ObjectCount":0} |
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,58 +1,48 @@ | ||
import pytest | ||
from unittest.mock import Mock, AsyncMock | ||
from codegate.storage.storage_engine import StorageEngine # Adjust the import according to your project structure | ||
from unittest.mock import patch, MagicMock, AsyncMock | ||
from codegate.storage.storage_engine import StorageEngine # Adjust the import based on your actual path | ||
|
||
|
||
@pytest.fixture | ||
def mock_client(): | ||
client = Mock() | ||
client.connect = Mock() | ||
client.is_ready = Mock(return_value=True) | ||
client.schema.contains = Mock(return_value=False) | ||
client.schema.create_class = Mock() | ||
client.collections.get = Mock() | ||
client.close = Mock() | ||
def mock_weaviate_client(): | ||
client = MagicMock() | ||
response = MagicMock() | ||
response.objects = [ | ||
{ | ||
'properties': { | ||
'name': 'test', | ||
'type': 'library', | ||
'status': 'active', | ||
'description': 'test description' | ||
} | ||
} | ||
] | ||
client.collections.get.return_value.query.near_vector.return_value = response | ||
return client | ||
|
||
|
||
@pytest.fixture | ||
def mock_logger(): | ||
logger = Mock() | ||
return logger | ||
|
||
|
||
@pytest.fixture | ||
def mock_inference_engine(): | ||
inference_engine = AsyncMock() | ||
inference_engine.embed = AsyncMock( | ||
return_value=[0.1, 0.2, 0.3]) # Adjust based on expected vector dimensions | ||
return inference_engine | ||
|
||
|
||
@pytest.fixture | ||
def storage_engine(mock_client, mock_logger, mock_inference_engine): | ||
engine = StorageEngine(data_path='./weaviate_data') | ||
engine.client = mock_client | ||
engine.__logger = mock_logger | ||
engine.inference_engine = mock_inference_engine | ||
engine = AsyncMock() | ||
engine.embed.return_value = [0.1, 0.2, 0.3] # Adjust based on expected vector dimensions | ||
return engine | ||
|
||
|
||
def test_connect(storage_engine, mock_client): | ||
storage_engine.connect() | ||
mock_client.connect.assert_called_once() | ||
mock_client.is_ready.assert_called_once() | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_search(storage_engine, mock_client): | ||
query = "test query" | ||
results = await storage_engine.search(query) | ||
storage_engine.inference_engine.embed.assert_called_once_with( | ||
"./models/all-minilm-L6-v2-q5_k_m.gguf", [query]) | ||
assert results is not None # Further asserts can be based on your application logic | ||
|
||
|
||
def test_close(storage_engine, mock_client): | ||
storage_engine.close() | ||
mock_client.close.assert_called_once() | ||
async def test_search(mock_weaviate_client, mock_inference_engine): | ||
# Patch the WeaviateClient and LlamaCppInferenceEngine inside the test function | ||
with patch('weaviate.WeaviateClient', return_value=mock_weaviate_client), \ | ||
patch('codegate.inference.inference_engine.LlamaCppInferenceEngine', | ||
return_value=mock_inference_engine): | ||
|
||
# Initialize StorageEngine | ||
storage_engine = StorageEngine(data_path='./weaviate_data') | ||
|
||
# Invoke the search method | ||
results = await storage_engine.search("test query", 5, 0.3) | ||
|
||
# Assertions to validate the expected behavior | ||
assert len(results) == 1 # Assert that one result is returned | ||
assert results[0]['properties']['name'] == 'test' | ||
mock_weaviate_client.connect.assert_called() | ||
mock_weaviate_client.close.assert_called() |