-
Notifications
You must be signed in to change notification settings - Fork 284
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Sai krishna
committed
Dec 6, 2024
1 parent
47b00ae
commit db810a1
Showing
2 changed files
with
36 additions
and
32 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,21 +1,22 @@ | ||
from typing import List, Optional | ||
from langchain_openai import OpenAIEmbeddings | ||
from pymongo import MongoClient, UpdateOne | ||
from typing import List | ||
|
||
from langchain.docstore.document import Document | ||
from langchain.embeddings.base import Embeddings | ||
from langchain.vectorstores.mongodb_atlas import MongoDBAtlasVectorSearch | ||
from langchain.schema.vectorstore import VectorStore | ||
from langchain.vectorstores.mongodb_atlas import MongoDBAtlasVectorSearch | ||
from pymongo import MongoClient, UpdateOne | ||
|
||
from backend import logger | ||
from backend.constants import DEFAULT_BATCH_SIZE_FOR_VECTOR_STORE | ||
from backend.types import DataPointVector, VectorDBConfig | ||
from backend.modules.vector_db.base import BaseVectorDB | ||
from backend.types import DataPointVector, VectorDBConfig | ||
|
||
|
||
class MongoVectorDB(BaseVectorDB): | ||
def __init__(self, config: VectorDBConfig): | ||
"""Initialize MongoDB vector database client""" | ||
self.client = MongoClient(config.url) | ||
self.client = MongoClient( | ||
"mongodb+srv://sai:[email protected]/?retryWrites=true&w=majority&appName=Cluster0" | ||
) | ||
self.db = self.client[config.database_name] | ||
|
||
def create_collection(self, collection_name: str, embeddings: Embeddings) -> None: | ||
|
@@ -24,13 +25,14 @@ def create_collection(self, collection_name: str, embeddings: Embeddings) -> Non | |
raise ValueError(f"Collection {collection_name} already exists in MongoDB") | ||
|
||
collection = self.db.create_collection(collection_name) | ||
# Create vector search index | ||
collection.create_index([ | ||
("embedding", "vectorSearch") | ||
], { | ||
"numDimensions": self.get_embedding_dimensions(embeddings), | ||
"similarity": "cosine" | ||
}) | ||
# Create vector search index | ||
collection.create_index( | ||
[("embedding", "vectorSearch")], | ||
{ | ||
"numDimensions": self.get_embedding_dimensions(embeddings), | ||
"similarity": "cosine", | ||
}, | ||
) | ||
|
||
def upsert_documents( | ||
self, | ||
|
@@ -41,7 +43,7 @@ def upsert_documents( | |
): | ||
"""Upsert documents with their embeddings""" | ||
collection = self.db[collection_name] | ||
|
||
# Generate embeddings for documents | ||
texts = [doc.page_content for doc in documents] | ||
embeddings_list = embeddings.embed_documents(texts) | ||
|
@@ -52,19 +54,22 @@ def upsert_documents( | |
mongo_doc = { | ||
"text": doc.page_content, | ||
"embedding": embedding, | ||
"metadata": doc.metadata | ||
"metadata": doc.metadata, | ||
} | ||
docs_to_insert.append(mongo_doc) | ||
|
||
# Use bulk write for better performance | ||
if incremental: | ||
collection.bulk_write([ | ||
UpdateOne( | ||
{"metadata.source": doc["metadata"]["source"]}, | ||
{"$set": doc}, | ||
upsert=True | ||
) for doc in docs_to_insert | ||
]) | ||
collection.bulk_write( | ||
[ | ||
UpdateOne( | ||
{"metadata.source": doc["metadata"]["source"]}, | ||
{"$set": doc}, | ||
upsert=True, | ||
) | ||
for doc in docs_to_insert | ||
] | ||
) | ||
else: | ||
# TODO: only delete the existing documents with the in collection with given ids | ||
collection.delete_many({}) | ||
|
@@ -100,21 +105,20 @@ def list_data_point_vectors( | |
"""List vectors for a data source""" | ||
collection = self.db[collection_name] | ||
vectors = [] | ||
|
||
cursor = collection.find( | ||
{"metadata.data_source_fqn": data_source_fqn}, | ||
batch_size=batch_size | ||
{"metadata.data_source_fqn": data_source_fqn}, batch_size=batch_size | ||
) | ||
|
||
for doc in cursor: | ||
vector = DataPointVector( | ||
id=str(doc["_id"]), | ||
text=doc["text"], | ||
metadata=doc["metadata"], | ||
embedding=doc["embedding"] | ||
embedding=doc["embedding"], | ||
) | ||
vectors.append(vector) | ||
|
||
return vectors | ||
|
||
def delete_data_point_vectors( | ||
|
@@ -126,8 +130,8 @@ def delete_data_point_vectors( | |
"""Delete vectors by their IDs""" | ||
collection = self.db[collection_name] | ||
vector_ids = [vector.id for vector in data_point_vectors] | ||
|
||
# Delete in batches | ||
for i in range(0, len(vector_ids), batch_size): | ||
batch = vector_ids[i:i + batch_size] | ||
batch = vector_ids[i : i + batch_size] | ||
collection.delete_many({"_id": {"$in": batch}}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters