Skip to content

Commit

Permalink
Merge pull request #3 from AWAS666/main
Browse files Browse the repository at this point in the history
Return distance for each entry
  • Loading branch information
vprelovac authored Oct 25, 2023
2 parents 3cadf23 + 6f90a78 commit a09dd73
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 9 deletions.
5 changes: 3 additions & 2 deletions vectordb/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,9 @@ def search(self, query: str, top_n: int = 5, unique: bool = False) -> List[Dict[

results = [
{
"chunk": self.memory[i]["chunk"],
"metadata": self.metadata_memory[self.memory[i]["metadata_index"]],
"chunk": self.memory[i[0]]["chunk"],
"metadata": self.metadata_memory[self.memory[i[0]]["metadata_index"]],
"distance": i[1]
}
for i in indices
]
Expand Down
16 changes: 9 additions & 7 deletions vectordb/vector_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,11 @@ def run_mrpt(vector, vectors, k=15):
"""
Search for the most similar vectors using MRPT method.
"""
if isinstance(vector, list):
vector = np.array(vector).astype(np.float32)
index = mrpt.MRPTIndex(vectors)
res = index.exact_search(vector, k, return_distances=False)
return res
res = index.exact_search(vector, k, return_distances=True)
return res[0].tolist(), res[1].tolist()

@staticmethod
def run_faiss(vector, vectors, k=15):
Expand All @@ -41,8 +43,8 @@ def run_faiss(vector, vectors, k=15):
"""
index = faiss.IndexFlatL2(vectors.shape[1])
index.add(vectors)
_, indices = index.search(np.array([vector]), k)
return indices[0]
dis, indices = index.search(np.array([vector]), k)
return indices[0], dis[0]

@staticmethod
def run_sk(vector, vectors, k=15):
Expand All @@ -59,7 +61,7 @@ def run_sk(vector, vectors, k=15):
@staticmethod
def search_vectors(
query_embedding: List[float], embeddings: List[List[float]], top_n: int
) -> List[int]:
) -> List[tuple[int, float]]:
"""
Searches for the most similar vectors to the query_embedding in the given embeddings.
Expand All @@ -76,6 +78,6 @@ def search_vectors(
else:
call_search = VectorSearch.run_mrpt

indices = call_search(query_embedding, embeddings, top_n)
indices, dis = call_search(query_embedding, embeddings, top_n)

return indices
return list(zip(indices, dis))

0 comments on commit a09dd73

Please sign in to comment.