diff --git a/CHANGELOG.md b/CHANGELOG.md index 82eaeef3c..cd83a793c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Native function calling support to `OpenAiChatPromptDriver`, `AzureOpenAiChatPromptDriver`, `AnthropicPromptDriver`, `AmazonBedrockPromptDriver`, `GooglePromptDriver`, and `CoherePromptDriver`. - `OllamaEmbeddingDriver` for generating embeddings with Ollama. +- `GriptapeCloudKnowledgeBaseVectorStoreDriver` to query Griptape Cloud Knowledge Bases. ### Changed diff --git a/docs/griptape-framework/drivers/vector-store-drivers.md b/docs/griptape-framework/drivers/vector-store-drivers.md index da6fdc242..ea2b72a56 100644 --- a/docs/griptape-framework/drivers/vector-store-drivers.md +++ b/docs/griptape-framework/drivers/vector-store-drivers.md @@ -1,6 +1,6 @@ ## Overview -Griptape provides a way to build drivers for vector DBs where embeddings can be stored and queried. Every vector store driver implements the following methods: +Griptape provides a way to build drivers for vector DBs where embeddings can be stored and queried. Every Vector Store Driver implements the following methods: - `upsert_text_artifact()` for updating or inserting a new [TextArtifact](../../reference/griptape/artifacts/text_artifact.md) into vector DBs. The method will automatically generate embeddings for a given value. - `upsert_text_artifacts()` for updating or inserting multiple [TextArtifact](../../reference/griptape/artifacts/text_artifact.md)s into vector DBs. The method will automatically generate embeddings for given values. @@ -8,19 +8,19 @@ Griptape provides a way to build drivers for vector DBs where embeddings can be - `upsert_vector()` for updating and inserting new vectors directly. - `query()` for querying vector DBs. -Each vector driver takes a [BaseEmbeddingDriver](../../reference/griptape/drivers/embedding/base_embedding_driver.md) used to dynamically generate embeddings for strings. +Each Vector Store Driver takes a [BaseEmbeddingDriver](../../reference/griptape/drivers/embedding/base_embedding_driver.md) used to dynamically generate embeddings for strings. !!! info - When working with vector database indexes with Griptape drivers, make sure the number of dimensions is equal to 1536. Nearly all embedding models create vectors with this number of dimensions. Check the documentation for your vector database on how to create/update vector indexes. + When working with vector database indexes with Griptape Drivers, make sure the number of dimensions is equal to 1536. Nearly all embedding models create vectors with this number of dimensions. Check the documentation for your vector database on how to create/update vector indexes. !!! info - More vector drivers are coming soon. + More Vector Store Drivers are coming soon. ## Vector Store Drivers ### Local -The [LocalVectorStoreDriver](../../reference/griptape/drivers/vector/local_vector_store_driver.md) can be used to load and query data from memory. Here is a complete example of how the driver can be used to load a webpage into the driver and query it later: +The [LocalVectorStoreDriver](../../reference/griptape/drivers/vector/local_vector_store_driver.md) can be used to load and query data from memory. Here is a complete example of how the Driver can be used to load a webpage into the Driver and query it later: ```python import os @@ -29,12 +29,15 @@ from griptape.drivers import LocalVectorStoreDriver, OpenAiEmbeddingDriver from griptape.loaders import WebLoader -# Initialize an embedding driver +# Initialize an Embedding Driver embedding_driver = OpenAiEmbeddingDriver(api_key=os.environ["OPENAI_API_KEY"]) vector_store_driver = LocalVectorStoreDriver(embedding_driver=embedding_driver) + +# Load Artifacts from the web artifacts = WebLoader(max_tokens=100).load("https://www.griptape.ai") +# Upsert Artifacts into the Vector Store Driver [vector_store_driver.upsert_text_artifact(a, namespace="griptape") for a in artifacts] results = vector_store_driver.query( @@ -49,14 +52,38 @@ print("\n\n".join(values)) ``` +### Griptape Cloud Knowledge Base + +The [GriptapeCloudKnowledgeBaseVectorStoreDriver](../../reference/griptape/drivers/vector/griptape_cloud_knowledge_base_vector_store_driver.md) can be used to query data from a Griptape Cloud Knowledge Base. Loading into Knowledge Bases is not supported at this time, only querying. Here is a complete example of how the Driver can be used to query an existing Knowledge Base: + +```python +import os +from griptape.artifacts import BaseArtifact +from griptape.drivers import GriptapeCloudKnowledgeBaseVectorStoreDriver + + +# Initialize environment variables +gt_cloud_api_key = os.environ["GRIPTAPE_CLOUD_API_KEY"] +gt_cloud_knowledge_base_id = os.environ["GRIPTAPE_CLOUD_KB_ID"] + +vector_store_driver = GriptapeCloudKnowledgeBaseVectorStoreDriver(api_key=gt_cloud_api_key, knowledge_base_id=gt_cloud_knowledge_base_id) + +results =vector_store_driver.query(query="What is griptape?") + +values = [r.to_artifact().value for r in results] + +print("\n\n".join(values)) + +``` + ### Pinecone !!! info - This driver requires the `drivers-vector-pinecone` [extra](../index.md#extras). + This Driver requires the `drivers-vector-pinecone` [extra](../index.md#extras). The [PineconeVectorStoreDriver](../../reference/griptape/drivers/vector/pinecone_vector_store_driver.md) supports the [Pinecone vector database](https://www.pinecone.io/). -Here is an example of how the driver can be used to load and query information in a Pinecone cluster: +Here is an example of how the Driver can be used to load and query information in a Pinecone cluster: ```python import os @@ -85,7 +112,7 @@ def load_data(driver: PineconeVectorStoreDriver) -> None: namespace="supermarket-products", ) -# Initialize an embedding driver +# Initialize an Embedding Driver embedding_driver = OpenAiEmbeddingDriver(api_key=os.environ["OPENAI_API_KEY"]) vector_store_driver = PineconeVectorStoreDriver( @@ -97,72 +124,79 @@ vector_store_driver = PineconeVectorStoreDriver( load_data(vector_store_driver) -result = vector_store_driver.query( +results = vector_store_driver.query( "fruit", count=3, filter={"price": {"$lte": 15}, "rating": {"$gte": 4}}, namespace="supermarket-products", ) + +values = [r.to_artifact().value for r in results] + +print("\n\n".join(values)) ``` ### Marqo !!! info - This driver requires the `drivers-vector-marqo` [extra](../index.md#extras). + This Driver requires the `drivers-vector-marqo` [extra](../index.md#extras). The [MarqoVectorStoreDriver](../../reference/griptape/drivers/vector/marqo_vector_store_driver.md) supports the Marqo vector database. -Here is an example of how the driver can be used to load and query information in a Marqo cluster: +Here is an example of how the Driver can be used to load and query information in a Marqo cluster: ```python import os from griptape.drivers import MarqoVectorStoreDriver, OpenAiEmbeddingDriver, OpenAiChatPromptDriver from griptape.loaders import WebLoader -# Initialize an embedding driver +# Initialize an Embedding Driver embedding_driver = OpenAiEmbeddingDriver(api_key=os.environ["OPENAI_API_KEY"]) prompt_driver = OpenAiChatPromptDriver(model="gpt-3.5-turbo") # Define the namespace namespace = 'griptape-ai' -# Initialize the vector store driver -vector_store = MarqoVectorStoreDriver( +# Initialize the Vector Store Driver +vector_store_driver = MarqoVectorStoreDriver( api_key=os.environ["MARQO_API_KEY"], url=os.environ["MARQO_URL"], index=os.environ["MARQO_INDEX_NAME"], embedding_driver=embedding_driver, ) -# Load artifacts from the web +# Load Artifacts from the web artifacts = WebLoader(max_tokens=200).load("https://www.griptape.ai") -# Upsert the artifacts into the vector store -vector_store.upsert_text_artifacts( +# Upsert Artifacts into the Vector Store Driver +vector_store_driver.upsert_text_artifacts( { "griptape": artifacts, } ) -result = vector_store.query(query="What is griptape?") -print(result) +results =vector_store_driver.query(query="What is griptape?") + +values = [r.to_artifact().value for r in results] + +print("\n\n".join(values)) ``` ### Mongodb Atlas !!! info - This driver requires the `drivers-vector-mongodb` [extra](../index.md#extras). + This Driver requires the `drivers-vector-mongodb` [extra](../index.md#extras). The [MongodbAtlasVectorStoreDriver](../../reference/griptape/drivers/vector/mongodb_atlas_vector_store_driver.md) provides support for storing vector data in a MongoDB Atlas database. -Here is an example of how the driver can be used to load and query information in a MongoDb Atlas Cluster: +Here is an example of how the Driver can be used to load and query information in a MongoDb Atlas Cluster: ```python from griptape.drivers import MongoDbAtlasVectorStoreDriver, OpenAiEmbeddingDriver from griptape.loaders import WebLoader import os -# Initialize an embedding driver +# Initialize an Embedding Driver embedding_driver = OpenAiEmbeddingDriver(api_key=os.environ["OPENAI_API_KEY"]) host = os.environ["MONGODB_HOST"] @@ -173,8 +207,8 @@ collection_name = os.environ[ "MONGODB_COLLECTION_NAME"] index_name = os.environ["MONGODB_INDEX_NAME"] vector_path = os.environ["MONGODB_VECTOR_PATH"] -# Initialize the vector store driver -vector_store = MongoDbAtlasVectorStoreDriver( +# Initialize the Vector Store Driver +vector_store_driver = MongoDbAtlasVectorStoreDriver( connection_string=f"mongodb+srv://{username}:{password}@{host}/{database_name}", database_name=database_name, collection_name=collection_name, @@ -183,18 +217,21 @@ vector_store = MongoDbAtlasVectorStoreDriver( vector_path=vector_path, ) -# Load artifacts from the web +# Load Artifacts from the web artifacts = WebLoader(max_tokens=200).load("https://www.griptape.ai") -# Upsert the artifacts into the vector store -vector_store.upsert_text_artifacts( +# Upsert Artifacts into the Vector Store Driver +vector_store_driver.upsert_text_artifacts( { "griptape": artifacts, } ) -result = vector_store.query(query="What is griptape?") -print(result) +results =vector_store_driver.query(query="What is griptape?") + +values = [r.to_artifact().value for r in results] + +print("\n\n".join(values)) ``` The format for creating a vector index should look similar to the following: @@ -219,18 +256,18 @@ Replace `path_to_vector` with the expected field name where the vector content w ### Azure MongoDB !!! info - This driver requires the `drivers-vector-mongodb` [extra](../index.md#extras). + This Driver requires the `drivers-vector-mongodb` [extra](../index.md#extras). The [AzureMongoDbVectorStoreDriver](../../reference/griptape/drivers/vector/azure_mongodb_vector_store_driver.md) provides support for storing vector data in an Azure CosmosDb database account using the MongoDb vCore API -Here is an example of how the driver can be used to load and query information in an Azure CosmosDb MongoDb vCore database. It is very similar to the Driver for [MongoDb Atlas](#mongodb-atlas): +Here is an example of how the Driver can be used to load and query information in an Azure CosmosDb MongoDb vCore database. It is very similar to the Driver for [MongoDb Atlas](#mongodb-atlas): ```python from griptape.drivers import AzureMongoDbVectorStoreDriver, OpenAiEmbeddingDriver from griptape.loaders import WebLoader import os -# Initialize an embedding driver +# Initialize an Embedding Driver embedding_driver = OpenAiEmbeddingDriver(api_key=os.environ["OPENAI_API_KEY"]) azure_host = os.environ["AZURE_MONGODB_HOST"] @@ -241,8 +278,8 @@ collection_name = os.environ["AZURE_MONGODB_COLLECTION_NAME"] index_name = os.environ["AZURE_MONGODB_INDEX_NAME"] vector_path = os.environ["AZURE_MONGODB_VECTOR_PATH"] -# Initialize the vector store driver -vector_store = AzureMongoDbVectorStoreDriver( +# Initialize the Vector Store Driver +vector_store_driver = AzureMongoDbVectorStoreDriver( connection_string=f"mongodb+srv://{username}:{password}@{azure_host}/{database_name}?tls=true&authMechanism=SCRAM-SHA-256&retrywrites=false&maxIdleTimeMS=120000", database_name=database_name, collection_name=collection_name, @@ -251,28 +288,31 @@ vector_store = AzureMongoDbVectorStoreDriver( vector_path=vector_path, ) -# Load artifacts from the web +# Load Artifacts from the web artifacts = WebLoader(max_tokens=200).load("https://www.griptape.ai") -# Upsert the artifacts into the vector store -vector_store.upsert_text_artifacts( +# Upsert Artifacts into the Vector Store Driver +vector_store_driver.upsert_text_artifacts( { "griptape": artifacts, } ) -result = vector_store.query(query="What is griptape?") -print(result) +results =vector_store_driver.query(query="What is griptape?") + +values = [r.to_artifact().value for r in results] + +print("\n\n".join(values)) ``` ### Redis !!! info - This driver requires the `drivers-vector-redis` [extra](../index.md#extras). + This Driver requires the `drivers-vector-redis` [extra](../index.md#extras). The [RedisVectorStoreDriver](../../reference/griptape/drivers/vector/redis_vector_store_driver.md) integrates with the Redis vector storage system. -Here is an example of how the driver can be used to load and query information in a Redis Cluster: +Here is an example of how the Driver can be used to load and query information in a Redis Cluster: ```python import os @@ -280,7 +320,7 @@ from griptape.drivers import RedisVectorStoreDriver, OpenAiEmbeddingDriver from griptape.loaders import WebLoader import numpy as np # Assuming you'd use numpy to create a dummy vector for the sake of example. -# Initialize an embedding driver +# Initialize an Embedding Driver embedding_driver = OpenAiEmbeddingDriver(api_key=os.environ["OPENAI_API_KEY"]) vector_store_driver = RedisVectorStoreDriver( @@ -291,18 +331,21 @@ vector_store_driver = RedisVectorStoreDriver( embedding_driver=embedding_driver, ) -# Load artifacts from the web +# Load Artifacts from the web artifacts = WebLoader(max_tokens=200).load("https://www.griptape.ai") -# Upsert the artifacts into the vector store +# Upsert Artifacts into the Vector Store Driver vector_store_driver.upsert_text_artifacts( { "griptape": artifacts, } ) -result = vector_store_driver.query(query="What is griptape?") -print(result) +results =vector_store_driver.query(query="What is griptape?") + +values = [r.to_artifact().value for r in results] + +print("\n\n".join(values)) ``` The format for creating a vector index should be similar to the following: @@ -313,11 +356,11 @@ FT.CREATE idx:griptape ON hash PREFIX 1 "griptape:" SCHEMA namespace TAG vector ### OpenSearch !!! info - This driver requires the `drivers-vector-opensearch` [extra](../index.md#extras). + This Driver requires the `drivers-vector-opensearch` [extra](../index.md#extras). The [OpenSearchVectorStoreDriver](../../reference/griptape/drivers/vector/opensearch_vector_store_driver.md) integrates with the OpenSearch platform, allowing for storage, retrieval, and querying of vector data. -Here is an example of how the driver can be used to load and query information in an OpenSearch Cluster: +Here is an example of how the Driver can be used to load and query information in an OpenSearch Cluster: ```python import os @@ -325,7 +368,7 @@ import boto3 from griptape.drivers import AmazonOpenSearchVectorStoreDriver, OpenAiEmbeddingDriver from griptape.loaders import WebLoader -# Initialize an embedding driver +# Initialize an Embedding Driver embedding_driver = OpenAiEmbeddingDriver(api_key=os.environ["OPENAI_API_KEY"]) vector_store_driver = AmazonOpenSearchVectorStoreDriver( @@ -335,19 +378,21 @@ vector_store_driver = AmazonOpenSearchVectorStoreDriver( embedding_driver=embedding_driver, ) -# Load artifacts from the web +# Load Artifacts from the web artifacts = WebLoader(max_tokens=200).load("https://www.griptape.ai") -# Upsert the artifacts into the vector store +# Upsert Artifacts into the Vector Store Driver vector_store_driver.upsert_text_artifacts( { "griptape": artifacts, } ) -result = vector_store_driver.query(query="What is griptape?") +results =vector_store_driver.query(query="What is griptape?") + +values = [r.to_artifact().value for r in results] -print(result) +print("\n\n".join(values)) ``` The body mappings for creating a vector index should look similar to the following: @@ -366,18 +411,18 @@ The body mappings for creating a vector index should look similar to the followi ### PGVector !!! info - This driver requires the `drivers-vector-postgresql` [extra](../index.md#extras). + This Driver requires the `drivers-vector-postgresql` [extra](../index.md#extras). -The [PGVectorVectorStoreDriver](../../reference/griptape/drivers/vector/pgvector_vector_store_driver.md) integrates with PGVector, a vector storage and search extension for Postgres. While Griptape will handle enabling the extension, PGVector must be installed and ready for use in your Postgres instance before using this vector store driver. +The [PGVectorVectorStoreDriver](../../reference/griptape/drivers/vector/pgvector_vector_store_driver.md) integrates with PGVector, a vector storage and search extension for Postgres. While Griptape will handle enabling the extension, PGVector must be installed and ready for use in your Postgres instance before using this Vector Store Driver. -Here is an example of how the driver can be used to load and query information in a Postgres database: +Here is an example of how the Driver can be used to load and query information in a Postgres database: ```python import os from griptape.drivers import PgVectorVectorStoreDriver, OpenAiEmbeddingDriver from griptape.loaders import WebLoader -# Initialize an embedding driver. +# Initialize an Embedding Driver. embedding_driver = OpenAiEmbeddingDriver(api_key=os.environ["OPENAI_API_KEY"]) db_user = os.environ["POSTGRES_USER"] @@ -395,26 +440,31 @@ vector_store_driver = PgVectorVectorStoreDriver( # Install required Postgres extensions and create database schema. vector_store_driver.setup() -web_loader = WebLoader() -artifacts = web_loader.load("https://www.griptape.ai") +# Load Artifacts from the web +artifacts = WebLoader().load("https://www.griptape.ai") + +# Upsert Artifacts into the Vector Store Driver vector_store_driver.upsert_text_artifacts( { "griptape": artifacts, } ) -result = vector_store_driver.query("What is griptape?") -print(result) +results =vector_store_driver.query(query="What is griptape?") + +values = [r.to_artifact().value for r in results] + +print("\n\n".join(values)) ``` ### Qdrant !!! info - This driver requires the `drivers-vector-qdrant` [extra](../index.md#extras). + This Driver requires the `drivers-vector-qdrant` [extra](../index.md#extras). The QdrantVectorStoreDriver supports the [Qdrant vector database](https://qdrant.tech/). -Here is an example of how the driver can be used to query information in a Qdrant collection: +Here is an example of how the Driver can be used to query information in a Qdrant collection: ```python import os @@ -427,14 +477,14 @@ embedding_model_name = "sentence-transformers/all-MiniLM-L6-v2" host = os.environ["QDRANT_CLUSTER_ENDPOINT"] huggingface_token = os.environ["HUGGINGFACE_HUB_ACCESS_TOKEN"] -# Initialize HuggingFace embedding driver +# Initialize HuggingFace Embedding Driver embedding_driver = HuggingFaceHubEmbeddingDriver( api_token=huggingface_token, model=embedding_model_name, tokenizer=HuggingFaceTokenizer(model=embedding_model_name, max_output_tokens=512), ) -# Initialize Qdrant vector store driver +# Initialize Qdrant Vector Store Driver vector_store_driver = QdrantVectorStoreDriver( url=host, collection_name="griptape", @@ -443,7 +493,7 @@ vector_store_driver = QdrantVectorStoreDriver( api_key=os.environ["QDRANT_CLUSTER_API_KEY"], ) -# Load data from the website +# Load Artifacts from the web artifacts = WebLoader().load("https://www.griptape.ai") # Encode text to get embeddings @@ -465,6 +515,9 @@ vector_store_driver.upsert_vector( content=artifacts[0].value ) -print("Vectors successfully inserted into Qdrant.") +results =vector_store_driver.query(query="What is griptape?") +values = [r.to_artifact().value for r in results] + +print("\n\n".join(values)) ``` diff --git a/griptape/drivers/__init__.py b/griptape/drivers/__init__.py index 4e0fe6672..4ecb6c9bd 100644 --- a/griptape/drivers/__init__.py +++ b/griptape/drivers/__init__.py @@ -41,6 +41,7 @@ from .vector.azure_mongodb_vector_store_driver import AzureMongoDbVectorStoreDriver from .vector.dummy_vector_store_driver import DummyVectorStoreDriver from .vector.qdrant_vector_store_driver import QdrantVectorStoreDriver +from .vector.griptape_cloud_knowledge_base_vector_store_driver import GriptapeCloudKnowledgeBaseVectorStoreDriver from .sql.base_sql_driver import BaseSqlDriver from .sql.amazon_redshift_sql_driver import AmazonRedshiftSqlDriver @@ -149,6 +150,7 @@ "PgVectorVectorStoreDriver", "QdrantVectorStoreDriver", "DummyVectorStoreDriver", + "GriptapeCloudKnowledgeBaseVectorStoreDriver", "BaseSqlDriver", "AmazonRedshiftSqlDriver", "SnowflakeSqlDriver", diff --git a/griptape/drivers/vector/griptape_cloud_knowledge_base_vector_store_driver.py b/griptape/drivers/vector/griptape_cloud_knowledge_base_vector_store_driver.py new file mode 100644 index 000000000..0c7d5c961 --- /dev/null +++ b/griptape/drivers/vector/griptape_cloud_knowledge_base_vector_store_driver.py @@ -0,0 +1,101 @@ +from urllib.parse import urljoin +import requests +from typing import Optional, Any +from attrs import Factory, define, field +from griptape.artifacts import TextArtifact, ListArtifact +from griptape.drivers import BaseEmbeddingDriver, BaseVectorStoreDriver, DummyEmbeddingDriver + + +@define +class GriptapeCloudKnowledgeBaseVectorStoreDriver(BaseVectorStoreDriver): + """A vector store driver for Griptape Cloud Knowledge Bases. + + Attributes: + api_key: API Key for Griptape Cloud. + knowledge_base_id: Knowledge Base ID for Griptape Cloud. + base_url: Base URL for Griptape Cloud. + headers: Headers for Griptape Cloud. + """ + + api_key: str = field(kw_only=True, metadata={"serializable": True}) + knowledge_base_id: str = field(kw_only=True, metadata={"serializable": True}) + base_url: str = field(default="https://cloud.griptape.ai", kw_only=True) + headers: dict = field( + default=Factory(lambda self: {"Authorization": f"Bearer {self.api_key}"}, takes_self=True), kw_only=True + ) + embedding_driver: BaseEmbeddingDriver = field( + default=Factory(lambda: DummyEmbeddingDriver()), metadata={"serializable": True}, kw_only=True, init=False + ) + + def upsert_vector( + self, + vector: list[float], + vector_id: Optional[str] = None, + namespace: Optional[str] = None, + meta: Optional[dict] = None, + **kwargs, + ) -> str: + raise NotImplementedError(f"{self.__class__.__name__} does not support vector upsert.") + + def upsert_text_artifact( + self, + artifact: TextArtifact, + namespace: Optional[str] = None, + meta: Optional[dict] = None, + vector_id: Optional[str] = None, + **kwargs, + ) -> str: + raise NotImplementedError(f"{self.__class__.__name__} does not support text artifact upsert.") + + def upsert_text( + self, + string: str, + vector_id: Optional[str] = None, + namespace: Optional[str] = None, + meta: Optional[dict] = None, + **kwargs, + ) -> str: + raise NotImplementedError(f"{self.__class__.__name__} does not support text upsert.") + + def load_entry(self, vector_id: str, namespace: Optional[str] = None) -> BaseVectorStoreDriver.Entry: + raise NotImplementedError(f"{self.__class__.__name__} does not support entry loading.") + + def load_entries(self, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]: + raise NotImplementedError(f"{self.__class__.__name__} does not support entry loading.") + + def load_artifacts(self, namespace: Optional[str] = None) -> ListArtifact: + raise NotImplementedError(f"{self.__class__.__name__} does not support Artifact loading.") + + def query( + self, + query: str, + count: Optional[int] = None, + namespace: Optional[str] = None, + include_vectors: Optional[bool] = None, + distance_metric: Optional[str] = None, + # GriptapeCloudKnowledgeBaseVectorStoreDriver-specific params: + filter: Optional[dict] = None, + **kwargs, + ) -> list[BaseVectorStoreDriver.Entry]: + """Performs a query on the Knowledge Base. + + Performs a query on the Knowledge Base and returns Artifacts with close vector proximity to the query, optionally filtering to only those that match the provided filter(s). + """ + url = urljoin(self.base_url.strip("/"), f"/api/knowledge-bases/{self.knowledge_base_id}/query") + + request: dict[str, Any] = { + "query": query, + "count": count, + "distance_metric": distance_metric, + "filter": filter, + "include_vectors": include_vectors, + } + request = {k: v for k, v in request.items() if v is not None} + + response = requests.post(url, json=request, headers=self.headers).json() + entries = response.get("entries", []) + entry_list = [BaseVectorStoreDriver.Entry.from_dict(entry) for entry in entries] + return entry_list + + def delete_vector(self, vector_id: str): + raise NotImplementedError(f"{self.__class__.__name__} does not support deletion.") diff --git a/tests/unit/drivers/vector/test_griptape_cloud_knowledge_base_vector_store_driver.py b/tests/unit/drivers/vector/test_griptape_cloud_knowledge_base_vector_store_driver.py new file mode 100644 index 000000000..957edebb8 --- /dev/null +++ b/tests/unit/drivers/vector/test_griptape_cloud_knowledge_base_vector_store_driver.py @@ -0,0 +1,69 @@ +import pytest +import uuid +from griptape.drivers import GriptapeCloudKnowledgeBaseVectorStoreDriver + + +class TestGriptapeCloudKnowledgeBaseVectorStoreDriver: + test_ids = [str(uuid.uuid4()), str(uuid.uuid4())] + test_vecs = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + test_namespaces = [str(uuid.uuid4()), str(uuid.uuid4())] + test_metas = [{"key": "value1"}, {"key": "value2"}] + test_scores = [0.7, 0.8] + + @pytest.fixture + def driver(self, mocker): + test_entries = { + "entries": [ + { + "id": self.test_ids[0], + "vector": self.test_vecs[0], + "namespace": self.test_namespaces[0], + "meta": self.test_metas[0], + "score": self.test_scores[0], + }, + { + "id": self.test_ids[1], + "vector": self.test_vecs[1], + "namespace": self.test_namespaces[1], + "meta": self.test_metas[1], + "score": self.test_scores[1], + }, + ] + } + + mock_response = mocker.Mock() + mock_response.status_code = 200 + mock_response.json.return_value = test_entries + mocker.patch("requests.post", return_value=mock_response) + + return GriptapeCloudKnowledgeBaseVectorStoreDriver(api_key="foo bar", knowledge_base_id="1") + + def test_query(self, driver): + result = driver.query( + "some query", count=10, namespace="foo", include_vectors=True, distance_metric="bar", filter={"foo": "bar"} + ) + + assert result[0].id == self.test_ids[0] + assert result[1].id == self.test_ids[1] + assert result[0].vector == self.test_vecs[0] + assert result[1].vector == self.test_vecs[1] + assert result[0].namespace == self.test_namespaces[0] + assert result[1].namespace == self.test_namespaces[1] + assert result[0].meta == self.test_metas[0] + assert result[1].meta == self.test_metas[1] + assert result[0].score == self.test_scores[0] + assert result[1].score == self.test_scores[1] + + def test_query_defaults(self, driver): + result = driver.query("some query") + + assert result[0].id == self.test_ids[0] + assert result[1].id == self.test_ids[1] + assert result[0].vector == self.test_vecs[0] + assert result[1].vector == self.test_vecs[1] + assert result[0].namespace == self.test_namespaces[0] + assert result[1].namespace == self.test_namespaces[1] + assert result[0].meta == self.test_metas[0] + assert result[1].meta == self.test_metas[1] + assert result[0].score == self.test_scores[0] + assert result[1].score == self.test_scores[1]