-
Notifications
You must be signed in to change notification settings - Fork 9
/
evaluate-bm25-qdrant.py
119 lines (83 loc) · 2.92 KB
/
evaluate-bm25-qdrant.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from fastembed import SparseTextEmbedding
from ipdb import launch_ipdb_on_exception
from qdrant_client import QdrantClient, models
import os
import json
DATASET = os.getenv("DATASET", "quora")
QDRANT_URL = os.getenv("QDRANT_URL", "http://localhost:6333")
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY", None)
def load_queries():
queries = {}
with open(f"data/{DATASET}/queries.jsonl", "r") as file:
for line in file:
row = json.loads(line)
queries[row["_id"]] = {**row, "doc_ids": []}
with open(f"data/{DATASET}/qrels/test.tsv", "r") as file:
next(file)
for line in file:
query_id, doc_id, score = line.strip().split("\t")
if int(score) > 0:
queries[query_id]["doc_ids"].append(doc_id)
queries_filtered = {}
for query_id, query in queries.items():
if len(query["doc_ids"]) > 0:
queries_filtered[query_id] = query
return queries_filtered
def main():
n = 0
hits = 0
limit = 10
number_of_queries = 100_000
queries = load_queries()
client = QdrantClient(QDRANT_URL, api_key=QDRANT_API_KEY)
model = SparseTextEmbedding(
model_name="Qdrant/bm25"
)
def search_sparse(query, limit):
with launch_ipdb_on_exception():
sparse_vector_fe = list(model.query_embed(query))[0]
sparse_vector = models.SparseVector(
values=sparse_vector_fe.values.tolist(),
indices=sparse_vector_fe.indices.tolist()
)
result = client.query_points(
collection_name=DATASET,
query=sparse_vector,
using="bm25",
with_payload=True,
limit=limit
)
return result.points
recalls = []
precisions = []
num_queries = 0
for idx, query in enumerate(queries.values()):
if idx >= number_of_queries:
print(f"Processed {number_of_queries} queries, stopping...")
break
num_queries += 1
result = search_sparse(query["text"], limit)
found_ids = []
for hit in result:
found_ids.append(str(hit.id))
query_hits = 0
for doc_id in query["doc_ids"]:
n += 1
if doc_id in found_ids:
hits += 1
query_hits += 1
recalls.append(
query_hits / len(query["doc_ids"])
)
precisions.append(
query_hits / limit
)
print(f"Processing query: {query}, hits: {query_hits}")
print(f"Total hits: {hits} out of {n}, which is {hits/n}")
print(f"Precision: {hits/(num_queries * limit)}")
average_precision = sum(precisions) / len(precisions)
print(f"Average precision: {average_precision}")
average_recall = sum(recalls) / len(recalls)
print(f"Average recall: {average_recall}")
if __name__ == "__main__":
main()