diff --git a/.github/workflows/docs-integration-tests.yml b/.github/workflows/docs-integration-tests.yml index 79e0c9f3a..cd21ad1ac 100644 --- a/.github/workflows/docs-integration-tests.yml +++ b/.github/workflows/docs-integration-tests.yml @@ -120,6 +120,8 @@ jobs: PUSHER_SECRET: ${{ secrets.INTEG_PUSHER_SECRET }} PUSHER_CLUSTER: ${{ secrets.INTEG_PUSHER_CLUSTER }} ZENROWS_API_KEY: ${{ secrets.INTEG_ZENROWS_API_KEY }} + QDRANT_CLUSTER_ENDPOINT: ${{ secrets.INTEG_QDRANT_CLUSTER_ENDPOINT }} + QDRANT_CLUSTER_API_KEY: ${{ secrets.INTEG_QDRANT_CLUSTER_API_KEY }} services: postgres: image: ankane/pgvector:v0.5.0 diff --git a/CHANGELOG.md b/CHANGELOG.md index 273fe5429..29ca44e2a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -263,6 +263,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `MarkdownifyWebScraperDriver` for scraping text from web pages using playwright and converting to markdown using markdownify. - `VoyageAiEmbeddingDriver` for use with VoyageAi's embedding models. - `AnthropicStructureConfig` for providing Structures with Anthropic Prompt and VoyageAi Embedding Driver configuration. +- `QdrantVectorStoreDriver` to integrate with Qdrant vector databases. ### Fixed - Improved system prompt in `ToolTask` to support more use cases. diff --git a/docs/griptape-framework/drivers/vector-store-drivers.md b/docs/griptape-framework/drivers/vector-store-drivers.md index 095c14f22..f0ea9b2a0 100644 --- a/docs/griptape-framework/drivers/vector-store-drivers.md +++ b/docs/griptape-framework/drivers/vector-store-drivers.md @@ -406,3 +406,69 @@ vector_store_driver.upsert_text_artifacts( result = vector_store_driver.query("What is griptape?") print(result) ``` + +### Qdrant + +!!! info + This driver requires the `drivers-vector-qdrant` [extra](../index.md#extras). + +The QdrantVectorStoreDriver supports the [Qdrant vector database](https://qdrant.tech/). + +Here is an example of how the driver can be used to query information in a Qdrant collection: + +```python +import os +from sentence_transformers import SentenceTransformer +from griptape.drivers import QdrantVectorStoreDriver, HuggingFaceHubEmbeddingDriver +from griptape.tokenizers import HuggingFaceTokenizer +from griptape.loaders import WebLoader + +# Set up environment variables +embedding_model_name = "sentence-transformers/all-MiniLM-L6-v2" +host = os.environ["QDRANT_CLUSTER_ENDPOINT"] +huggingface_token = os.environ["HUGGINGFACE_HUB_ACCESS_TOKEN"] + +# Initialize embedding model +embedding_model = SentenceTransformer(embedding_model_name) + +# Initialize HuggingFace embedding driver +embedding_driver = HuggingFaceHubEmbeddingDriver( + api_token=huggingface_token, + model=embedding_model_name, + tokenizer=HuggingFaceTokenizer(model=embedding_model_name, max_output_tokens=512), +) + +# Initialize Qdrant vector store driver +vector_store_driver = QdrantVectorStoreDriver( + url=host, + collection_name="griptape", + content_payload_key="content", + embedding_driver=embedding_driver, + api_key=os.environ["QDRANT_CLUSTER_API_KEY"], +) + +# Load data from the website +artifacts = WebLoader().load("https://www.griptape.ai") + +# Encode text to get embeddings +embeddings = embedding_model.encode(artifacts[0].value) + +# Recreate Qdrant collection +vector_store_driver.client.recreate_collection( + collection_name=vector_store_driver.collection_name, + vectors_config={ + "size": embedding_model.get_sentence_embedding_dimension(), + "distance": vector_store_driver.distance + }, +) + +# Upsert vector into Qdrant +vector_store_driver.upsert_vector( + vector=embeddings.tolist(), + vector_id=str(artifacts[0].id), + content=artifacts[0].value +) + +print("Vectors successfully inserted into Qdrant.") + +``` diff --git a/griptape/drivers/__init__.py b/griptape/drivers/__init__.py index 34b43103b..fa2934a38 100644 --- a/griptape/drivers/__init__.py +++ b/griptape/drivers/__init__.py @@ -39,6 +39,7 @@ from .vector.pgvector_vector_store_driver import PgVectorVectorStoreDriver from .vector.azure_mongodb_vector_store_driver import AzureMongoDbVectorStoreDriver from .vector.dummy_vector_store_driver import DummyVectorStoreDriver +from .vector.qdrant_vector_store_driver import QdrantVectorStoreDriver from .sql.base_sql_driver import BaseSqlDriver from .sql.amazon_redshift_sql_driver import AmazonRedshiftSqlDriver @@ -144,6 +145,7 @@ "OpenSearchVectorStoreDriver", "AmazonOpenSearchVectorStoreDriver", "PgVectorVectorStoreDriver", + "QdrantVectorStoreDriver", "DummyVectorStoreDriver", "BaseSqlDriver", "AmazonRedshiftSqlDriver", diff --git a/griptape/drivers/vector/qdrant_vector_store_driver.py b/griptape/drivers/vector/qdrant_vector_store_driver.py new file mode 100644 index 000000000..5159bdb0a --- /dev/null +++ b/griptape/drivers/vector/qdrant_vector_store_driver.py @@ -0,0 +1,207 @@ +from __future__ import annotations +from typing import Optional +from attrs import define, field +from griptape.drivers import BaseVectorStoreDriver +from griptape.utils import import_optional_dependency +import uuid +import logging + +DEFAULT_DISTANCE = "Cosine" +CONTENT_PAYLOAD_KEY = "data" + + +@define +class QdrantVectorStoreDriver(BaseVectorStoreDriver): + """ + Attributes: + location: An optional location for the Qdrant client. If set to ':memory:', an in-memory client is used. + url: An optional Qdrant API URL. + host: An optional Qdrant host. + path: Persistence path for QdrantLocal. Default: None + port: The port number for the Qdrant client. Defaults: 6333. + grpc_port: The gRPC port number for the Qdrant client. Defaults: 6334. + prefer_grpc: A boolean indicating whether to prefer gRPC over HTTP. Defaults: False. + force_disable_check_same_thread: For QdrantLocal, force disable check_same_thread. Default: False Only use this if you can guarantee that you can resolve the thread safety outside QdrantClient. + timeout: Timeout for REST and gRPC API requests. Default: 5 seconds for REST and unlimited for gRPC + api_key: API key for authentication in Qdrant Cloud. Defaults: False + https: If true - use HTTPS(SSL) protocol. Default: None + prefix: Add prefix to the REST URL path. Example: service/v1 will result in Example: service/v1 will result in http://localhost:6333/service/v1/{qdrant-endpoint} for REST API. Defaults: None + distance: The distance metric to be used for the vectors. Defaults: 'COSINE'. + collection_name: The name of the Qdrant collection. + vector_name: An optional name for the vectors. + content_payload_key: The key for the content payload in the metadata. Defaults: 'data'. + """ + + location: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) + url: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) + host: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) + path: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) + port: int = field(default=6333, kw_only=True, metadata={"serializable": True}) + grpc_port: int = field(default=6334, kw_only=True, metadata={"serializable": True}) + prefer_grpc: bool = field(default=False, kw_only=True, metadata={"serializable": True}) + api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) + https: bool = field(default=None, kw_only=True, metadata={"serializable": True}) + prefix: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) + force_disable_check_same_thread: Optional[bool] = field( + default=False, kw_only=True, metadata={"serializable": True} + ) + timeout: Optional[int] = field(default=5, kw_only=True, metadata={"serializable": True}) + distance: str = field(default=DEFAULT_DISTANCE, kw_only=True, metadata={"serializable": True}) + collection_name: str = field(kw_only=True, metadata={"serializable": True}) + vector_name: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) + content_payload_key: str = field(default=CONTENT_PAYLOAD_KEY, kw_only=True, metadata={"serializable": True}) + + def __attrs_post_init__(self) -> None: + self.client = import_optional_dependency("qdrant_client").QdrantClient( + location=self.location, + url=self.url, + host=self.host, + path=self.path, + port=self.port, + prefer_grpc=self.prefer_grpc, + grpc_port=self.grpc_port, + api_key=self.api_key, + https=self.https, + prefix=self.prefix, + force_disable_check_same_thread=self.force_disable_check_same_thread, + timeout=self.timeout, + ) + + def delete_vector(self, vector_id: str) -> None: + """ + Delete a vector from the Qdrant collection based on its ID. + + Parameters: + vector_id (str | id): ID of the vector to delete. + """ + deletion_response = self.client.delete( + collection_name=self.collection_name, + points_selector=import_optional_dependency("qdrant_client.http.models").PointIdsList(points=[vector_id]), + ) + if deletion_response.status == import_optional_dependency("qdrant_client.http.models").UpdateStatus.COMPLETED: + logging.info(f"ID {vector_id} is successfully deleted") + + def query( + self, + query: str, + count: Optional[int] = None, + namespace: Optional[str] = None, + include_vectors: bool = False, + **kwargs, + ) -> list[BaseVectorStoreDriver.Entry]: + """ + Query the Qdrant collection based on a query vector. + + Parameters: + query (str): Query string. + count (Optional[int]): Optional number of results to return. + namespace (Optional[str]): Optional namespace of the vectors. + include_vectors (bool): Whether to include vectors in the results. + + Returns: + list[BaseVectorStoreDriver.Entry]: List of Entry objects. + """ + query_vector = self.embedding_driver.embed_string(query) + + # Create a search request + results = self.client.search(collection_name=self.collection_name, query_vector=query_vector, limit=count) + + # Convert results to QueryResult objects + query_results = [ + BaseVectorStoreDriver.Entry( + id=result.id, + vector=result.vector if include_vectors else [], + score=result.score, + meta={k: v for k, v in result.payload.items() if k not in ["_score", "_tensor_facets"]}, + ) + for result in results + ] + return query_results + + def upsert_vector( + self, + vector: list[float], + vector_id: Optional[str] = None, + namespace: Optional[str] = None, + meta: Optional[dict] = None, + content: Optional[str] = None, + **kwargs, + ) -> str: + """ + Upsert vectors into the Qdrant collection. + + Parameters: + vector (list[float]): The vector to be upserted. + vector_id (Optional[str]): Optional vector ID. + namespace (Optional[str]): Optional namespace for the vector. + meta (Optional[dict]): Optional dictionary containing metadata. + content (Optional[str]): The text content to be included in the payload. + + Returns: + str: The ID of the upserted vector. + """ + + if vector_id is None: + vector_id = str(uuid.uuid5(uuid.NAMESPACE_DNS, str(vector))) + + if meta is None: + meta = {} + + if content: + meta[self.content_payload_key] = content + + points = import_optional_dependency("qdrant_client.http.models").Batch( + ids=[vector_id], vectors=[vector], payloads=[meta] if meta else None + ) + + self.client.upsert(collection_name=self.collection_name, points=points) + return vector_id + + def load_entry(self, vector_id: str, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]: + """ + Load a vector entry from the Qdrant collection based on its ID. + + Parameters: + vector_id (str): ID of the vector to load. + namespace (str, optional): Optional namespace of the vector. + + Returns: + Optional[BaseVectorStoreDriver.Entry]: Vector entry if found, else None. + """ + results = self.client.retrieve(collection_name=self.collection_name, ids=[vector_id]) + if results: + entry = results[0] + return BaseVectorStoreDriver.Entry( + id=entry.id, + vector=entry.vector, + meta={k: v for k, v in entry.payload.items() if k not in ["_score", "_tensor_facets"]}, + ) + else: + return None + + def load_entries(self, namespace: Optional[str] = None, **kwargs) -> list[BaseVectorStoreDriver.Entry]: + """ + Load vector entries from the Qdrant collection. + + Parameters: + namespace: Optional namespace of the vectors. + + Returns: + List of points. + """ + + results = self.client.retrieve( + collection_name=self.collection_name, + ids=kwargs.get("ids", []), + with_payload=kwargs.get("with_payload", True), + with_vectors=kwargs.get("with_vectors", True), + ) + + return [ + BaseVectorStoreDriver.Entry( + id=entry.id, + vector=entry.vector if kwargs.get("with_vectors", True) else [], + meta={k: v for k, v in entry.payload.items() if k not in ["_score", "_tensor_facets"]}, + ) + for entry in results + ] diff --git a/poetry.lock b/poetry.lock index 768a4e8ae..64f01aa75 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1844,6 +1844,74 @@ googleapis-common-protos = ">=1.5.5" grpcio = ">=1.62.2" protobuf = ">=4.21.6" +[[package]] +name = "grpcio-tools" +version = "1.62.2" +description = "Protobuf code generator for gRPC" +optional = true +python-versions = ">=3.7" +files = [ + {file = "grpcio-tools-1.62.2.tar.gz", hash = "sha256:5fd5e1582b678e6b941ee5f5809340be5e0724691df5299aae8226640f94e18f"}, + {file = "grpcio_tools-1.62.2-cp310-cp310-linux_armv7l.whl", hash = "sha256:1679b4903aed2dc5bd8cb22a452225b05dc8470a076f14fd703581efc0740cdb"}, + {file = "grpcio_tools-1.62.2-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:9d41e0e47dd075c075bb8f103422968a65dd0d8dc8613288f573ae91eb1053ba"}, + {file = "grpcio_tools-1.62.2-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:987e774f74296842bbffd55ea8826370f70c499e5b5f71a8cf3103838b6ee9c3"}, + {file = "grpcio_tools-1.62.2-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:40cd4eeea4b25bcb6903b82930d579027d034ba944393c4751cdefd9c49e6989"}, + {file = "grpcio_tools-1.62.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b6746bc823958499a3cf8963cc1de00072962fb5e629f26d658882d3f4c35095"}, + {file = "grpcio_tools-1.62.2-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:2ed775e844566ce9ce089be9a81a8b928623b8ee5820f5e4d58c1a9d33dfc5ae"}, + {file = "grpcio_tools-1.62.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:bdc5dd3f57b5368d5d661d5d3703bcaa38bceca59d25955dff66244dbc987271"}, + {file = "grpcio_tools-1.62.2-cp310-cp310-win32.whl", hash = "sha256:3a8d6f07e64c0c7756f4e0c4781d9d5a2b9cc9cbd28f7032a6fb8d4f847d0445"}, + {file = "grpcio_tools-1.62.2-cp310-cp310-win_amd64.whl", hash = "sha256:e33b59fb3efdddeb97ded988a871710033e8638534c826567738d3edce528752"}, + {file = "grpcio_tools-1.62.2-cp311-cp311-linux_armv7l.whl", hash = "sha256:472505d030135d73afe4143b0873efe0dcb385bd6d847553b4f3afe07679af00"}, + {file = "grpcio_tools-1.62.2-cp311-cp311-macosx_10_10_universal2.whl", hash = "sha256:ec674b4440ef4311ac1245a709e87b36aca493ddc6850eebe0b278d1f2b6e7d1"}, + {file = "grpcio_tools-1.62.2-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:184b4174d4bd82089d706e8223e46c42390a6ebac191073b9772abc77308f9fa"}, + {file = "grpcio_tools-1.62.2-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c195d74fe98541178ece7a50dad2197d43991e0f77372b9a88da438be2486f12"}, + {file = "grpcio_tools-1.62.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a34d97c62e61bfe9e6cff0410fe144ac8cca2fc979ad0be46b7edf026339d161"}, + {file = "grpcio_tools-1.62.2-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:cbb8453ae83a1db2452b7fe0f4b78e4a8dd32be0f2b2b73591ae620d4d784d3d"}, + {file = "grpcio_tools-1.62.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:4f989e5cebead3ae92c6abf6bf7b19949e1563a776aea896ac5933f143f0c45d"}, + {file = "grpcio_tools-1.62.2-cp311-cp311-win32.whl", hash = "sha256:c48fabe40b9170f4e3d7dd2c252e4f1ff395dc24e49ac15fc724b1b6f11724da"}, + {file = "grpcio_tools-1.62.2-cp311-cp311-win_amd64.whl", hash = "sha256:8c616d0ad872e3780693fce6a3ac8ef00fc0963e6d7815ce9dcfae68ba0fc287"}, + {file = "grpcio_tools-1.62.2-cp312-cp312-linux_armv7l.whl", hash = "sha256:10cc3321704ecd17c93cf68c99c35467a8a97ffaaed53207e9b2da6ae0308ee1"}, + {file = "grpcio_tools-1.62.2-cp312-cp312-macosx_10_10_universal2.whl", hash = "sha256:9be84ff6d47fd61462be7523b49d7ba01adf67ce4e1447eae37721ab32464dd8"}, + {file = "grpcio_tools-1.62.2-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:d82f681c9a9d933a9d8068e8e382977768e7779ddb8870fa0cf918d8250d1532"}, + {file = "grpcio_tools-1.62.2-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:04c607029ae3660fb1624ed273811ffe09d57d84287d37e63b5b802a35897329"}, + {file = "grpcio_tools-1.62.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:72b61332f1b439c14cbd3815174a8f1d35067a02047c32decd406b3a09bb9890"}, + {file = "grpcio_tools-1.62.2-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:8214820990d01b52845f9fbcb92d2b7384a0c321b303e3ac614c219dc7d1d3af"}, + {file = "grpcio_tools-1.62.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:462e0ab8dd7c7b70bfd6e3195eebc177549ede5cf3189814850c76f9a340d7ce"}, + {file = "grpcio_tools-1.62.2-cp312-cp312-win32.whl", hash = "sha256:fa107460c842e4c1a6266150881694fefd4f33baa544ea9489601810c2210ef8"}, + {file = "grpcio_tools-1.62.2-cp312-cp312-win_amd64.whl", hash = "sha256:759c60f24c33a181bbbc1232a6752f9b49fbb1583312a4917e2b389fea0fb0f2"}, + {file = "grpcio_tools-1.62.2-cp37-cp37m-linux_armv7l.whl", hash = "sha256:45db5da2bcfa88f2b86b57ef35daaae85c60bd6754a051d35d9449c959925b57"}, + {file = "grpcio_tools-1.62.2-cp37-cp37m-macosx_10_10_universal2.whl", hash = "sha256:ab84bae88597133f6ea7a2bdc57b2fda98a266fe8d8d4763652cbefd20e73ad7"}, + {file = "grpcio_tools-1.62.2-cp37-cp37m-manylinux_2_17_aarch64.whl", hash = "sha256:7a49bccae1c7d154b78e991885c3111c9ad8c8fa98e91233de425718f47c6139"}, + {file = "grpcio_tools-1.62.2-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a7e439476b29d6dac363b321781a113794397afceeb97dad85349db5f1cb5e9a"}, + {file = "grpcio_tools-1.62.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7ea369c4d1567d1acdf69c8ea74144f4ccad9e545df7f9a4fc64c94fa7684ba3"}, + {file = "grpcio_tools-1.62.2-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:4f955702dc4b530696375251319d05223b729ed24e8673c2129f7a75d2caefbb"}, + {file = "grpcio_tools-1.62.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:3708a747aa4b6b505727282ca887041174e146ae030ebcadaf4c1d346858df62"}, + {file = "grpcio_tools-1.62.2-cp37-cp37m-win_amd64.whl", hash = "sha256:2ce149ea55eadb486a7fb75a20f63ef3ac065ee6a0240ed25f3549ce7954c653"}, + {file = "grpcio_tools-1.62.2-cp38-cp38-linux_armv7l.whl", hash = "sha256:58cbb24b3fa6ae35aa9c210fcea3a51aa5fef0cd25618eb4fd94f746d5a9b703"}, + {file = "grpcio_tools-1.62.2-cp38-cp38-macosx_10_10_universal2.whl", hash = "sha256:6413581e14a80e0b4532577766cf0586de4dd33766a31b3eb5374a746771c07d"}, + {file = "grpcio_tools-1.62.2-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:47117c8a7e861382470d0e22d336e5a91fdc5f851d1db44fa784b9acea190d87"}, + {file = "grpcio_tools-1.62.2-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9f1ba79a253df9e553d20319c615fa2b429684580fa042dba618d7f6649ac7e4"}, + {file = "grpcio_tools-1.62.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:04a394cf5e51ba9be412eb9f6c482b6270bd81016e033e8eb7d21b8cc28fe8b5"}, + {file = "grpcio_tools-1.62.2-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:3c53b221378b035ae2f1881cbc3aca42a6075a8e90e1a342c2f205eb1d1aa6a1"}, + {file = "grpcio_tools-1.62.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:c384c838b34d1b67068e51b5bbe49caa6aa3633acd158f1ab16b5da8d226bc53"}, + {file = "grpcio_tools-1.62.2-cp38-cp38-win32.whl", hash = "sha256:19ea69e41c3565932aa28a202d1875ec56786aea46a2eab54a3b28e8a27f9517"}, + {file = "grpcio_tools-1.62.2-cp38-cp38-win_amd64.whl", hash = "sha256:1d768a5c07279a4c461ebf52d0cec1c6ca85c6291c71ec2703fe3c3e7e28e8c4"}, + {file = "grpcio_tools-1.62.2-cp39-cp39-linux_armv7l.whl", hash = "sha256:5b07b5874187e170edfbd7aa2ca3a54ebf3b2952487653e8c0b0d83601c33035"}, + {file = "grpcio_tools-1.62.2-cp39-cp39-macosx_10_10_universal2.whl", hash = "sha256:d58389fe8be206ddfb4fa703db1e24c956856fcb9a81da62b13577b3a8f7fda7"}, + {file = "grpcio_tools-1.62.2-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:7d8b4e00c3d7237b92260fc18a561cd81f1da82e8be100db1b7d816250defc66"}, + {file = "grpcio_tools-1.62.2-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1fe08d2038f2b7c53259b5c49e0ad08c8e0ce2b548d8185993e7ef67e8592cca"}, + {file = "grpcio_tools-1.62.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:19216e1fb26dbe23d12a810517e1b3fbb8d4f98b1a3fbebeec9d93a79f092de4"}, + {file = "grpcio_tools-1.62.2-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:b8574469ecc4ff41d6bb95f44e0297cdb0d95bade388552a9a444db9cd7485cd"}, + {file = "grpcio_tools-1.62.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:4f6f32d39283ea834a493fccf0ebe9cfddee7577bdcc27736ad4be1732a36399"}, + {file = "grpcio_tools-1.62.2-cp39-cp39-win32.whl", hash = "sha256:76eb459bdf3fb666e01883270beee18f3f11ed44488486b61cd210b4e0e17cc1"}, + {file = "grpcio_tools-1.62.2-cp39-cp39-win_amd64.whl", hash = "sha256:217c2ee6a7ce519a55958b8622e21804f6fdb774db08c322f4c9536c35fdce7c"}, +] + +[package.dependencies] +grpcio = ">=1.62.2" +protobuf = ">=4.21.6,<5.0dev" +setuptools = "*" + [[package]] name = "h11" version = "0.14.0" @@ -1855,6 +1923,32 @@ files = [ {file = "h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d"}, ] +[[package]] +name = "h2" +version = "4.1.0" +description = "HTTP/2 State-Machine based protocol implementation" +optional = true +python-versions = ">=3.6.1" +files = [ + {file = "h2-4.1.0-py3-none-any.whl", hash = "sha256:03a46bcf682256c95b5fd9e9a99c1323584c3eec6440d379b9903d709476bc6d"}, + {file = "h2-4.1.0.tar.gz", hash = "sha256:a83aca08fbe7aacb79fec788c9c0bac936343560ed9ec18b82a13a12c28d2abb"}, +] + +[package.dependencies] +hpack = ">=4.0,<5" +hyperframe = ">=6.0,<7" + +[[package]] +name = "hpack" +version = "4.0.0" +description = "Pure-Python HPACK header compression" +optional = true +python-versions = ">=3.6.1" +files = [ + {file = "hpack-4.0.0-py3-none-any.whl", hash = "sha256:84a076fad3dc9a9f8063ccb8041ef100867b1878b25ef0ee63847a5d53818a6c"}, + {file = "hpack-4.0.0.tar.gz", hash = "sha256:fc41de0c63e687ebffde81187a948221294896f6bdc0ae2312708df339430095"}, +] + [[package]] name = "htmldate" version = "1.8.1" @@ -1926,6 +2020,7 @@ files = [ [package.dependencies] anyio = "*" certifi = "*" +h2 = {version = ">=3,<5", optional = true, markers = "extra == \"http2\""} httpcore = "==1.*" idna = "*" sniffio = "*" @@ -1981,6 +2076,17 @@ testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gr torch = ["safetensors", "torch"] typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)"] +[[package]] +name = "hyperframe" +version = "6.0.1" +description = "HTTP/2 framing layer for Python" +optional = true +python-versions = ">=3.6.1" +files = [ + {file = "hyperframe-6.0.1-py3-none-any.whl", hash = "sha256:0ec6bafd80d8ad2195c4f03aacba3a8265e57bc4cff261e802bf39970ed02a15"}, + {file = "hyperframe-6.0.1.tar.gz", hash = "sha256:ae510046231dc8e9ecb1a6586f63d2347bf4c8905914aa84ba585ae85f28a914"}, +] + [[package]] name = "identify" version = "2.5.36" @@ -3787,6 +3893,25 @@ files = [ dev = ["pre-commit", "tox"] testing = ["pytest", "pytest-benchmark"] +[[package]] +name = "portalocker" +version = "2.8.2" +description = "Wraps the portalocker recipe for easy usage" +optional = true +python-versions = ">=3.8" +files = [ + {file = "portalocker-2.8.2-py3-none-any.whl", hash = "sha256:cfb86acc09b9aa7c3b43594e19be1345b9d16af3feb08bf92f23d4dce513a28e"}, + {file = "portalocker-2.8.2.tar.gz", hash = "sha256:2b035aa7828e46c58e9b31390ee1f169b98e1066ab10b9a6a861fe7e25ee4f33"}, +] + +[package.dependencies] +pywin32 = {version = ">=226", markers = "platform_system == \"Windows\""} + +[package.extras] +docs = ["sphinx (>=1.7.1)"] +redis = ["redis"] +tests = ["pytest (>=5.4.1)", "pytest-cov (>=2.8.1)", "pytest-mypy (>=0.8.0)", "pytest-timeout (>=2.1.0)", "redis", "sphinx (>=6.0.0)", "types-redis"] + [[package]] name = "pprintpp" version = "0.4.0" @@ -4597,6 +4722,7 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -4645,6 +4771,32 @@ files = [ [package.dependencies] pyyaml = "*" +[[package]] +name = "qdrant-client" +version = "1.9.1" +description = "Client library for the Qdrant vector search engine" +optional = true +python-versions = ">=3.8" +files = [ + {file = "qdrant_client-1.9.1-py3-none-any.whl", hash = "sha256:b9b7e0e5c1a51410d8bb5106a869a51e12f92ab45a99030f27aba790553bd2c8"}, + {file = "qdrant_client-1.9.1.tar.gz", hash = "sha256:186b9c31d95aefe8f2db84b7746402d7365bd63b305550e530e31bde2002ce79"}, +] + +[package.dependencies] +grpcio = ">=1.41.0" +grpcio-tools = ">=1.41.0" +httpx = {version = ">=0.20.0", extras = ["http2"]} +numpy = [ + {version = ">=1.21", markers = "python_version >= \"3.8\" and python_version < \"3.12\""}, + {version = ">=1.26", markers = "python_version >= \"3.12\""}, +] +portalocker = ">=2.7.0,<3.0.0" +pydantic = ">=1.10.8" +urllib3 = ">=1.26.14,<3" + +[package.extras] +fastembed = ["fastembed (==0.2.6)"] + [[package]] name = "readme-renderer" version = "43.0" @@ -6369,7 +6521,7 @@ docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.link testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy", "pytest-ruff (>=0.2.1)"] [extras] -all = ["anthropic", "beautifulsoup4", "boto3", "cohere", "duckduckgo-search", "elevenlabs", "filetype", "google-generativeai", "mail-parser", "markdownify", "marqo", "ollama", "opensearch-py", "pandas", "pgvector", "pillow", "pinecone-client", "playwright", "psycopg2-binary", "pusher", "pymongo", "pypdf", "redis", "snowflake-sqlalchemy", "sqlalchemy-redshift", "torch", "trafilatura", "transformers", "voyageai"] +all = ["anthropic", "beautifulsoup4", "boto3", "cohere", "duckduckgo-search", "elevenlabs", "filetype", "google-generativeai", "mail-parser", "markdownify", "marqo", "ollama", "opensearch-py", "pandas", "pgvector", "pillow", "pinecone-client", "playwright", "psycopg2-binary", "pusher", "pymongo", "pypdf", "qdrant-client", "redis", "snowflake-sqlalchemy", "sqlalchemy-redshift", "torch", "trafilatura", "transformers", "voyageai"] drivers-embedding-amazon-bedrock = ["boto3"] drivers-embedding-amazon-sagemaker = ["boto3"] drivers-embedding-cohere = ["cohere"] @@ -6399,6 +6551,7 @@ drivers-vector-mongodb = ["pymongo"] drivers-vector-opensearch = ["opensearch-py"] drivers-vector-pinecone = ["pinecone-client"] drivers-vector-postgresql = ["pgvector", "psycopg2-binary"] +drivers-vector-qdrant = ["qdrant-client"] drivers-vector-redis = ["redis"] drivers-web-scraper-markdownify = ["beautifulsoup4", "markdownify", "playwright"] drivers-web-scraper-trafilatura = ["trafilatura"] @@ -6412,4 +6565,4 @@ loaders-pdf = ["pypdf"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "937d5e870407a493038178b3060c663feb9a1b6e6fc16e55fa380a18d02b0c80" +content-hash = "d4a00633119a6b9616fc0ea31ae354f806d1fc363e4a930fa912d2d710aa8938" diff --git a/pyproject.toml b/pyproject.toml index 7fc62d6af..da32484a2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,6 +52,7 @@ markdownify = {version = "^0.11.6", optional = true} voyageai = {version = "^0.2.1", optional = true} elevenlabs = {version = "^1.1.2", optional = true} torch = {version = "^2.3.0", optional = true} +qdrant-client = { version = ">=1.9.1", optional = true } pusher = {version = "^3.3.2", optional = true} ollama = {version = "^0.2.1", optional = true} duckduckgo-search = {version = "^6.1.6", optional = true} @@ -87,6 +88,7 @@ drivers-vector-redis = ["redis"] drivers-vector-opensearch = ["opensearch-py"] drivers-vector-amazon-opensearch = ["opensearch-py", "boto3"] drivers-vector-postgresql = ["pgvector", "psycopg2-binary"] +drivers-vector-qdrant = ["qdrant-client"] drivers-embedding-amazon-bedrock = ["boto3"] drivers-embedding-amazon-sagemaker = ["boto3"] @@ -124,6 +126,7 @@ all = [ "snowflake", "marqo", "pinecone-client", + "qdrant-client", "pymongo", "redis", "opensearch-py", diff --git a/tests/unit/drivers/vector/test_qdrant_vector_store_driver.py b/tests/unit/drivers/vector/test_qdrant_vector_store_driver.py new file mode 100644 index 000000000..4431b146f --- /dev/null +++ b/tests/unit/drivers/vector/test_qdrant_vector_store_driver.py @@ -0,0 +1,171 @@ +import pytest +from unittest.mock import MagicMock, patch +from griptape.drivers import QdrantVectorStoreDriver +from tests.mocks.mock_embedding_driver import MockEmbeddingDriver +from griptape.utils import import_optional_dependency +import uuid + + +class TestQdrantVectorStoreDriver: + @pytest.fixture + def embedding_driver(self): + return MockEmbeddingDriver() + + @pytest.fixture + def mock_engine(self): + return MagicMock() + + @pytest.fixture(autouse=True) + def driver(self, embedding_driver): + driver = QdrantVectorStoreDriver( + url="http://some_url", + port=8080, + grpc_port=50051, + prefer_grpc=True, + api_key=None, + https=False, + prefix=None, + force_disable_check_same_thread=False, + timeout=5, + distance="COSINE", + collection_name="some_collection", + vector_name=None, + content_payload_key="data", + embedding_driver=embedding_driver, + ) + return driver + + def test_attrs_post_init(self, driver): + with patch("griptape.drivers.vector.qdrant_vector_store_driver.import_optional_dependency") as mock_import: + mock_qdrant_client = MagicMock() + mock_import.return_value.QdrantClient.return_value = mock_qdrant_client + + driver.__attrs_post_init__() + + mock_import.assert_called_once_with("qdrant_client") + mock_import.return_value.QdrantClient.assert_called_once_with( + location=driver.location, + url=driver.url, + host=driver.host, + path=driver.path, + port=driver.port, + prefer_grpc=driver.prefer_grpc, + grpc_port=driver.grpc_port, + api_key=driver.api_key, + https=driver.https, + prefix=driver.prefix, + force_disable_check_same_thread=driver.force_disable_check_same_thread, + timeout=driver.timeout, + ) + assert driver.client == mock_qdrant_client + + def test_delete_vector(self, driver): + vector_id = "test_vector_id" + + mock_deletion_response = MagicMock() + mock_deletion_response.status = import_optional_dependency("qdrant_client.http.models").UpdateStatus.COMPLETED + + with patch.object(driver.client, "delete", return_value=mock_deletion_response) as mock_delete, patch( + "griptape.drivers.vector.qdrant_vector_store_driver.import_optional_dependency" + ) as mock_import: + mock_import.return_value.PointIdsList.return_value = MagicMock() + mock_import.return_value.UpdateStatus = import_optional_dependency("qdrant_client.http.models").UpdateStatus + + driver.delete_vector(vector_id) + + mock_delete.assert_called_once_with( + collection_name=driver.collection_name, + points_selector=mock_import.return_value.PointIdsList(points=[vector_id]), + ) + + def test_query(self, driver): + mock_query_result = [ + MagicMock( + id="foo", vector=[0, 1, 0], score=42, payload={"foo": "bar", "_score": 0.99, "_tensor_facets": []} + ) + ] + + with patch.object( + driver.embedding_driver, "embed_string", return_value=[0.1, 0.2, 0.3] + ) as mock_embed, patch.object(driver.client, "search", return_value=mock_query_result) as mock_search: + query = "test" + count = 10 + include_vectors = True + + results = driver.query(query, count, include_vectors=include_vectors) + + mock_embed.assert_called_once_with(query) + mock_search.assert_called_once_with( + collection_name=driver.collection_name, query_vector=[0.1, 0.2, 0.3], limit=count + ) + + assert len(results) == 1 + assert results[0].id == "foo" + assert results[0].vector == [0, 1, 0] if include_vectors else [] + assert results[0].score == 42 + assert results[0].meta == {"foo": "bar"} + + def test_upsert_with_batch(self, driver): + vector = [0.1, 0.2, 0.3] + vector_id = str(uuid.uuid4()) + meta = {"meta_key": "meta_value"} + + with patch("griptape.drivers.vector.qdrant_vector_store_driver.import_optional_dependency") as mock_import: + mock_batch = MagicMock() + mock_import.return_value.Batch.return_value = mock_batch + mock_qdrant_client = MagicMock() + driver.client = mock_qdrant_client + + result = driver.upsert_vector(vector=vector, vector_id=vector_id, meta=meta) + + mock_import.assert_called_once_with("qdrant_client.http.models") + mock_import.return_value.Batch.assert_called_once_with(ids=[vector_id], vectors=[vector], payloads=[meta]) + driver.client.upsert.assert_called_once_with(collection_name=driver.collection_name, points=mock_batch) + assert result == vector_id + + def test_load_entry(self, driver): + vector_id = str(uuid.uuid4()) + mock_entry = MagicMock() + mock_entry.id = vector_id + mock_entry.vector = [0.1, 0.2, 0.3] + mock_entry.payload = {"meta_key": "meta_value", "_score": 0.99, "_tensor_facets": []} + + with patch.object(driver.client, "retrieve", return_value=[mock_entry]): + result = driver.load_entry(vector_id) + + driver.client.retrieve.assert_called_once_with(collection_name=driver.collection_name, ids=[vector_id]) + + assert result.id == vector_id + assert result.vector == [0.1, 0.2, 0.3] + assert result.meta == {"meta_key": "meta_value"} + + with patch.object(driver.client, "retrieve", return_value=[]): + result = driver.load_entry(vector_id) + + driver.client.retrieve.assert_called_with(collection_name=driver.collection_name, ids=[vector_id]) + assert result is None + + def test_load_entries(self, driver): + mock_entries = [ + MagicMock( + id="id1", vector=[0.1, 0.2, 0.3], payload={"key1": "value1", "_score": 0.99, "_tensor_facets": []} + ), + MagicMock( + id="id2", vector=[0.4, 0.5, 0.6], payload={"key2": "value2", "_score": 0.88, "_tensor_facets": []} + ), + ] + + with patch.object(driver.client, "retrieve", return_value=mock_entries) as mock_retrieve: + results = driver.load_entries(ids=["id1", "id2"], with_payload=True, with_vectors=True) + + mock_retrieve.assert_called_once_with( + collection_name=driver.collection_name, ids=["id1", "id2"], with_payload=True, with_vectors=True + ) + + assert len(results) == 2 + assert results[0].id == "id1" + assert results[0].vector == [0.1, 0.2, 0.3] + assert results[0].meta == {"key1": "value1"} + assert results[1].id == "id2" + assert results[1].vector == [0.4, 0.5, 0.6] + assert results[1].meta == {"key2": "value2"}