From 020f2738f5cec3cbd76e1092f00ca7f02adfc1b2 Mon Sep 17 00:00:00 2001
From: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>
Date: Tue, 16 Jan 2024 12:17:53 +0100
Subject: [PATCH] Add `collection_name` parameter and creation (#215)

* Add collection_name parameter

* Fix linting
---
 .../weaviate/document_store.py                | 11 ++++
 .../weaviate/tests/test_document_store.py     | 50 ++++++++++++++++++-
 2 files changed, 60 insertions(+), 1 deletion(-)

diff --git a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py
index 9317fb9de..4c15d707e 100644
--- a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py
+++ b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py
@@ -35,6 +35,7 @@ def __init__(
         self,
         *,
         url: Optional[str] = None,
+        collection_name: str = "default",
         auth_client_secret: Optional[AuthCredentials] = None,
         timeout_config: TimeoutType = (10, 60),
         proxies: Optional[Union[Dict, str]] = None,
@@ -79,6 +80,8 @@ def __init__(
         :param embedded_options: If set create an embedded Weaviate cluster inside the client, defaults to None.
             For a full list of options see `weaviate.embedded.EmbeddedOptions`.
         :param additional_config: Additional and advanced configuration options for weaviate, defaults to None.
+        :param collection_name: The name of the collection to use, defaults to "default".
+            If the collection does not exist it will be created.
         """
         self._client = weaviate.Client(
             url=url,
@@ -92,7 +95,14 @@ def __init__(
             additional_config=additional_config,
         )
 
+        # Test connection, it will raise an exception if it fails.
+        self._client.schema.get()
+
+        if not self._client.schema.exists(collection_name):
+            self._client.schema.create_class({"class": collection_name})
+
         self._url = url
+        self._collection_name = collection_name
         self._auth_client_secret = auth_client_secret
         self._timeout_config = timeout_config
         self._proxies = proxies
@@ -114,6 +124,7 @@ def to_dict(self) -> Dict[str, Any]:
         return default_to_dict(
             self,
             url=self._url,
+            collection_name=self._collection_name,
             auth_client_secret=auth_client_secret,
             timeout_config=self._timeout_config,
             proxies=self._proxies,
diff --git a/integrations/weaviate/tests/test_document_store.py b/integrations/weaviate/tests/test_document_store.py
index 3c72934b1..0666151ee 100644
--- a/integrations/weaviate/tests/test_document_store.py
+++ b/integrations/weaviate/tests/test_document_store.py
@@ -1,4 +1,4 @@
-from unittest.mock import patch
+from unittest.mock import MagicMock, patch
 
 from haystack_integrations.document_stores.weaviate.document_store import WeaviateDocumentStore
 from weaviate.auth import AuthApiKey
@@ -13,10 +13,55 @@
 
 
 class TestWeaviateDocumentStore:
+    @patch("haystack_integrations.document_stores.weaviate.document_store.weaviate.Client")
+    def test_init(self, mock_weaviate_client_class):
+        mock_client = MagicMock()
+        mock_client.schema.exists.return_value = False
+        mock_weaviate_client_class.return_value = mock_client
+
+        WeaviateDocumentStore(
+            url="http://localhost:8080",
+            collection_name="my_collection",
+            auth_client_secret=AuthApiKey("my_api_key"),
+            proxies={"http": "http://proxy:1234"},
+            additional_headers={"X-HuggingFace-Api-Key": "MY_HUGGINGFACE_KEY"},
+            embedded_options=EmbeddedOptions(
+                persistence_data_path=DEFAULT_PERSISTENCE_DATA_PATH,
+                binary_path=DEFAULT_BINARY_PATH,
+                version="1.23.0",
+                hostname="127.0.0.1",
+            ),
+            additional_config=Config(grpc_port_experimental=12345),
+        )
+
+        # Verify client is created with correct parameters
+        mock_weaviate_client_class.assert_called_once_with(
+            url="http://localhost:8080",
+            auth_client_secret=AuthApiKey("my_api_key"),
+            timeout_config=(10, 60),
+            proxies={"http": "http://proxy:1234"},
+            trust_env=False,
+            additional_headers={"X-HuggingFace-Api-Key": "MY_HUGGINGFACE_KEY"},
+            startup_period=5,
+            embedded_options=EmbeddedOptions(
+                persistence_data_path=DEFAULT_PERSISTENCE_DATA_PATH,
+                binary_path=DEFAULT_BINARY_PATH,
+                version="1.23.0",
+                hostname="127.0.0.1",
+            ),
+            additional_config=Config(grpc_port_experimental=12345),
+        )
+
+        # Verify collection is created
+        mock_client.schema.get.assert_called_once()
+        mock_client.schema.exists.assert_called_once_with("my_collection")
+        mock_client.schema.create_class.assert_called_once_with({"class": "my_collection"})
+
     @patch("haystack_integrations.document_stores.weaviate.document_store.weaviate")
     def test_to_dict(self, _mock_weaviate):
         document_store = WeaviateDocumentStore(
             url="http://localhost:8080",
+            collection_name="my_collection",
             auth_client_secret=AuthApiKey("my_api_key"),
             proxies={"http": "http://proxy:1234"},
             additional_headers={"X-HuggingFace-Api-Key": "MY_HUGGINGFACE_KEY"},
@@ -32,6 +77,7 @@ def test_to_dict(self, _mock_weaviate):
             "type": "haystack_integrations.document_stores.weaviate.document_store.WeaviateDocumentStore",
             "init_parameters": {
                 "url": "http://localhost:8080",
+                "collection_name": "my_collection",
                 "auth_client_secret": {
                     "type": "weaviate.auth.AuthApiKey",
                     "init_parameters": {"api_key": "my_api_key"},
@@ -67,6 +113,7 @@ def test_from_dict(self, _mock_weaviate):
                 "type": "haystack_integrations.document_stores.weaviate.document_store.WeaviateDocumentStore",
                 "init_parameters": {
                     "url": "http://localhost:8080",
+                    "collection_name": "my_collection",
                     "auth_client_secret": {
                         "type": "weaviate.auth.AuthApiKey",
                         "init_parameters": {"api_key": "my_api_key"},
@@ -97,6 +144,7 @@ def test_from_dict(self, _mock_weaviate):
         )
 
         assert document_store._url == "http://localhost:8080"
+        assert document_store._collection_name == "my_collection"
         assert document_store._auth_client_secret == AuthApiKey("my_api_key")
         assert document_store._timeout_config == (10, 60)
         assert document_store._proxies == {"http": "http://proxy:1234"}