forked from mamei16/LLM_Web_search
-
Notifications
You must be signed in to change notification settings - Fork 0
/
qdrant_retriever.py
136 lines (118 loc) · 4.79 KB
/
qdrant_retriever.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
from typing import (
Any,
Iterable,
List,
Optional,
Tuple,
cast,
Generator
)
import torch
from langchain_community.retrievers import QdrantSparseVectorRetriever
from langchain_community.vectorstores.qdrant import Qdrant
from langchain_core.pydantic_v1 import Field
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain.schema import Document
try:
from qdrant_client import QdrantClient, models
except ImportError:
pass
def batchify(_list: List, batch_size: int) -> Generator[List, None, None]:
for i in range(0, len(_list), batch_size):
yield _list[i:i + batch_size]
class MyQdrantSparseVectorRetriever(QdrantSparseVectorRetriever):
splade_doc_tokenizer: Any = Field(repr=False)
splade_doc_model: Any = Field(repr=False)
splade_query_tokenizer: Any = Field(repr=False)
splade_query_model: Any = Field(repr=False)
device: Any = Field(repr=False)
batch_size: int = Field(repr=False)
sparse_encoder: Any or None = Field(repr=False)
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def compute_document_vectors(self, texts: List[str], batch_size: int) -> Tuple[List[List[int]], List[List[float]]]:
indices = []
values = []
for text_batch in batchify(texts, batch_size):
with torch.no_grad():
tokens = self.splade_doc_tokenizer(text_batch, truncation=True, padding=True,
return_tensors="pt").to(self.device)
output = self.splade_doc_model(**tokens)
logits, attention_mask = output.logits, tokens.attention_mask
relu_log = torch.log(1 + torch.relu(logits))
weighted_log = relu_log * attention_mask.unsqueeze(-1)
tvecs, _ = torch.max(weighted_log, dim=1)
# extract all non-zero values and their indices from the sparse vectors
for batch in tvecs.cpu():
indices.append(batch.nonzero(as_tuple=True)[0].numpy())
values.append(batch[indices[-1]].numpy())
return indices, values
def compute_query_vector(self, text: str):
"""
Computes a vector from logits and attention mask using ReLU, log, and max operations.
"""
with torch.no_grad():
tokens = self.splade_query_tokenizer(text, return_tensors="pt").to(self.device)
output = self.splade_query_model(**tokens)
logits, attention_mask = output.logits, tokens.attention_mask
relu_log = torch.log(1 + torch.relu(logits))
weighted_log = relu_log * attention_mask.unsqueeze(-1)
max_val, _ = torch.max(weighted_log, dim=1)
query_vec = max_val.squeeze().cpu()
query_indices = query_vec.nonzero().numpy().flatten()
query_values = query_vec.detach().numpy()[query_indices]
return query_indices, query_values
def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
):
client = cast(QdrantClient, self.client)
indices, values = self.compute_document_vectors(texts, self.batch_size)
points = [
models.PointStruct(
id=i + 1,
vector={
self.sparse_vector_name: models.SparseVector(
indices=indices[i],
values=values[i],
)
},
payload={
self.content_payload_key: texts[i],
self.metadata_payload_key: metadatas[i] if metadatas else None,
},
)
for i in range(len(texts))
]
client.upsert(self.collection_name, points=points, **kwargs)
if self.device == "cuda":
torch.cuda.empty_cache()
def _get_relevant_documents(self, query: str, *, run_manager: CallbackManagerForRetrieverRun) -> List[Document]:
client = cast(QdrantClient, self.client)
query_indices, query_values = self.compute_query_vector(query)
results = client.search(
self.collection_name,
query_filter=self.filter,
query_vector=models.NamedSparseVector(
name=self.sparse_vector_name,
vector=models.SparseVector(
indices=query_indices,
values=query_values,
),
),
limit=self.k,
with_vectors=False,
**self.search_options,
)
return [
Qdrant._document_from_scored_point(
point,
self.collection_name,
self.content_payload_key,
self.metadata_payload_key,
)
for point in results
]