-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathqdrant_test.py
126 lines (101 loc) · 4.24 KB
/
qdrant_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import uuid
from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, VectorParams
from qdrant_client.http.models import PointStruct, CollectionStatus, UpdateStatus
from qdrant_client.http.models import Filter, FieldCondition, MatchValue
from qdrant_client.http import models
from typing import List
import openai
from openai.embeddings_utils import get_embedding
openai.api_key = "YOUR-OPENAI-API-KEY"
class QdrantVectorStore:
def __init__(self,
host: str = "localhost",
port: int = 6333,
db_path: str = "/Users/het/qdrant/qdrant_storage",
collection_name: str = "test_collection",
vector_size: int = 1536,
vector_distance=Distance.COSINE
):
self.client = QdrantClient(
url=host,
port=port,
path=db_path
)
self.collection_name = collection_name
try:
collection_info = self.client.get_collection(collection_name=collection_name)
except Exception as e:
print("Collection does not exist, creating collection now")
self.set_up_collection(collection_name, vector_size, vector_distance)
def set_up_collection(self, collection_name: str, vector_size: int, vector_distance: str):
self.client.recreate_collection(
collection_name=collection_name,
vectors_config=VectorParams(size=vector_size, distance=vector_distance)
)
collection_info = self.client.get_collection(collection_name=collection_name)
def upsert_data(self, data: List[dict]):
points = []
for item in data:
text = item.get("text")
text_vector = get_embedding(text, engine="text-embedding-ada-002")
text_id = str(uuid.uuid4())
point = PointStruct(id=text_id, vector=text_vector, payload=item)
points.append(point)
operation_info = self.client.upsert(
collection_name=self.collection_name,
wait=True,
points=points)
if operation_info.status == UpdateStatus.COMPLETED:
print("Data inserted successfully!")
else:
print("Failed to insert data")
def search(self, input_query: str, limit: int = 3):
input_vector = get_embedding(input_query, engine="text-embedding-ada-002")
search_result = self.client.search(
collection_name=self.collection_name,
query_vector=input_vector,
limit=limit
)
result = []
for item in search_result:
similarity_score = item.score
payload = item.payload
data = {"id": item.id, "similarity_score": similarity_score, "text": payload.get("text")}
result.append(data)
return result
def search_with_filter(self, input_query: str, filter: dict, limit: int = 3):
input_vector = get_embedding(input_query, engine="text-embedding-ada-002")
filter_list = []
for key, value in filter.items():
filter_list.append(
FieldCondition(key=key, match=MatchValue(value=value))
)
search_result = self.client.search(
collection_name=self.collection_name,
query_vector=input_vector,
query_filter=Filter(
must=filter_list
),
limit=limit
)
result = []
for item in search_result:
similarity_score = item.score
payload = item.payload
data = {"id": item.id, "similarity_score": similarity_score, "quote": payload.get("quote"),
"person": payload.get("person")}
result.append(data)
return result
def delete(self, text_ids: list):
self.client.delete(
collection_name=self.collection_name,
points_selector=models.PointIdsList(
points=text_ids,
)
)
def delete_collection(self, collection_name: str):
self.client.delete_collection(collection_name=collection_name)
print("collection deleted")
def get_collection(self, collection_name: str):
return self.client.get_collection(collection_name=collection_name)