Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Qdrant query count not optional #972

Merged
merged 12 commits into from
Jul 13, 2024
93 changes: 25 additions & 68 deletions docs/griptape-framework/drivers/vector-store-drivers.md
Original file line number Diff line number Diff line change
@@ -24,7 +24,6 @@ The [LocalVectorStoreDriver](../../reference/griptape/drivers/vector/local_vecto

```python
import os
from griptape.artifacts import BaseArtifact
from griptape.drivers import LocalVectorStoreDriver, OpenAiEmbeddingDriver
from griptape.loaders import WebLoader

@@ -40,16 +39,11 @@ artifacts = WebLoader(max_tokens=100).load("https://www.griptape.ai")
# Upsert Artifacts into the Vector Store Driver
[vector_store_driver.upsert_text_artifact(a, namespace="griptape") for a in artifacts]

results = vector_store_driver.query(
"creativity",
count=3,
namespace="griptape"
)
results = vector_store_driver.query(query="What is griptape?")

values = [r.to_artifact().value for r in results]

print("\n\n".join(values))

```

### Griptape Cloud Knowledge Base
@@ -58,7 +52,6 @@ The [GriptapeCloudKnowledgeBaseVectorStoreDriver](../../reference/griptape/drive

```python
import os
from griptape.artifacts import BaseArtifact
from griptape.drivers import GriptapeCloudKnowledgeBaseVectorStoreDriver


@@ -68,12 +61,11 @@ gt_cloud_knowledge_base_id = os.environ["GRIPTAPE_CLOUD_KB_ID"]

vector_store_driver = GriptapeCloudKnowledgeBaseVectorStoreDriver(api_key=gt_cloud_api_key, knowledge_base_id=gt_cloud_knowledge_base_id)

results =vector_store_driver.query(query="What is griptape?")
results = vector_store_driver.query(query="What is griptape?")

values = [r.to_artifact().value for r in results]

print("\n\n".join(values))

```

### Pinecone
@@ -86,50 +78,28 @@ The [PineconeVectorStoreDriver](../../reference/griptape/drivers/vector/pinecone
Here is an example of how the Driver can be used to load and query information in a Pinecone cluster:

```python
import os
import hashlib
import json
from urllib.request import urlopen
import os
from griptape.drivers import PineconeVectorStoreDriver, OpenAiEmbeddingDriver
from griptape.loaders import WebLoader

def load_data(driver: PineconeVectorStoreDriver) -> None:
response = urlopen(
"https://raw.githubusercontent.com/wedeploy-examples/"
"supermarket-web-example/master/products.json"
)

for product in json.loads(response.read()):
driver.upsert_text(
product["description"],
vector_id=hashlib.md5(product["title"].encode()).hexdigest(),
meta={
"title": product["title"],
"description": product["description"],
"type": product["type"],
"price": product["price"],
"rating": product["rating"],
},
namespace="supermarket-products",
)

# Initialize an Embedding Driver
embedding_driver = OpenAiEmbeddingDriver(api_key=os.environ["OPENAI_API_KEY"])

vector_store_driver = PineconeVectorStoreDriver(
api_key=os.environ["PINECONE_API_KEY"],
environment=os.environ["PINECONE_ENVIRONMENT"],
index_name=os.environ['PINECONE_INDEX_NAME'],
index_name=os.environ["PINECONE_INDEX_NAME"],
embedding_driver=embedding_driver,
)

load_data(vector_store_driver)
# Load Artifacts from the web
artifacts = WebLoader(max_tokens=100).load("https://www.griptape.ai")

results = vector_store_driver.query(
"fruit",
count=3,
filter={"price": {"$lte": 15}, "rating": {"$gte": 4}},
namespace="supermarket-products",
)
# Upsert Artifacts into the Vector Store Driver
[vector_store_driver.upsert_text_artifact(a, namespace="griptape") for a in artifacts]

results = vector_store_driver.query(query="What is griptape?")

values = [r.to_artifact().value for r in results]

@@ -175,7 +145,7 @@ vector_store_driver.upsert_text_artifacts(
}
)

results =vector_store_driver.query(query="What is griptape?")
results = vector_store_driver.query(query="What is griptape?")

values = [r.to_artifact().value for r in results]

@@ -227,7 +197,7 @@ vector_store_driver.upsert_text_artifacts(
}
)

results =vector_store_driver.query(query="What is griptape?")
results = vector_store_driver.query(query="What is griptape?")

values = [r.to_artifact().value for r in results]

@@ -298,7 +268,7 @@ vector_store_driver.upsert_text_artifacts(
}
)

results =vector_store_driver.query(query="What is griptape?")
results = vector_store_driver.query(query="What is griptape?")

values = [r.to_artifact().value for r in results]

@@ -341,7 +311,7 @@ vector_store_driver.upsert_text_artifacts(
}
)

results =vector_store_driver.query(query="What is griptape?")
results = vector_store_driver.query(query="What is griptape?")

values = [r.to_artifact().value for r in results]

@@ -388,7 +358,7 @@ vector_store_driver.upsert_text_artifacts(
}
)

results =vector_store_driver.query(query="What is griptape?")
results = vector_store_driver.query(query="What is griptape?")

values = [r.to_artifact().value for r in results]

@@ -450,7 +420,7 @@ vector_store_driver.upsert_text_artifacts(
}
)

results =vector_store_driver.query(query="What is griptape?")
results = vector_store_driver.query(query="What is griptape?")

values = [r.to_artifact().value for r in results]

@@ -473,49 +443,36 @@ from griptape.tokenizers import HuggingFaceTokenizer
from griptape.loaders import WebLoader

# Set up environment variables
embedding_model_name = "sentence-transformers/all-MiniLM-L6-v2"
host = os.environ["QDRANT_CLUSTER_ENDPOINT"]
huggingface_token = os.environ["HUGGINGFACE_HUB_ACCESS_TOKEN"]
api_key = os.environ["QDRANT_CLUSTER_API_KEY"]

# Initialize HuggingFace Embedding Driver
embedding_driver = HuggingFaceHubEmbeddingDriver(
api_token=huggingface_token,
model=embedding_model_name,
tokenizer=HuggingFaceTokenizer(model=embedding_model_name, max_output_tokens=512),
)
# Initialize an Embedding Driver.
embedding_driver = OpenAiEmbeddingDriver(api_key=os.environ["OPENAI_API_KEY"])

# Initialize Qdrant Vector Store Driver
vector_store_driver = QdrantVectorStoreDriver(
url=host,
collection_name="griptape",
content_payload_key="content",
embedding_driver=embedding_driver,
api_key=os.environ["QDRANT_CLUSTER_API_KEY"],
api_key=api_key,
)

# Load Artifacts from the web
artifacts = WebLoader().load("https://www.griptape.ai")

# Encode text to get embeddings
embeddings = embedding_driver.embed_text_artifact(artifacts[0])

# Recreate Qdrant collection
vector_store_driver.client.recreate_collection(
collection_name=vector_store_driver.collection_name,
vectors_config={
"size": len(embeddings),
"size": 1536,
"distance": vector_store_driver.distance
},
)

# Upsert vector into Qdrant
vector_store_driver.upsert_vector(
vector=embeddings,
vector_id=str(artifacts[0].id),
content=artifacts[0].value
)
# Upsert Artifacts into the Vector Store Driver
[vector_store_driver.upsert_text_artifact(a, namespace="griptape") for a in artifacts]

results =vector_store_driver.query(query="What is griptape?")
results = vector_store_driver.query(query="What is griptape?")

values = [r.to_artifact().value for r in results]

4 changes: 3 additions & 1 deletion griptape/drivers/vector/qdrant_vector_store_driver.py
Original file line number Diff line number Diff line change
@@ -106,7 +106,9 @@ def query(
query_vector = self.embedding_driver.embed_string(query)

# Create a search request
results = self.client.search(collection_name=self.collection_name, query_vector=query_vector, limit=count)
request = {"collection_name": self.collection_name, "query_vector": query_vector, "limit": count}
request = {k: v for k, v in request.items() if v is not None}
results = self.client.search(**request)

# Convert results to QueryResult objects
query_results = [
Loading