Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Qdrant #1180

Merged
merged 19 commits into from
Mar 7, 2024
12 changes: 8 additions & 4 deletions databases/qdrant/docker/chatbot/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import streamlit as st
import os

# [START gke_databases_qdrant_docker_chat_model]
vertexAI = ChatVertexAI(model_name="gemini-pro", streaming=True, convert_system_message_to_human=True)
prompt_template = ChatPromptTemplate.from_messages(
[
Expand All @@ -47,33 +48,36 @@
)

embedding_model = VertexAIEmbeddings("textembedding-gecko@001")
# [END gke_databases_qdrant_docker_chat_model]

# [START gke_databases_qdrant_docker_chat_client]
client = QdrantClient(
url=os.getenv("QDRANT_URL"),
api_key=os.getenv("APIKEY"),
)
collection_name = os.getenv("COLLECTION_NAME")
qdrant = Qdrant(client, collection_name, embeddings=embedding_model)

# [END gke_databases_qdrant_docker_chat_client]
def format_docs(docs):
return "\n\n".join([d.page_content for d in docs])

st.title("🤖 Chatbot")
if "messages" not in st.session_state:
st.session_state["messages"] = [{"role": "ai", "content": "How can I help you?"}]

# [START gke_databases_qdrant_docker_chat_session]
if "memory" not in st.session_state:
st.session_state["memory"] = ConversationBufferWindowMemory(
memory_key="history",
ai_prefix="Bob",
human_prefix="User",
k=3,
)

# [END gke_databases_qdrant_docker_chat_session]
# [START gke_databases_qdrant_docker_chat_history]
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.write(message["content"])

# [END gke_databases_qdrant_docker_chat_history]
if chat_input := st.chat_input():
with st.chat_message("human"):
st.write(chat_input)
Expand Down
11 changes: 9 additions & 2 deletions databases/qdrant/docker/embed-docs/embedding-job.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,27 @@
from langchain_community.vectorstores import Qdrant
from google.cloud import storage
import os

# [START gke_databases_qdrant_docker_embed_docs_retrieval]
bucketname = os.getenv("BUCKET_NAME")
filename = os.getenv("FILE_NAME")

storage_client = storage.Client()
bucket = storage_client.bucket(bucketname)
blob = bucket.blob(filename)
blob.download_to_filename("/documents/" + filename)
# [END gke_databases_qdrant_docker_embed_docs_retrieval]

# [START gke_databases_qdrant_docker_embed_docs_split]
loader = PyPDFLoader("/documents/" + filename)
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
documents = loader.load_and_split(text_splitter)
# [END gke_databases_qdrant_docker_embed_docs_split]

# [START gke_databases_qdrant_docker_embed_docs_embed]
embeddings = VertexAIEmbeddings("textembedding-gecko@001")
# [END gke_databases_qdrant_docker_embed_docs_embed]

# [START gke_databases_qdrant_docker_embed_docs_storage]
qdrant = Qdrant.from_documents(
documents, embeddings,
collection_name=os.getenv("COLLECTION_NAME"),
Expand All @@ -40,7 +47,7 @@
shard_number=6,
replication_factor=2
)

# [END gke_databases_qdrant_docker_embed_docs_storage]
print(filename + " was successfully embedded")
print(f"# of vectors = {len(documents)}")

4 changes: 2 additions & 2 deletions databases/qdrant/docker/embed-docs/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def bucket():

# Setup K8 configs
config.load_incluster_config()

# [START gke_databases_qdrant_docker_embed_endpoint_job]
def kube_create_job_object(name, container_image, bucket_name, f_name, namespace="qdrant", container_name="jobcontainer", env_vars={}):

body = client.V1Job(api_version="batch/v1", kind="Job")
Expand All @@ -64,7 +64,7 @@ def kube_create_job_object(name, container_image, bucket_name, f_name, namespace

body.spec = client.V1JobSpec(backoff_limit=3, ttl_seconds_after_finished=60, template=template.template)
return body

# [END gke_databases_qdrant_docker_embed_endpoint_job]
def kube_test_credentials():
try:
api_response = api_instance.get_api_resources()
Expand Down
17 changes: 14 additions & 3 deletions databases/qdrant/manifests/04-qdrant-client/app.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#!/usr/bin/env python

# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -12,20 +14,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# [START gke_databases_qdrant_manifests_04_imports]
from qdrant_client import QdrantClient
from qdrant_client.http import models
import os
import sys
import csv
# [END gke_databases_qdrant_manifests_04_imports]

def main(query_string):

# [START gke_databases_qdrant_manifests_04_create_client]
qdrant = QdrantClient(
url="http://qdrant-database:6333", api_key=os.getenv("APIKEY"))
# [END gke_databases_qdrant_manifests_04_create_client]

# Create a collection
# [START gke_databases_qdrant_manifests_04_create_collection]
books = [*csv.DictReader(open('/usr/local/dataset/dataset.csv'))]
# [END gke_databases_qdrant_manifests_04_create_collection]

# [START gke_databases_quadrant_manifests_04_prepare_doc]
documents: list[dict[str, any]] = []
metadata: list[dict[str, any]] = []
ids: list[int] = []
Expand All @@ -40,11 +47,15 @@ def main(query_string):
"publishDate": doc["publishDate"],
}
)
# [END gke_databases_quadrant_manifests_04_prepare_doc]

# [START gke_databases_quadrant_manifests_04_add_to_collection]
# Add my_books to the collection
qdrant.add(collection_name="my_books", documents=documents, metadata=metadata, ids=ids, parallel=2)
# [END gke_databases_quadrant_manifests_04_add_to_collection]

# Query the collection
# [START gke_databases_quadrant_manifests_04_query_collection]
results = qdrant.query(
collection_name="my_books",
query_text=query_string,
Expand All @@ -54,7 +65,7 @@ def main(query_string):
print("Title:", result.metadata["title"], "\nAuthor:", result.metadata["author"])
print("Description:", result.metadata["document"], "Published:", result.metadata["publishDate"], "\nScore:", result.score)
print("-----")

# [END gke_databases_quadrant_manifests_04_query_collection]
if __name__ == "__main__":
if len(sys.argv) > 1:
query_string = " ".join(sys.argv[1:])
Expand Down