Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Qdrant VectorDB #812

Closed
Show file tree
Hide file tree
Changes from 78 commits
Commits
Show all changes
88 commits
Select commit Hold shift + click to select a range
139d23a
Added the initial code for qdrant-vector client
May 25, 2024
a9f7eda
Added the Qdrant vectorstore db md
May 28, 2024
28dfebe
Modified the MD
May 28, 2024
63eea0e
Fixed the url typo
May 28, 2024
f4456b1
Added few more chagnes
May 28, 2024
c15bf31
Added comments and restructured the file
May 28, 2024
f3f2e5c
Update qdrant_vector_store_driver.md
hkhajgiwale May 28, 2024
a8d69db
Update qdrant_vector_store_driver.py
hkhajgiwale May 31, 2024
a56003d
Update vector-store-drivers.md
hkhajgiwale May 31, 2024
0106634
Update vector-store-drivers.md
hkhajgiwale May 31, 2024
51f51ea
Create test_qdrant_vector_store_driver.py
hkhajgiwale May 31, 2024
9a346f7
Merge branch 'griptape-ai:dev' into feature/qdrant_vector_store_driver
hkhajgiwale May 31, 2024
cd153b5
Update test_qdrant_vector_store_driver.py
hkhajgiwale May 31, 2024
6b18025
Merge branch 'dev' into feature/qdrant_vector_store_driver
hkhajgiwale May 31, 2024
baac892
Update poetry.lock
hkhajgiwale May 31, 2024
a8bfd8b
Update poetry.lock
hkhajgiwale May 31, 2024
6cf977f
Used attrs instead of attr
hkhajgiwale May 31, 2024
7e3a370
Removed unused TYPE_CHECKING
hkhajgiwale May 31, 2024
ec874e9
Making poetry at par with dev
May 31, 2024
1b6dafa
updated poetry.lock
May 31, 2024
2106b7e
Merge branch 'griptape-ai:dev' into feature/qdrant_vector_store_driver
hkhajgiwale May 31, 2024
308b245
Modified the test file
May 31, 2024
9b54769
Formatting
May 31, 2024
389769f
Merge branch 'griptape-ai:dev' into dev
hkhajgiwale May 31, 2024
f3e3854
Refactored the driver claass
May 31, 2024
5367273
Merge branch 'griptape-ai:dev' into feature/qdrant_vector_store_driver
hkhajgiwale Jun 1, 2024
08ddb70
Added the change
Jun 4, 2024
5d274cc
Removed commented code
Jun 4, 2024
7ab4492
Incorporated required changes
Jun 6, 2024
79555ba
Merge remote-tracking branch 'origin/dev' into feature/qdrant_vector_…
Jun 6, 2024
2cf5e6f
Reverting to dev
Jun 6, 2024
ed58a03
Removed create collections, batches and build payloads
Jun 6, 2024
b2582a0
Added the str_to_hash instead of uuid
Jun 6, 2024
f921751
Changed poetry lock
Jun 6, 2024
9415e6f
Reverted to dfev
Jun 6, 2024
f5af949
Merge branch 'dev' into feature/qdrant_vector_store_driver
hkhajgiwale Jun 6, 2024
6ba7f64
Update pyproject.toml
hkhajgiwale Jun 6, 2024
8ab2f01
Added correct poetry.lock
Jun 6, 2024
b976a3f
Formatting using ruff
Jun 6, 2024
a995df6
Added dict for points
Jun 6, 2024
6c74636
Fixed delete vector
Jun 6, 2024
8c2c447
Added import_optional_dependency
Jun 6, 2024
cd4872a
Added all the usecases
Jun 6, 2024
b1e6504
Fixed the docs
Jun 6, 2024
2e4b691
Added extra test cases and removed direct imports
Jun 6, 2024
7159324
Updated the usage code in docs
Jun 6, 2024
47e3858
Moved embedding driver to fixture
Jun 6, 2024
3dcf5ee
Merge branch 'dev' into feature/qdrant_vector_store_driver
hkhajgiwale Jun 6, 2024
392a60e
Update poetry.lock
hkhajgiwale Jun 6, 2024
2f96def
updated the test file for more code coverage
Jun 6, 2024
46c5b72
Merge branch 'dev' into feature/qdrant_vector_store_driver
hkhajgiwale Jun 6, 2024
4b02e10
Removed AsyncQdrantClient and removed with_payload and with_vector
Jun 7, 2024
337945b
Removed unused force_create, added api_key in init and removed with_p…
Jun 7, 2024
94dba24
Added the default parameters
Jun 14, 2024
c32f238
Merge branch 'griptape-ai:dev' into feature/qdrant_vector_store_driver
hkhajgiwale Jun 14, 2024
99f71e1
Fixed the test case
Jun 14, 2024
f3f5c71
Refactored code to eliminate try and catch
Jun 14, 2024
6c7cc9f
Merge branch 'dev' into feature/qdrant_vector_store_driver
hkhajgiwale Jun 18, 2024
7aefee7
Added the deterministic UUIDV5
Jun 23, 2024
6553750
Merge branch 'dev' into feature/qdrant_vector_store_driver
hkhajgiwale Jun 23, 2024
665330a
Updated poetry.loc
Jun 23, 2024
6db7435
Removed QueryResult and added Entry
Jun 23, 2024
105db33
Expanded the test coverage
Jun 23, 2024
c167b0c
Expanded the test coverage
Jun 23, 2024
a7a7378
Create ci.yml
hkhajgiwale Jun 24, 2024
8659e61
Update code-checks.yml
hkhajgiwale Jun 24, 2024
8b40c2d
Delete .github/workflows/ci.yml
hkhajgiwale Jun 24, 2024
a10e689
Formatted the test cases
Jun 24, 2024
901977f
Reverted adding the feature branch after testing completed
Jun 24, 2024
f117dd3
Update docs/griptape-framework/drivers/vector-store-drivers.md
hkhajgiwale Jun 24, 2024
b75e370
Apply suggestions from code review
hkhajgiwale Jun 24, 2024
ddbbdb3
Reverted int as optional owing to inheritance
Jun 24, 2024
6274954
Merge branch 'dev' into feature/qdrant_vector_store_driver
hkhajgiwale Jun 24, 2024
55e8cfa
Added the correct use cases
Jun 26, 2024
940ece5
Merge branch 'dev' into feature/qdrant_vector_store_driver
hkhajgiwale Jun 26, 2024
4c935dd
Updated the poetry.lock
Jun 26, 2024
f0d51c5
Removed the commented code
Jun 26, 2024
085822a
Handled the negative case
Jun 26, 2024
624bad1
Removed the results section
Jun 26, 2024
4d51ec8
Added the changes for access_token
hkhajgiwale Jul 1, 2024
1e6077c
Incorporated changes regarding adding the access token and fixed docs
Jul 1, 2024
acf642d
Removed the unused VECTOR_NAME
Jul 1, 2024
11053f5
Merge branch 'dev' into feature/qdrant_vector_store_driver
hkhajgiwale Jul 1, 2024
2035e4c
Update docs-integration-tests.yml
hkhajgiwale Jul 1, 2024
588cfb2
Added the QDRANT_CLUSTER_API_KEY
Jul 1, 2024
ad39e2f
Merge branch 'dev' into feature/qdrant_vector_store_driver
hkhajgiwale Jul 2, 2024
dd162b6
Added the QdrantVectorStoreDriver to the changelog
Jul 2, 2024
e1638a8
Updated the poetry.locl
Jul 2, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 109 additions & 0 deletions docs/griptape-framework/drivers/vector-store-drivers.md
Original file line number Diff line number Diff line change
Expand Up @@ -406,3 +406,112 @@ vector_store_driver.upsert_text_artifacts(
result = vector_store_driver.query("What is griptape?")
print(result)
```

### Qdrant

!!! info
This driver requires the `drivers-vector-qdrant` [extra](../index.md#extras).

The QdrantVectorStoreDriver supports the [Qdrant vector database](https://qdrant.tech/).

Here is an example of how the driver can be used to query information in a Qdrant collection:

```python
import os
import logging
from transformers import AutoTokenizer
from sentence_transformers import SentenceTransformer
from griptape.drivers import QdrantVectorStoreDriver, HuggingFaceHubEmbeddingDriver
from griptape.tokenizers import HuggingFaceTokenizer
from griptape.loaders import PdfLoader
from griptape.chunkers import TextChunker

# Configure logging
logging.basicConfig(level=logging.INFO)
hkhajgiwale marked this conversation as resolved.
Show resolved Hide resolved

# Setting the models
embedding_model = "sentence-transformers/all-MiniLM-L6-v2"
qdrant_model = SentenceTransformer("all-MiniLM-L6-v2")
hkhajgiwale marked this conversation as resolved.
Show resolved Hide resolved
file_name = "linux_bible.pdf"
HUGGINGFACE_TOKEN = os.environ["HUGGINGFACE_HUB_ACCESS_TOKEN"]


# Using HuggingFaceTokenizer
def create_tokenizer(embedding_model):
tokenizer = HuggingFaceTokenizer(max_output_tokens=1024, tokenizer=AutoTokenizer.from_pretrained(embedding_model))
return tokenizer


def create_embedding_driver(embedding_model, tokenizer):
embedding_driver = HuggingFaceHubEmbeddingDriver(
api_token=HUGGINGFACE_TOKEN, model=embedding_model, tokenizer=tokenizer
)
return embedding_driver


# Instantiating QdrantVectorStoreDriver
def create_vector_store_driver(url, collection_name, embedding_driver):
vector_store_driver = QdrantVectorStoreDriver(
url=url, collection_name=collection_name, content_payload_key="content", embedding_driver=embedding_driver
hkhajgiwale marked this conversation as resolved.
Show resolved Hide resolved
)
return vector_store_driver


# Opening the file
def load_pdf(file_name, tokenizer):
with open(file_name, "rb") as f:
loader = PdfLoader(tokenizer=tokenizer, chunker=TextChunker(tokenizer=tokenizer, max_tokens=1024)).load(
f.read()
)
hkhajgiwale marked this conversation as resolved.
Show resolved Hide resolved
return loader


def main():
tokenizer = create_tokenizer(embedding_model)
embedding_driver = create_embedding_driver(embedding_model, tokenizer)
vector_store_driver = create_vector_store_driver(
url="http://localhost:6333", collection_name="linux_bible", embedding_driver=embedding_driver
hkhajgiwale marked this conversation as resolved.
Show resolved Hide resolved
)

# Loading the data
loader = load_pdf(file_name, tokenizer=tokenizer)

# Generate metadata for each chunk (example metadata)
metadata = [{"source": file_name, "page": i + 1} for i in range(len(loader))]

# Upserting the vector
try:
for i, l in enumerate(loader):
content = str(l)
meta = metadata[i] if metadata else None
vector_store_driver.upsert_vector(embedding_driver.try_embed_chunk(content), meta=meta, content=content)
logging.info("Successfully upserted vectors with metadata.")
except Exception as e:
logging.error(f"Error during upsert_vector: {e}")
hkhajgiwale marked this conversation as resolved.
Show resolved Hide resolved

# Querying the data
query_string = "Who created linux?"
query_results = vector_store_driver.query(query_string, count=6, include_vectors=True)
for result in query_results:
print(f"ID: {result.id}, Score: {result.score}, Vector: {result.vector}, Metadata: {result.meta}")

# Retrieving single entries
single_entry = vector_store_driver.load_entry(vector_id="00499354-9362-49c7-97b6-0220b9de84e7")
if single_entry:
print(f"Vector ID: {single_entry.id}")
print(f"Vector: {single_entry.vector}")
print(f"Metadata: {single_entry.meta}")
hkhajgiwale marked this conversation as resolved.
Show resolved Hide resolved
else:
print("Vector with ID 00499354-9362-49c7-97b6-0220b9de84e7 was not found.")
hkhajgiwale marked this conversation as resolved.
Show resolved Hide resolved

# Retrieving multiple entries
multiple_entries = vector_store_driver.load_entries(ids=["707a004c-23bc-476c-a097-e527977777f3"])
print(multiple_entries)

# Deleting the vector
vector_store_driver.delete_vector(vector_id="00499354-9362-49c7-97b6-0220b9de84e7")

if __name__ == "__main__":
main()
hkhajgiwale marked this conversation as resolved.
Show resolved Hide resolved

```
2 changes: 2 additions & 0 deletions griptape/drivers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from .vector.pgvector_vector_store_driver import PgVectorVectorStoreDriver
from .vector.azure_mongodb_vector_store_driver import AzureMongoDbVectorStoreDriver
from .vector.dummy_vector_store_driver import DummyVectorStoreDriver
from .vector.qdrant_vector_store_driver import QdrantVectorStoreDriver

from .sql.base_sql_driver import BaseSqlDriver
from .sql.amazon_redshift_sql_driver import AmazonRedshiftSqlDriver
Expand Down Expand Up @@ -144,6 +145,7 @@
"OpenSearchVectorStoreDriver",
"AmazonOpenSearchVectorStoreDriver",
"PgVectorVectorStoreDriver",
"QdrantVectorStoreDriver",
"DummyVectorStoreDriver",
"BaseSqlDriver",
"AmazonRedshiftSqlDriver",
Expand Down
210 changes: 210 additions & 0 deletions griptape/drivers/vector/qdrant_vector_store_driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
from __future__ import annotations
from typing import Optional
from attrs import define, field
from griptape.drivers import BaseVectorStoreDriver
from griptape.utils import import_optional_dependency
import uuid
import logging

VECTOR_NAME = None
DEFAULT_DISTANCE = "COSINE"
CONTENT_PAYLOAD_KEY = "data"


@define
class QdrantVectorStoreDriver(BaseVectorStoreDriver):
"""
Attributes:
location: An optional location for the Qdrant client. If set to ':memory:', an in-memory client is used.
url: An optional Qdrant API URL.
host: An optional Qdrant host.
path: Persistence path for QdrantLocal. Default: None
port: The port number for the Qdrant client. Defaults to 6333.
grpc_port: The gRPC port number for the Qdrant client. Defaults to 6334.
prefer_grpc: A boolean indicating whether to prefer gRPC over HTTP. Defaults to False.
force_disable_check_same_thread: For QdrantLocal, force disable check_same_thread. Default: False Only use this if you can guarantee that you can resolve the thread safety outside QdrantClient.
timeout: Timeout for REST and gRPC API requests. Default: 5 seconds for REST and unlimited for gRPC
api_key: API key for authentication in Qdrant Cloud. Defaults to False
https: If true - use HTTPS(SSL) protocol. Default: None
prefix: Add prefix to the REST URL path. Example: service/v1 will result in Example: service/v1 will result in http://localhost:6333/service/v1/{qdrant-endpoint} for REST API. Defaults to None
distance: The distance metric to be used for the vectors. Defaults to 'COSINE'.
collection_name: The name of the Qdrant collection.
vector_name: An optional name for the vectors.
content_payload_key: The key for the content payload in the metadata. Defaults to 'data'.
hkhajgiwale marked this conversation as resolved.
Show resolved Hide resolved
"""

location: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
url: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
host: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
path: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
port: int = field(default=6333, kw_only=True, metadata={"serializable": True})
grpc_port: int = field(default=6334, kw_only=True, metadata={"serializable": True})
prefer_grpc: bool = field(default=False, kw_only=True, metadata={"serializable": True})
api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
https: bool = field(default=None, kw_only=True, metadata={"serializable": True})
prefix: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
force_disable_check_same_thread: Optional[bool] = field(
default=False, kw_only=True, metadata={"serializable": True}
)
timeout: Optional[int] = field(default=5, kw_only=True, metadata={"serializable": True})
distance: str = field(default=DEFAULT_DISTANCE, kw_only=True, metadata={"serializable": True})
collection_name: str = field(kw_only=True, metadata={"serializable": True})
vector_name: Optional[str] = VECTOR_NAME
content_payload_key: str = field(default=CONTENT_PAYLOAD_KEY, kw_only=True, metadata={"serializable": True})

def __attrs_post_init__(self) -> None:
self.client = import_optional_dependency("qdrant_client").QdrantClient(
location=self.location,
url=self.url,
host=self.host,
path=self.path,
port=self.port,
prefer_grpc=self.prefer_grpc,
grpc_port=self.grpc_port,
api_key=self.api_key,
https=self.https,
prefix=self.prefix,
force_disable_check_same_thread=self.force_disable_check_same_thread,
timeout=self.timeout,
)

def delete_vector(self, vector_id: str) -> None:
hkhajgiwale marked this conversation as resolved.
Show resolved Hide resolved
"""
Delete a vector from the Qdrant collection based on its ID.

Parameters:
vector_id (str | id): ID of the vector to delete.
"""
deletion_response = self.client.delete(
collection_name=self.collection_name,
points_selector=import_optional_dependency("qdrant_client.http.models").PointIdsList(points=[vector_id]),
)
if deletion_response.status == import_optional_dependency("qdrant_client.http.models").UpdateStatus.COMPLETED:
logging.info(f"ID {vector_id} is successfully deleted")

def query(
self,
query: str,
count: Optional[int] = None,
namespace: Optional[str] = None,
include_vectors: bool = False,
**kwargs,
) -> list[BaseVectorStoreDriver.Entry]:
"""
Query the Qdrant collection based on a query vector.

Parameters:
query (str): Query string.
count (Optional[int]): Optional number of results to return.
namespace (Optional[str]): Optional namespace of the vectors.
include_vectors (bool): Whether to include vectors in the results.

Returns:
list[BaseVectorStoreDriver.Entry]: List of Entry objects.
"""
query_vector = self.embedding_driver.embed_string(query)

# Create a search request
results = self.client.search(collection_name=self.collection_name, query_vector=query_vector, limit=count)

# Convert results to QueryResult objects
query_results = [
BaseVectorStoreDriver.Entry(
id=result.id,
vector=result.vector if include_vectors else [],
score=result.score,
meta={k: v for k, v in result.payload.items() if k not in ["_score", "_tensor_facets"]},
hkhajgiwale marked this conversation as resolved.
Show resolved Hide resolved
)
for result in results
]
return query_results

def upsert_vector(
self,
vector: list[float],
vector_id: Optional[str] = None,
hkhajgiwale marked this conversation as resolved.
Show resolved Hide resolved
namespace: Optional[str] = None,
meta: Optional[dict] = None,
content: Optional[str] = None,
**kwargs,
) -> str:
"""
Upsert vectors into the Qdrant collection.

Parameters:
vector (list[float]): The vector to be upserted.
vector_id (Optional[str]): Optional vector ID.
namespace (Optional[str]): Optional namespace for the vector.
meta (Optional[dict]): Optional dictionary containing metadata.
content (Optional[str]): The text content to be included in the payload.

Returns:
str: The ID of the upserted vector.
"""

if vector_id is None:
vector_id = str(uuid.uuid5(uuid.NAMESPACE_DNS, str(vector)))

Check warning on line 146 in griptape/drivers/vector/qdrant_vector_store_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/vector/qdrant_vector_store_driver.py#L146

Added line #L146 was not covered by tests

if meta is None:
meta = {}

Check warning on line 149 in griptape/drivers/vector/qdrant_vector_store_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/vector/qdrant_vector_store_driver.py#L149

Added line #L149 was not covered by tests

if content:
meta[self.content_payload_key] = content

Check warning on line 152 in griptape/drivers/vector/qdrant_vector_store_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/vector/qdrant_vector_store_driver.py#L152

Added line #L152 was not covered by tests

points = import_optional_dependency("qdrant_client.http.models").Batch(
ids=[vector_id], vectors=[vector], payloads=[meta] if meta else None
)

self.client.upsert(collection_name=self.collection_name, points=points)
return vector_id

def load_entry(self, vector_id: str, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]:
hkhajgiwale marked this conversation as resolved.
Show resolved Hide resolved
hkhajgiwale marked this conversation as resolved.
Show resolved Hide resolved
"""
Load a vector entry from the Qdrant collection based on its ID.

Parameters:
vector_id (str): ID of the vector to load.
hkhajgiwale marked this conversation as resolved.
Show resolved Hide resolved
namespace (str, optional): Optional namespace of the vector.

Returns:
Optional[BaseVectorStoreDriver.Entry]: Vector entry if found, else None.
"""
results = self.client.retrieve(collection_name=self.collection_name, ids=[vector_id])
if results:
entry = results[0]
return BaseVectorStoreDriver.Entry(
id=entry.id,
vector=entry.vector,
meta={k: v for k, v in entry.payload.items() if k not in ["_score", "_tensor_facets"]},
)
else:
return None

def load_entries(self, namespace: Optional[str] = None, **kwargs) -> list[BaseVectorStoreDriver.Entry]:
"""
Load vector entries from the Qdrant collection.

Parameters:
namespace: Optional namespace of the vectors.

Returns:
List of points.
"""

results = self.client.retrieve(
collection_name=self.collection_name,
ids=kwargs.get("ids", []),
with_payload=kwargs.get("with_payload", True),
with_vectors=kwargs.get("with_vectors", True),
)
if not results:
logging.error("An error occurred or no results found.")
return []

Check warning on line 202 in griptape/drivers/vector/qdrant_vector_store_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/vector/qdrant_vector_store_driver.py#L201-L202

Added lines #L201 - L202 were not covered by tests
return [
BaseVectorStoreDriver.Entry(
id=entry.id,
vector=entry.vector if kwargs.get("with_vectors", True) else [],
meta={k: v for k, v in entry.payload.items() if k not in ["_score", "_tensor_facets"]},
)
for entry in results
]
Loading
Loading