From c46c248b279a2a23be49c8b7ee06d8604b43dc2b Mon Sep 17 00:00:00 2001 From: Sai krishna Date: Fri, 6 Dec 2024 20:14:56 +0530 Subject: [PATCH] feat: add mongo db integration --- backend/modules/vector_db/__init__.py | 5 +- backend/modules/vector_db/base.py | 8 +- backend/modules/vector_db/mongo.py | 390 ++++++++++++++++++++++---- backend/vectordb.requirements.txt | 3 +- compose.env | 9 +- docker-compose.yaml | 30 -- 6 files changed, 346 insertions(+), 99 deletions(-) diff --git a/backend/modules/vector_db/__init__.py b/backend/modules/vector_db/__init__.py index 2e50ba07..e4a9fde8 100644 --- a/backend/modules/vector_db/__init__.py +++ b/backend/modules/vector_db/__init__.py @@ -1,14 +1,13 @@ from backend.modules.vector_db.base import BaseVectorDB -from backend.modules.vector_db.mongo import MongoVectorDB +# from backend.modules.vector_db.mongo import MongoVectorDB from backend.modules.vector_db.qdrant import QdrantVectorDB - # from backend.modules.vector_db.singlestore import SingleStoreVectorDB # from backend.modules.vector_db.weaviate import WeaviateVectorDB from backend.types import VectorDBConfig SUPPORTED_VECTOR_DBS = { "qdrant": QdrantVectorDB, - "mongo": MongoVectorDB, + # "mongo": MongoVectorDB, # "weaviate": WeaviateVectorDB, # "singlestore": SingleStoreVectorDB, } diff --git a/backend/modules/vector_db/base.py b/backend/modules/vector_db/base.py index 65e80609..420ec852 100644 --- a/backend/modules/vector_db/base.py +++ b/backend/modules/vector_db/base.py @@ -6,6 +6,7 @@ from langchain.schema.vectorstore import VectorStore from backend.constants import DEFAULT_BATCH_SIZE_FOR_VECTOR_STORE +from backend.logger import logger from backend.types import DataPointVector @@ -88,4 +89,9 @@ def get_embedding_dimensions(self, embeddings: Embeddings) -> int: """ Fetch embedding dimensions """ - raise NotImplementedError() + # Calculate embedding size + logger.debug(f"[VectorDB] Embedding a dummy doc to get vector dimensions") + partial_embeddings = embeddings.embed_documents(["Initial document"]) + vector_size = len(partial_embeddings[0]) + logger.debug(f"Vector size: {vector_size}") + return vector_size diff --git a/backend/modules/vector_db/mongo.py b/backend/modules/vector_db/mongo.py index 00febb71..b3ff0569 100644 --- a/backend/modules/vector_db/mongo.py +++ b/backend/modules/vector_db/mongo.py @@ -1,39 +1,86 @@ +import time from typing import List from langchain.docstore.document import Document from langchain.embeddings.base import Embeddings from langchain.schema.vectorstore import VectorStore -from langchain.vectorstores.mongodb_atlas import MongoDBAtlasVectorSearch -from pymongo import MongoClient, UpdateOne +from langchain_mongodb import MongoDBAtlasVectorSearch +from pymongo import MongoClient +from pymongo.operations import SearchIndexModel -from backend.constants import DEFAULT_BATCH_SIZE_FOR_VECTOR_STORE +from backend.constants import ( + DATA_POINT_FQN_METADATA_KEY, + DATA_POINT_HASH_METADATA_KEY, + DEFAULT_BATCH_SIZE_FOR_VECTOR_STORE, +) +from backend.logger import logger from backend.modules.vector_db.base import BaseVectorDB from backend.types import DataPointVector, VectorDBConfig +MAX_SCROLL_LIMIT = int(1e6) +BATCH_SIZE = 1000 + class MongoVectorDB(BaseVectorDB): def __init__(self, config: VectorDBConfig): """Initialize MongoDB vector database client""" - self.client = MongoClient( - "mongodb+srv://sai:mbdqh2jyt@cluster0.bz0r9.mongodb.net/?retryWrites=true&w=majority&appName=Cluster0" - ) - self.db = self.client[config.database_name] + self.config = config + self.client = MongoClient(config.url) + self.db = self.client[config.config.get("database_name")] def create_collection(self, collection_name: str, embeddings: Embeddings) -> None: """Create a collection with vector search index""" if collection_name in self.db.list_collection_names(): raise ValueError(f"Collection {collection_name} already exists in MongoDB") - collection = self.db.create_collection(collection_name) - # Create vector search index - collection.create_index( - [("embedding", "vectorSearch")], - { - "numDimensions": self.get_embedding_dimensions(embeddings), - "similarity": "cosine", + # Create the collection first + self.db.create_collection(collection_name) + + # Define the search index model + self._create_search_index(collection_name, embeddings) + + def _create_search_index(self, collection_name: str, embeddings: Embeddings): + # Mongo DB requires a vector search index to be created with a specific configuration + # Number of dimensions are calculated based on the embedding model. We use a sample text to get the number of dimensions + # Similarity is set to cosine as we are using cosine similarity for the vector search. This can be changed to euclidean, dotproduct etc as per the requirements. + # Path is set to "embedding" as embeddings are stored in the "embedding" field in the collection by default in MongoDB. + # Reference: https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-type/ + search_index_model = SearchIndexModel( + definition={ + "fields": [ + { + "type": "vector", + "path": "embedding", + "numDimensions": self.get_embedding_dimensions(embeddings), + "similarity": "cosine", + } + ] }, + name="vector_search_index", + type="vectorSearch", ) + # Create the search index + result = self.db[collection_name].create_search_index(model=search_index_model) + logger.debug(f"New search index named {result} is building.") + + # Immediate avaialbility of the index is not guaranteed upon creation. + # MongoDB documentation recommends polling for the index to be ready. + # Ensure this check to provide a seamless experience. + # TODO (mnvsk97): We might want to introduce a new status in the ingestion runs to reflex this. + logger.debug( + "Polling to check if the index is ready. This may take up to a minute." + ) + predicate = lambda index: index.get("queryable") is True + while True: + indices = list( + self.db[collection_name].list_search_indexes("vector_search_index") + ) + if len(indices) and predicate(indices[0]): + break + time.sleep(5) + logger.debug(f"{result} is ready for querying.") + def upsert_documents( self, collection_name: str, @@ -41,39 +88,92 @@ def upsert_documents( embeddings: Embeddings, incremental: bool = True, ): - """Upsert documents with their embeddings""" + if len(documents) == 0: + logger.warning("No documents to index") + return + # get record IDs to be upserted + logger.debug( + f"[Mongo] Adding {len(documents)} documents to collection {collection_name}" + ) + + """Upsert documenlots with their embeddings""" collection = self.db[collection_name] - # Generate embeddings for documents - texts = [doc.page_content for doc in documents] - embeddings_list = embeddings.embed_documents(texts) - - # Prepare documents for insertion - docs_to_insert = [] - for doc, embedding in zip(documents, embeddings_list): - mongo_doc = { - "text": doc.page_content, - "embedding": embedding, - "metadata": doc.metadata, - } - docs_to_insert.append(mongo_doc) - - # Use bulk write for better performance - if incremental: - collection.bulk_write( - [ - UpdateOne( - {"metadata.source": doc["metadata"]["source"]}, - {"$set": doc}, - upsert=True, - ) - for doc in docs_to_insert + data_point_fqns = [] + for document in documents: + if document.metadata.get(DATA_POINT_FQN_METADATA_KEY): + data_point_fqns.append( + document.metadata.get(DATA_POINT_FQN_METADATA_KEY) + ) + + logger.debug(f"[Mongo] documents: {documents}") + + record_ids_to_be_upserted: List[str] = self._get_records_to_be_upserted( + collection_name=collection_name, + data_point_fqns=data_point_fqns, + incremental=incremental, + ) + + # Add Documents + MongoDBAtlasVectorSearch( + collection=collection, + embedding=embeddings, + ).add_documents(documents=documents) + logger.debug( + f"[Mongo] Added {len(documents)} documents to collection {collection_name}" + ) + + if len(record_ids_to_be_upserted) > 0: + logger.debug( + f"[Mongo] Deleting {len(record_ids_to_be_upserted)} outdated documents from collection {collection_name}" + ) + for i in range( + 0, len(record_ids_to_be_upserted), DEFAULT_BATCH_SIZE_FOR_VECTOR_STORE + ): + record_ids_to_be_processed = record_ids_to_be_upserted[ + i : i + DEFAULT_BATCH_SIZE_FOR_VECTOR_STORE ] + collection.delete_many({"_id": {"$in": record_ids_to_be_processed}}) + logger.debug( + f"[Mongo] Deleted {len(record_ids_to_be_upserted)} outdated documents from collection {collection_name}" ) - else: - # TODO: only delete the existing documents with the in collection with given ids - collection.delete_many({}) - collection.insert_many(docs_to_insert) + + def _get_records_to_be_upserted( + self, collection_name: str, data_point_fqns: List[str], incremental: bool = True + ) -> List[str]: + """Get record IDs to be upserted""" + if not incremental: + return [] + + # For incremental deletion, we delete the documents with the same document_id + logger.debug( + f"[Mongo] Incremental Ingestion: Fetching documents for {len(data_point_fqns)} data point fqns for collection {collection_name}" + ) + + collection = self.db[collection_name] + record_ids_to_be_upserted = [] + + # Query in batches to avoid memory issues + for i in range(0, len(data_point_fqns), DEFAULT_BATCH_SIZE_FOR_VECTOR_STORE): + batch_fqns = data_point_fqns[i : i + DEFAULT_BATCH_SIZE_FOR_VECTOR_STORE] + + cursor = collection.find( + {f"metadata.{DATA_POINT_FQN_METADATA_KEY}": {"$in": batch_fqns}}, + {"_id": 1}, + ) + + for doc in cursor: + record_ids_to_be_upserted.append(doc["_id"]) + if len(record_ids_to_be_upserted) > MAX_SCROLL_LIMIT: + break + + if len(record_ids_to_be_upserted) > MAX_SCROLL_LIMIT: + break + + logger.debug( + f"[Mongo] Incremental Ingestion: collection={collection_name} Addition={len(data_point_fqns)}, Updates={len(record_ids_to_be_upserted)}" + ) + return record_ids_to_be_upserted def get_collections(self) -> List[str]: """Get all collection names""" @@ -90,6 +190,7 @@ def get_vector_store( return MongoDBAtlasVectorSearch( collection=self.db[collection_name], embedding=embeddings, + index_name="vector_search_index", ) def get_vector_client(self): @@ -102,24 +203,59 @@ def list_data_point_vectors( data_source_fqn: str, batch_size: int = DEFAULT_BATCH_SIZE_FOR_VECTOR_STORE, ) -> List[DataPointVector]: - """List vectors for a data source""" - collection = self.db[collection_name] - vectors = [] - - cursor = collection.find( - {"metadata.data_source_fqn": data_source_fqn}, batch_size=batch_size + logger.debug( + f"[Mongo] Listing all data point vectors for collection {collection_name}" ) + stop = False + offset = None + data_point_vectors: List[DataPointVector] = [] + collection = self.db[collection_name] - for doc in cursor: - vector = DataPointVector( - id=str(doc["_id"]), - text=doc["text"], - metadata=doc["metadata"], - embedding=doc["embedding"], + while stop is not True: + batch_cursor = ( + collection.find( + {f"metadata.{DATA_POINT_FQN_METADATA_KEY}": data_source_fqn}, + { + "_id": 1, + f"metadata.{DATA_POINT_FQN_METADATA_KEY}": 1, + f"metadata.{DATA_POINT_HASH_METADATA_KEY}": 1, + }, + ) + .skip(offset if offset else 0) + .limit(batch_size) ) - vectors.append(vector) - return vectors + # Convert cursor to list to check if we got any results + batch_results = list(batch_cursor) + if not batch_results: + stop = True + break + + for doc in batch_results: + metadata = doc.get("metadata", {}) + if metadata.get(DATA_POINT_FQN_METADATA_KEY) and metadata.get( + DATA_POINT_HASH_METADATA_KEY + ): + data_point_vectors.append( + DataPointVector( + data_point_vector_id=str(doc["_id"]), + data_point_fqn=metadata.get(DATA_POINT_FQN_METADATA_KEY), + data_point_hash=metadata.get(DATA_POINT_HASH_METADATA_KEY), + ) + ) + if len(data_point_vectors) > MAX_SCROLL_LIMIT: + stop = True + break + + if len(batch_results) < batch_size: + stop = True + else: + offset = (offset if offset else 0) + batch_size + + logger.debug( + f"[Mongo] Listing {len(data_point_vectors)} data point vectors for collection {collection_name}" + ) + return data_point_vectors def delete_data_point_vectors( self, @@ -135,3 +271,143 @@ def delete_data_point_vectors( for i in range(0, len(vector_ids), batch_size): batch = vector_ids[i : i + batch_size] collection.delete_many({"_id": {"$in": batch}}) + + def list_documents_in_collection( + self, collection_name: str, base_document_id: str = None + ) -> List[str]: + """ + List all documents in a collection + """ + logger.debug( + f"[Mongo] Listing all documents with base document id {base_document_id} for collection {collection_name}" + ) + stop = False + offset = None + document_ids_set = set() + collection = self.db[collection_name] + + while stop is not True: + batch_cursor = ( + collection.find( + {f"metadata.{DATA_POINT_FQN_METADATA_KEY}": base_document_id} + if base_document_id + else {}, + {f"metadata.{DATA_POINT_FQN_METADATA_KEY}": 1}, + ) + .skip(offset if offset else 0) + .limit(BATCH_SIZE) + ) + + # Convert cursor to list to check if we got any results + batch_results = list(batch_cursor) + if not batch_results: + stop = True + break + + for doc in batch_results: + if doc.get("metadata") and doc.get("metadata").get( + DATA_POINT_FQN_METADATA_KEY + ): + document_ids_set.add( + doc.get("metadata").get(DATA_POINT_FQN_METADATA_KEY) + ) + if len(document_ids_set) > MAX_SCROLL_LIMIT: + stop = True + break + + if len(batch_results) < BATCH_SIZE: + stop = True + else: + offset = (offset if offset else 0) + BATCH_SIZE + + logger.debug( + f"[Mongo] Found {len(document_ids_set)} documents with base document id {base_document_id} for collection {collection_name}" + ) + return list(document_ids_set) + + def delete_documents(self, collection_name: str, document_ids: List[str]): + """ + Delete documents from the collection + """ + logger.debug( + f"[Mongo] Deleting {len(document_ids)} documents from collection {collection_name}" + ) + try: + collection = self.db[collection_name] + except Exception as exp: + logger.debug(exp) + return + + for i in range(0, len(document_ids), BATCH_SIZE): + document_ids_to_be_processed = document_ids[i : i + BATCH_SIZE] + collection.delete_many( + { + f"metadata.{DATA_POINT_FQN_METADATA_KEY}": { + "$in": document_ids_to_be_processed + } + } + ) + + logger.debug( + f"[Mongo] Deleted {len(document_ids)} documents from collection {collection_name}" + ) + + def list_document_vector_points( + self, collection_name: str + ) -> List[DataPointVector]: + """ + List all documents in a collection + """ + logger.debug( + f"[Mongo] Listing all document vector points for collection {collection_name}" + ) + stop = False + offset = None + document_vector_points: List[DataPointVector] = [] + collection = self.db[collection_name] + + while stop is not True: + batch_cursor = ( + collection.find( + {}, + { + "_id": 1, + f"metadata.{DATA_POINT_FQN_METADATA_KEY}": 1, + f"metadata.{DATA_POINT_HASH_METADATA_KEY}": 1, + }, + ) + .skip(offset if offset else 0) + .limit(BATCH_SIZE) + ) + + # Convert cursor to list to check if we got any results + batch_results = list(batch_cursor) + if not batch_results: + stop = True + break + + for doc in batch_results: + metadata = doc.get("metadata", {}) + if metadata.get(DATA_POINT_FQN_METADATA_KEY) and metadata.get( + DATA_POINT_HASH_METADATA_KEY + ): + document_vector_points.append( + DataPointVector( + point_id=str(doc["_id"]), + document_id=metadata.get(DATA_POINT_FQN_METADATA_KEY), + document_hash=metadata.get(DATA_POINT_HASH_METADATA_KEY), + ) + ) + if len(document_vector_points) > MAX_SCROLL_LIMIT: + stop = True + break + + if len(batch_results) < BATCH_SIZE: + stop = True + else: + offset = (offset if offset else 0) + BATCH_SIZE + + logger.debug( + f"[Mongo] Listed {len(document_vector_points)} document vector points for collection {collection_name}" + ) + return document_vector_points diff --git a/backend/vectordb.requirements.txt b/backend/vectordb.requirements.txt index 43799eea..40d677fd 100644 --- a/backend/vectordb.requirements.txt +++ b/backend/vectordb.requirements.txt @@ -4,5 +4,6 @@ singlestoredb==1.0.4 ### Weaviate client (in progress) weaviate-client==3.25.3 -### MongoDB client +### MongoDB pymongo==4.10.1 +langchain-mongodb==0.2.0 \ No newline at end of file diff --git a/compose.env b/compose.env index 79730d2b..256c59d9 100644 --- a/compose.env +++ b/compose.env @@ -13,7 +13,8 @@ MODELS_CONFIG_PATH="./models_config.yaml" METADATA_STORE_CONFIG='{"provider":"prisma"}' ML_REPO_NAME='' VECTOR_DB_CONFIG='{"provider":"qdrant","url":"http://qdrant-server:6333", "config": {"grpc_port": 6334, "prefer_grpc": false}}' -MONGO_DB_CONFIG='{"provider":"mongo","url":"connection_uri", "config": {}}' +# MONGO Example +# VECTOR_DB_CONFIG='{"provider":"mongo","url":"connection_uri", "config": {"database_name": "cognita"}}' COGNITA_BACKEND_PORT=8000 UNSTRUCTURED_IO_URL=http://unstructured-io-parsers:9500/ @@ -49,9 +50,3 @@ BRAVE_API_KEY= ## WHISPER WHISPER_PORT=10300 WHISPER_MODEL=Systran/faster-distil-whisper-large-v3 - -## MongoDB -ME_CONFIG_MONGODB_ADMINUSERNAME: root -ME_CONFIG_MONGODB_ADMINPASSWORD: example -ME_CONFIG_MONGODB_URL: mongodb://root:example@mongo:27017/ -ME_CONFIG_BASICAUTH: false diff --git a/docker-compose.yaml b/docker-compose.yaml index 5ff7d3e6..12e40b51 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -231,36 +231,6 @@ services: networks: - cognita-docker - mongo: - image: mongo - pull_policy: if_not_present - restart: unless-stopped - container_name: mongo - profiles: - - mongo - environment: - MONGO_INITDB_ROOT_USERNAME: ${MONGO_INITDB_ROOT_USERNAME} - MONGO_INITDB_ROOT_PASSWORD: ${MONGO_INITDB_ROOT_PASSWORD} - networks: - - cognita-docker - - mongo-express: - image: mongo-express - pull_policy: if_not_present - restart: unless-stopped - container_name: mongo-express - profiles: - - mongo-express - ports: - - 8081:8081 - environment: - ME_CONFIG_MONGODB_ADMINUSERNAME: root - ME_CONFIG_MONGODB_ADMINPASSWORD: example - ME_CONFIG_MONGODB_URL: mongodb://root:example@mongo:27017/ - ME_CONFIG_BASICAUTH: false - networks: - - cognita-docker - networks: cognita-docker: external: false