Skip to content

Commit

Permalink
small test updates
Browse files Browse the repository at this point in the history
  • Loading branch information
eavanvalkenburg committed Jan 23, 2025
1 parent aad2908 commit 7854a57
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@
from pymongo.asynchronous.database import AsyncDatabase
from pymongo.driver_info import DriverInfo

from semantic_kernel.connectors.memory.mongodb_atlas.const import MONGODB_ID_FIELD
from semantic_kernel.connectors.memory.mongodb_atlas.const import (
DEFAULT_DB_NAME,
DEFAULT_SEARCH_INDEX_NAME,
MONGODB_ID_FIELD,
)
from semantic_kernel.connectors.memory.mongodb_atlas.utils import create_index_definition
from semantic_kernel.data.filter_clauses import AnyTagsEqualTo, EqualTo
from semantic_kernel.data.kernel_search_results import KernelSearchResults
Expand Down Expand Up @@ -57,9 +61,9 @@ class MongoDBAtlasCollection(

def __init__(
self,
collection_name: str,
data_model_type: type[TModel],
data_model_definition: VectorStoreRecordDefinition | None = None,
collection_name: str | None = None,
index_name: str | None = None,
mongo_client: AsyncMongoClient | None = None,
**kwargs: Any,
Expand All @@ -81,17 +85,16 @@ def __init__(
env_file_encoding: str | None = None
"""
if not collection_name:
raise VectorStoreInitializationException("Collection name is required.")
if mongo_client and "database_name" in kwargs:
managed_client = not mongo_client
if mongo_client:
super().__init__(
data_model_type=data_model_type,
data_model_definition=data_model_definition,
mongo_client=mongo_client,
collection_name=collection_name,
database_name=kwargs["database_name"],
index_name=index_name or f"{collection_name}_idx",
managed_client=False,
database_name=kwargs.get("database_name", DEFAULT_DB_NAME),
index_name=index_name or DEFAULT_SEARCH_INDEX_NAME,
managed_client=managed_client,
)
return

Expand All @@ -103,17 +106,15 @@ def __init__(
env_file_encoding=kwargs.get("env_file_encoding"),
connection_string=kwargs.get("connection_string"),
database_name=kwargs.get("database_name"),
index_name=index_name,
)
except ValidationError as exc:
raise VectorStoreInitializationException("Failed to create MongoDB Atlas settings.") from exc
managed_client = not mongo_client
if not mongo_client:
mongo_client = AsyncMongoClient(
mongodb_atlas_settings.connection_string.get_secret_value(),
driver=DriverInfo("Microsoft Semantic Kernel", metadata.version("semantic-kernel")),
)
if not mongodb_atlas_settings.database_name:
raise VectorStoreInitializationException("Database name is required.")

super().__init__(
data_model_type=data_model_type,
Expand All @@ -122,7 +123,7 @@ def __init__(
mongo_client=mongo_client,
managed_client=managed_client,
database_name=mongodb_atlas_settings.database_name,
index_name=index_name or f"{collection_name}_idx",
index_name=mongodb_atlas_settings.index_name,
)

def _get_database(self) -> AsyncDatabase:
Expand Down Expand Up @@ -186,16 +187,18 @@ def _deserialize_store_models_to_dicts(self, records: Sequence[Any], **kwargs: A
async def create_collection(self, **kwargs) -> None:
"""Create a new collection in MongoDB Atlas.
This first creates a collection, with the kwargs.
Then creates a search index based on the data model definition.
Args:
**kwargs: Additional keyword arguments.
"""
database = self._get_database()
collection = await database.create_collection(self.collection_name, **kwargs)
collection = await self._get_database().create_collection(self.collection_name, **kwargs)
await collection.create_search_index(create_index_definition(self.data_model_definition, self.index_name))

@override
async def does_collection_exist(self, **kwargs) -> bool:
return self.collection_name in await self._get_database().list_collection_names()
return bool(await self._get_database().list_collection_names(filter={"name": self.collection_name}))

@override
async def delete_collection(self, **kwargs) -> None:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Copyright (c) Microsoft. All rights reserved.

from typing import Annotated, ClassVar
from typing import ClassVar

from pydantic import Field, SecretStr
from pydantic import SecretStr

from semantic_kernel.connectors.memory.mongodb_atlas.const import DEFAULT_DB_NAME, DEFAULT_SEARCH_INDEX_NAME
from semantic_kernel.kernel_pydantic import KernelBaseSettings
Expand All @@ -17,12 +17,13 @@ class MongoDBAtlasSettings(KernelBaseSettings):
- connection_string: str - MongoDB Atlas connection string
(Env var MONGODB_ATLAS_CONNECTION_STRING)
- database_name: str - MongoDB Atlas database name, defaults to 'default'
(Env var MONGODB_ATLAS_DATABASE_NAME)
- index_name: str - MongoDB Atlas search index name, defaults to 'default'
(Env var MONGODB_ATLAS_INDEX_NAME)
"""

env_prefix: ClassVar[str] = "MONGODB_ATLAS_"

connection_string: SecretStr
database_name: str = DEFAULT_DB_NAME
index_name: Annotated[str, Field(deprecated="This field is not used with the new store and collection")] = (
DEFAULT_SEARCH_INDEX_NAME
)
index_name: str = DEFAULT_SEARCH_INDEX_NAME
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
# Copyright (c) Microsoft. All rights reserved.


from unittest.mock import AsyncMock, patch

from pymongo import AsyncMongoClient
from pymongo.asynchronous.cursor import AsyncCursor
from pymongo.results import UpdateResult
from pytest import mark, raises

from semantic_kernel.connectors.memory.mongodb_atlas.const import DEFAULT_DB_NAME, DEFAULT_SEARCH_INDEX_NAME
from semantic_kernel.connectors.memory.mongodb_atlas.mongodb_atlas_collection import MongoDBAtlasCollection
from semantic_kernel.exceptions.vector_store_exceptions import VectorStoreInitializationException


def test_mongodb_atlas_collection_initialization(mongodb_atlas_unit_test_env, data_model_definition, mock_mongo_client):
Expand All @@ -21,6 +23,27 @@ def test_mongodb_atlas_collection_initialization(mongodb_atlas_unit_test_env, da
assert isinstance(collection.mongo_client, AsyncMongoClient)


@mark.parametrize("exclude_list", [["MONGODB_ATLAS_CONNECTION_STRING"]], indirect=True)
def test_mongodb_atlas_collection_initialization_fail(mongodb_atlas_unit_test_env, data_model_definition):
with raises(VectorStoreInitializationException):
MongoDBAtlasCollection(
collection_name="test_collection",
data_model_type=dict,
data_model_definition=data_model_definition,
)


@mark.parametrize("exclude_list", [["MONGODB_ATLAS_DATABASE_NAME", "MONGODB_ATLAS_INDEX_NAME"]], indirect=True)
def test_mongodb_atlas_collection_initialization_defaults(mongodb_atlas_unit_test_env, data_model_definition):
collection = MongoDBAtlasCollection(
collection_name="test_collection",
data_model_type=dict,
data_model_definition=data_model_definition,
)
assert collection.database_name == DEFAULT_DB_NAME
assert collection.index_name == DEFAULT_SEARCH_INDEX_NAME


async def test_mongodb_atlas_collection_upsert(mongodb_atlas_unit_test_env, data_model_definition, mock_get_collection):
collection = MongoDBAtlasCollection(
data_model_type=dict,
Expand Down Expand Up @@ -58,3 +81,16 @@ async def test_mongodb_atlas_collection_delete(mongodb_atlas_unit_test_env, data
with patch.object(collection, "_get_collection", new=mock_get_collection) as mock_get:
await collection._inner_delete(["test_id"])
mock_get.return_value.delete_many.assert_called_with({"_id": {"$in": ["test_id"]}})


async def test_mongodb_atlas_collection_collection_exists(
mongodb_atlas_unit_test_env, data_model_definition, mock_get_database
):
collection = MongoDBAtlasCollection(
data_model_type=dict,
data_model_definition=data_model_definition,
collection_name="test_collection",
)
with patch.object(collection, "_get_database", new=mock_get_database) as mock_get:
mock_get.return_value.list_collection_names.return_value = ["test_collection"]
assert await collection.does_collection_exist()
11 changes: 0 additions & 11 deletions python/uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 7854a57

Please sign in to comment.