diff --git a/dbgpt/app/knowledge/api.py b/dbgpt/app/knowledge/api.py index 2203d4c2c..7f21188ab 100644 --- a/dbgpt/app/knowledge/api.py +++ b/dbgpt/app/knowledge/api.py @@ -27,6 +27,7 @@ EMBEDDING_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH, ) +from dbgpt.rag import ChunkParameters from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory from dbgpt.rag.knowledge.base import ChunkStrategy from dbgpt.rag.knowledge.factory import KnowledgeFactory @@ -235,13 +236,30 @@ async def document_upload( @router.post("/knowledge/{space_name}/document/sync") -def document_sync(space_name: str, request: DocumentSyncRequest): +async def document_sync( + space_name: str, + request: DocumentSyncRequest, + service: Service = Depends(get_rag_service), +): logger.info(f"Received params: {space_name}, {request}") try: - knowledge_space_service.sync_knowledge_document( - space_name=space_name, sync_request=request + space = service.get({"name": space_name}) + if space is None: + return Result.failed(code="E000X", msg=f"space {space_name} not exist") + if request.doc_ids is None or len(request.doc_ids) == 0: + return Result.failed(code="E000X", msg="doc_ids is None") + sync_request = KnowledgeSyncRequest( + doc_id=request.doc_ids[0], + space_id=str(space.id), + model_name=request.model_name, ) - return Result.succ([]) + sync_request.chunk_parameters = ChunkParameters( + chunk_strategy="Automatic", + chunk_size=request.chunk_size or 512, + chunk_overlap=request.chunk_overlap or 50, + ) + doc_ids = await service.sync_document(requests=[sync_request]) + return Result.succ(doc_ids) except Exception as e: return Result.failed(code="E000X", msg=f"document sync error {e}") diff --git a/dbgpt/app/knowledge/service.py b/dbgpt/app/knowledge/service.py index 9755ea04b..91b4eb7a7 100644 --- a/dbgpt/app/knowledge/service.py +++ b/dbgpt/app/knowledge/service.py @@ -1,7 +1,6 @@ import json import logging from datetime import datetime -from typing import List from dbgpt._private.config import Config from dbgpt.app.knowledge.chunk_db import DocumentChunkDao, DocumentChunkEntity @@ -32,13 +31,8 @@ from dbgpt.rag.assembler.summary import SummaryAssembler from dbgpt.rag.chunk_manager import ChunkParameters from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory -from dbgpt.rag.knowledge.base import ChunkStrategy, KnowledgeType +from dbgpt.rag.knowledge.base import KnowledgeType from dbgpt.rag.knowledge.factory import KnowledgeFactory -from dbgpt.rag.text_splitter.text_splitter import ( - RecursiveCharacterTextSplitter, - SpacyTextSplitter, -) -from dbgpt.serve.rag.api.schemas import KnowledgeSyncRequest from dbgpt.serve.rag.models.models import KnowledgeSpaceDao, KnowledgeSpaceEntity from dbgpt.serve.rag.service.service import SyncStatus from dbgpt.storage.vector_store.base import VectorStoreConfig @@ -199,186 +193,6 @@ def get_knowledge_documents(self, space, request: DocumentQueryRequest): total = knowledge_document_dao.get_knowledge_documents_count(query) return DocumentQueryResponse(data=data, total=total, page=page) - def batch_document_sync( - self, - space_name, - sync_requests: List[KnowledgeSyncRequest], - ) -> List[int]: - """batch sync knowledge document chunk into vector store - Args: - - space: Knowledge Space Name - - sync_requests: List[KnowledgeSyncRequest] - Returns: - - List[int]: document ids - """ - doc_ids = [] - for sync_request in sync_requests: - docs = knowledge_document_dao.documents_by_ids([sync_request.doc_id]) - if len(docs) == 0: - raise Exception( - f"there are document called, doc_id: {sync_request.doc_id}" - ) - doc = docs[0] - if ( - doc.status == SyncStatus.RUNNING.name - or doc.status == SyncStatus.FINISHED.name - ): - raise Exception( - f" doc:{doc.doc_name} status is {doc.status}, can not sync" - ) - chunk_parameters = sync_request.chunk_parameters - if chunk_parameters.chunk_strategy != ChunkStrategy.CHUNK_BY_SIZE.name: - space_context = self.get_space_context(space_name) - chunk_parameters.chunk_size = ( - CFG.KNOWLEDGE_CHUNK_SIZE - if space_context is None - else int(space_context["embedding"]["chunk_size"]) - ) - chunk_parameters.chunk_overlap = ( - CFG.KNOWLEDGE_CHUNK_OVERLAP - if space_context is None - else int(space_context["embedding"]["chunk_overlap"]) - ) - self._sync_knowledge_document(space_name, doc, chunk_parameters) - doc_ids.append(doc.id) - return doc_ids - - def sync_knowledge_document(self, space_name, sync_request: DocumentSyncRequest): - """sync knowledge document chunk into vector store - Args: - - space: Knowledge Space Name - - sync_request: DocumentSyncRequest - """ - from dbgpt.rag.text_splitter.pre_text_splitter import PreTextSplitter - - doc_ids = sync_request.doc_ids - self.model_name = sync_request.model_name or CFG.LLM_MODEL - for doc_id in doc_ids: - query = KnowledgeDocumentEntity(id=doc_id) - docs = knowledge_document_dao.get_documents(query) - if len(docs) == 0: - raise Exception( - f"there are document called, doc_id: {sync_request.doc_id}" - ) - doc = docs[0] - if ( - doc.status == SyncStatus.RUNNING.name - or doc.status == SyncStatus.FINISHED.name - ): - raise Exception( - f" doc:{doc.doc_name} status is {doc.status}, can not sync" - ) - - space_context = self.get_space_context(space_name) - chunk_size = ( - CFG.KNOWLEDGE_CHUNK_SIZE - if space_context is None - else int(space_context["embedding"]["chunk_size"]) - ) - chunk_overlap = ( - CFG.KNOWLEDGE_CHUNK_OVERLAP - if space_context is None - else int(space_context["embedding"]["chunk_overlap"]) - ) - if sync_request.chunk_size: - chunk_size = sync_request.chunk_size - if sync_request.chunk_overlap: - chunk_overlap = sync_request.chunk_overlap - separators = sync_request.separators or None - from dbgpt.rag.chunk_manager import ChunkParameters - - chunk_parameters = ChunkParameters( - chunk_size=chunk_size, chunk_overlap=chunk_overlap - ) - if CFG.LANGUAGE == "en": - text_splitter = RecursiveCharacterTextSplitter( - separators=separators, - chunk_size=chunk_size, - chunk_overlap=chunk_overlap, - length_function=len, - ) - else: - if separators and len(separators) > 1: - raise ValueError( - "SpacyTextSplitter do not support multipsle separators" - ) - try: - separator = "\n\n" if not separators else separators[0] - text_splitter = SpacyTextSplitter( - separator=separator, - pipeline="zh_core_web_sm", - chunk_size=chunk_size, - chunk_overlap=chunk_overlap, - ) - except Exception: - text_splitter = RecursiveCharacterTextSplitter( - separators=separators, - chunk_size=chunk_size, - chunk_overlap=chunk_overlap, - ) - if sync_request.pre_separator: - logger.info(f"Use preseparator, {sync_request.pre_separator}") - text_splitter = PreTextSplitter( - pre_separator=sync_request.pre_separator, - text_splitter_impl=text_splitter, - ) - chunk_parameters.text_splitter = text_splitter - self._sync_knowledge_document(space_name, doc, chunk_parameters) - return doc.id - - def _sync_knowledge_document( - self, - space_name, - doc: KnowledgeDocumentEntity, - chunk_parameters: ChunkParameters, - ) -> List[Chunk]: - """sync knowledge document chunk into vector store""" - embedding_factory = CFG.SYSTEM_APP.get_component( - "embedding_factory", EmbeddingFactory - ) - embedding_fn = embedding_factory.create( - model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL] - ) - - spaces = self.get_knowledge_space(KnowledgeSpaceRequest(name=space_name)) - if len(spaces) != 1: - raise Exception(f"invalid space name:{space_name}") - space = spaces[0] - - from dbgpt.storage.vector_store.base import VectorStoreConfig - - config = VectorStoreConfig( - name=space.name, - embedding_fn=embedding_fn, - max_chunks_once_load=CFG.KNOWLEDGE_MAX_CHUNKS_ONCE_LOAD, - llm_client=self.llm_client, - model_name=self.model_name, - ) - vector_store_connector = VectorStoreConnector( - vector_store_type=space.vector_type, vector_store_config=config - ) - knowledge = KnowledgeFactory.create( - datasource=doc.content, - knowledge_type=KnowledgeType.get_by_value(doc.doc_type), - ) - assembler = EmbeddingAssembler.load_from_knowledge( - knowledge=knowledge, - chunk_parameters=chunk_parameters, - embeddings=embedding_fn, - vector_store_connector=vector_store_connector, - ) - chunk_docs = assembler.get_chunks() - doc.status = SyncStatus.RUNNING.name - doc.chunk_size = len(chunk_docs) - doc.gmt_modified = datetime.now() - knowledge_document_dao.update_knowledge_document(doc) - executor = CFG.SYSTEM_APP.get_component( - ComponentType.EXECUTOR_DEFAULT, ExecutorFactory - ).create() - executor.submit(self.async_doc_embedding, assembler, chunk_docs, doc) - logger.info(f"begin save document chunks, doc:{doc.doc_name}") - return chunk_docs - async def document_summary(self, request: DocumentSummaryRequest): """get document summary Args: diff --git a/dbgpt/rag/assembler/embedding.py b/dbgpt/rag/assembler/embedding.py index 2f1a6bd9b..095408c3e 100644 --- a/dbgpt/rag/assembler/embedding.py +++ b/dbgpt/rag/assembler/embedding.py @@ -1,9 +1,11 @@ """Embedding Assembler.""" +from concurrent.futures import ThreadPoolExecutor from typing import Any, List, Optional from dbgpt.core import Chunk, Embeddings from dbgpt.storage.vector_store.connector import VectorStoreConnector +from ...util.executor_utils import blocking_func_to_async from ..assembler.base import BaseAssembler from ..chunk_manager import ChunkParameters from ..embedding.embedding_factory import DefaultEmbeddingFactory @@ -98,6 +100,41 @@ def load_from_knowledge( embeddings=embeddings, ) + @classmethod + async def aload_from_knowledge( + cls, + knowledge: Knowledge, + vector_store_connector: VectorStoreConnector, + chunk_parameters: Optional[ChunkParameters] = None, + embedding_model: Optional[str] = None, + embeddings: Optional[Embeddings] = None, + executor: Optional[ThreadPoolExecutor] = None, + ) -> "EmbeddingAssembler": + """Load document embedding into vector store from path. + + Args: + knowledge: (Knowledge) Knowledge datasource. + vector_store_connector: (VectorStoreConnector) VectorStoreConnector to use. + chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for + chunking. + embedding_model: (Optional[str]) Embedding model to use. + embeddings: (Optional[Embeddings]) Embeddings to use. + executor: (Optional[ThreadPoolExecutor) ThreadPoolExecutor to use. + + Returns: + EmbeddingAssembler + """ + executor = executor or ThreadPoolExecutor() + return await blocking_func_to_async( + executor, + cls, + knowledge, + vector_store_connector, + chunk_parameters, + embedding_model, + embeddings, + ) + def persist(self) -> List[str]: """Persist chunks into vector store. diff --git a/dbgpt/rag/index/base.py b/dbgpt/rag/index/base.py index 59f6b49cc..764c362e4 100644 --- a/dbgpt/rag/index/base.py +++ b/dbgpt/rag/index/base.py @@ -8,6 +8,7 @@ from dbgpt._private.pydantic import BaseModel, ConfigDict, Field, model_to_dict from dbgpt.core import Chunk, Embeddings from dbgpt.storage.vector_store.filters import MetadataFilters +from dbgpt.util.executor_utils import blocking_func_to_async logger = logging.getLogger(__name__) @@ -46,6 +47,10 @@ def to_dict(self, **kwargs) -> Dict[str, Any]: class IndexStoreBase(ABC): """Index store base class.""" + def __init__(self, executor: Optional[ThreadPoolExecutor] = None): + """Init index store.""" + self._executor = executor or ThreadPoolExecutor() + @abstractmethod def load_document(self, chunks: List[Chunk]) -> List[str]: """Load document in index database. @@ -143,6 +148,27 @@ def load_document_with_limit( ) return ids + async def aload_document_with_limit( + self, chunks: List[Chunk], max_chunks_once_load: int = 10, max_threads: int = 1 + ) -> List[str]: + """Load document in index database with specified limit. + + Args: + chunks(List[Chunk]): Document chunks. + max_chunks_once_load(int): Max number of chunks to load at once. + max_threads(int): Max number of threads to use. + + Return: + List[str]: Chunk ids. + """ + return await blocking_func_to_async( + self._executor, + self.load_document_with_limit, + chunks, + max_chunks_once_load, + max_threads, + ) + def similar_search( self, text: str, topk: int, filters: Optional[MetadataFilters] = None ) -> List[Chunk]: diff --git a/dbgpt/serve/rag/service/service.py b/dbgpt/serve/rag/service/service.py index 02a39a683..8d7f55fc5 100644 --- a/dbgpt/serve/rag/service/service.py +++ b/dbgpt/serve/rag/service/service.py @@ -443,7 +443,7 @@ async def _sync_knowledge_document( space_id, doc_vo: DocumentVO, chunk_parameters: ChunkParameters, - ) -> List[Chunk]: + ) -> None: """sync knowledge document chunk into vector store""" embedding_factory = CFG.SYSTEM_APP.get_component( "embedding_factory", EmbeddingFactory @@ -470,47 +470,45 @@ async def _sync_knowledge_document( datasource=doc.content, knowledge_type=KnowledgeType.get_by_value(doc.doc_type), ) - assembler = EmbeddingAssembler.load_from_knowledge( - knowledge=knowledge, - chunk_parameters=chunk_parameters, - vector_store_connector=vector_store_connector, - ) - chunk_docs = assembler.get_chunks() doc.status = SyncStatus.RUNNING.name - doc.chunk_size = len(chunk_docs) + doc.gmt_modified = datetime.now() self._document_dao.update_knowledge_document(doc) - # executor = CFG.SYSTEM_APP.get_component( - # ComponentType.EXECUTOR_DEFAULT, ExecutorFactory - # ).create() - # executor.submit(self.async_doc_embedding, assembler, chunk_docs, doc) - asyncio.create_task(self.async_doc_embedding(assembler, chunk_docs, doc)) + # asyncio.create_task(self.async_doc_embedding(assembler, chunk_docs, doc)) + asyncio.create_task( + self.async_doc_embedding( + knowledge, chunk_parameters, vector_store_connector, doc + ) + ) logger.info(f"begin save document chunks, doc:{doc.doc_name}") - return chunk_docs + # return chunk_docs @trace("async_doc_embedding") - async def async_doc_embedding(self, assembler, chunk_docs, doc): + async def async_doc_embedding( + self, knowledge, chunk_parameters, vector_store_connector, doc + ): """async document embedding into vector db Args: - - client: EmbeddingEngine Client - - chunk_docs: List[Document] - - doc: KnowledgeDocumentEntity + - knowledge: Knowledge + - chunk_parameters: ChunkParameters + - vector_store_connector: vector_store_connector + - doc: doc """ - logger.info( - f"async doc embedding sync, doc:{doc.doc_name}, chunks length is {len(chunk_docs)}" - ) + logger.info(f"async doc embedding sync, doc:{doc.doc_name}") try: with root_tracer.start_span( "app.knowledge.assembler.persist", - metadata={"doc": doc.doc_name, "chunks": len(chunk_docs)}, + metadata={"doc": doc.doc_name}, ): - # vector_ids = assembler.persist() - space = self.get({"name": doc.space}) - if space and space.vector_type == "KnowledgeGraph": - vector_ids = await assembler.apersist() - else: - vector_ids = assembler.persist() + assembler = await EmbeddingAssembler.aload_from_knowledge( + knowledge=knowledge, + chunk_parameters=chunk_parameters, + vector_store_connector=vector_store_connector, + ) + chunk_docs = assembler.get_chunks() + doc.chunk_size = len(chunk_docs) + vector_ids = await assembler.apersist() doc.status = SyncStatus.FINISHED.name doc.result = "document embedding success" if vector_ids is not None: diff --git a/dbgpt/storage/knowledge_graph/knowledge_graph.py b/dbgpt/storage/knowledge_graph/knowledge_graph.py index 5bcddf4fd..59a1ba39d 100644 --- a/dbgpt/storage/knowledge_graph/knowledge_graph.py +++ b/dbgpt/storage/knowledge_graph/knowledge_graph.py @@ -37,7 +37,7 @@ class BuiltinKnowledgeGraph(KnowledgeGraphBase): def __init__(self, config: BuiltinKnowledgeGraphConfig): """Create builtin knowledge graph instance.""" self._config = config - + super().__init__() self._llm_client = config.llm_client if not self._llm_client: raise ValueError("No llm client provided.") diff --git a/dbgpt/storage/vector_store/base.py b/dbgpt/storage/vector_store/base.py index 5046504b5..390286179 100644 --- a/dbgpt/storage/vector_store/base.py +++ b/dbgpt/storage/vector_store/base.py @@ -2,6 +2,7 @@ import logging import math from abc import ABC, abstractmethod +from concurrent.futures import ThreadPoolExecutor from typing import Any, List, Optional from dbgpt._private.pydantic import ConfigDict, Field @@ -9,6 +10,7 @@ from dbgpt.core.awel.flow import Parameter from dbgpt.rag.index.base import IndexStoreBase, IndexStoreConfig from dbgpt.storage.vector_store.filters import MetadataFilters +from dbgpt.util.executor_utils import blocking_func_to_async from dbgpt.util.i18n_utils import _ logger = logging.getLogger(__name__) @@ -102,6 +104,10 @@ class VectorStoreConfig(IndexStoreConfig): class VectorStoreBase(IndexStoreBase, ABC): """Vector store base class.""" + def __init__(self, executor: Optional[ThreadPoolExecutor] = None): + """Initialize vector store.""" + super().__init__(executor) + def filter_by_score_threshold( self, chunks: List[Chunk], score_threshold: float ) -> List[Chunk]: @@ -160,7 +166,7 @@ def _default_relevance_score_fn(self, distance: float) -> float: return 1.0 - distance / math.sqrt(2) async def aload_document(self, chunks: List[Chunk]) -> List[str]: # type: ignore - """Load document in index database. + """Async load document in index database. Args: chunks(List[Chunk]): document chunks. @@ -168,4 +174,4 @@ async def aload_document(self, chunks: List[Chunk]) -> List[str]: # type: ignor Return: List[str]: chunk ids. """ - raise NotImplementedError + return await blocking_func_to_async(self._executor, self.load_document, chunks) diff --git a/dbgpt/storage/vector_store/chroma_store.py b/dbgpt/storage/vector_store/chroma_store.py index 5e921377e..930579c48 100644 --- a/dbgpt/storage/vector_store/chroma_store.py +++ b/dbgpt/storage/vector_store/chroma_store.py @@ -62,6 +62,7 @@ def __init__(self, vector_store_config: ChromaVectorConfig) -> None: Args: vector_store_config(ChromaVectorConfig): vector store config. """ + super().__init__() chroma_vector_config = vector_store_config.to_dict(exclude_none=True) chroma_path = chroma_vector_config.get( "persist_path", os.path.join(PILOT_PATH, "data") diff --git a/dbgpt/storage/vector_store/connector.py b/dbgpt/storage/vector_store/connector.py index 72e9b7a13..b93872115 100644 --- a/dbgpt/storage/vector_store/connector.py +++ b/dbgpt/storage/vector_store/connector.py @@ -170,14 +170,22 @@ def load_document(self, chunks: List[Chunk]) -> List[str]: ) async def aload_document(self, chunks: List[Chunk]) -> List[str]: - """Load document in vector database. + """Async load document in vector database. Args: - chunks: document chunks. Return chunk ids. """ - return await self.client.aload_document( - chunks, + max_chunks_once_load = ( + self._index_store_config.max_chunks_once_load + if self._index_store_config + else 10 + ) + max_threads = ( + self._index_store_config.max_threads if self._index_store_config else 1 + ) + return await self.client.aload_document_with_limit( + chunks, max_chunks_once_load, max_threads ) def similar_search( diff --git a/dbgpt/storage/vector_store/elastic_store.py b/dbgpt/storage/vector_store/elastic_store.py index 9dee184c8..328af163b 100644 --- a/dbgpt/storage/vector_store/elastic_store.py +++ b/dbgpt/storage/vector_store/elastic_store.py @@ -125,6 +125,7 @@ def __init__(self, vector_store_config: ElasticsearchVectorConfig) -> None: Args: vector_store_config (ElasticsearchVectorConfig): ElasticsearchStore config. """ + super().__init__() connect_kwargs = {} elasticsearch_vector_config = vector_store_config.dict() self.uri = elasticsearch_vector_config.get("uri") or os.getenv( diff --git a/dbgpt/storage/vector_store/milvus_store.py b/dbgpt/storage/vector_store/milvus_store.py index ea69b35fd..1e0a612b6 100644 --- a/dbgpt/storage/vector_store/milvus_store.py +++ b/dbgpt/storage/vector_store/milvus_store.py @@ -149,8 +149,14 @@ def __init__(self, vector_store_config: MilvusVectorConfig) -> None: vector_store_config (MilvusVectorConfig): MilvusStore config. refer to https://milvus.io/docs/v2.0.x/manage_connection.md """ - from pymilvus import connections - + super().__init__() + try: + from pymilvus import connections + except ImportError: + raise ValueError( + "Could not import pymilvus python package. " + "Please install it with `pip install pymilvus`." + ) connect_kwargs = {} milvus_vector_config = vector_store_config.to_dict() self.uri = milvus_vector_config.get("uri") or os.getenv( @@ -373,8 +379,13 @@ def similar_search( self, text, topk, filters: Optional[MetadataFilters] = None ) -> List[Chunk]: """Perform a search on a query string and return results.""" - from pymilvus import Collection, DataType - + try: + from pymilvus import Collection, DataType + except ImportError: + raise ValueError( + "Could not import pymilvus python package. " + "Please install it with `pip install pymilvus`." + ) """similar_search in vector database.""" self.col = Collection(self.collection_name) schema = self.col.schema @@ -419,7 +430,13 @@ def similar_search_with_scores( Returns: List[Tuple[Document, float]]: Result doc and score. """ - from pymilvus import Collection + try: + from pymilvus import Collection, DataType + except ImportError: + raise ValueError( + "Could not import pymilvus python package. " + "Please install it with `pip install pymilvus`." + ) self.col = Collection(self.collection_name) schema = self.col.schema @@ -429,7 +446,6 @@ def similar_search_with_scores( self.fields.remove(x.name) if x.is_primary: self.primary_field = x.name - from pymilvus import DataType if x.dtype == DataType.FLOAT_VECTOR or x.dtype == DataType.BINARY_VECTOR: self.vector_field = x.name @@ -526,15 +542,26 @@ def _search( def vector_name_exists(self): """Whether vector name exists.""" - from pymilvus import utility + try: + from pymilvus import utility + except ImportError: + raise ValueError( + "Could not import pymilvus python package. " + "Please install it with `pip install pymilvus`." + ) """is vector store name exist.""" return utility.has_collection(self.collection_name) def delete_vector_name(self, vector_name: str): """Delete vector name.""" - from pymilvus import utility - + try: + from pymilvus import utility + except ImportError: + raise ValueError( + "Could not import pymilvus python package. " + "Please install it with `pip install pymilvus`." + ) """milvus delete collection name""" logger.info(f"milvus vector_name:{vector_name} begin delete...") utility.drop_collection(self.collection_name) @@ -542,8 +569,13 @@ def delete_vector_name(self, vector_name: str): def delete_by_ids(self, ids): """Delete vector by ids.""" - from pymilvus import Collection - + try: + from pymilvus import Collection + except ImportError: + raise ValueError( + "Could not import pymilvus python package. " + "Please install it with `pip install pymilvus`." + ) self.col = Collection(self.collection_name) # milvus delete vectors by ids logger.info(f"begin delete milvus ids: {ids}") diff --git a/dbgpt/storage/vector_store/oceanbase_store.py b/dbgpt/storage/vector_store/oceanbase_store.py index 513a0bfc7..a52e9f31c 100644 --- a/dbgpt/storage/vector_store/oceanbase_store.py +++ b/dbgpt/storage/vector_store/oceanbase_store.py @@ -717,7 +717,7 @@ def __init__(self, vector_store_config: OceanBaseConfig) -> None: """Create a OceanBaseStore instance.""" if vector_store_config.embedding_fn is None: raise ValueError("embedding_fn is required for OceanBaseStore") - + super().__init__() self.embeddings = vector_store_config.embedding_fn self.collection_name = vector_store_config.name vector_store_config = vector_store_config.dict() diff --git a/dbgpt/storage/vector_store/pgvector_store.py b/dbgpt/storage/vector_store/pgvector_store.py index f90b4f864..7103a17f5 100644 --- a/dbgpt/storage/vector_store/pgvector_store.py +++ b/dbgpt/storage/vector_store/pgvector_store.py @@ -63,6 +63,7 @@ def __init__(self, vector_store_config: PGVectorConfig) -> None: raise ImportError( "Please install the `langchain` package to use the PGVector." ) + super().__init__() self.connection_string = vector_store_config.connection_string self.embeddings = vector_store_config.embedding_fn self.collection_name = vector_store_config.name diff --git a/dbgpt/storage/vector_store/weaviate_store.py b/dbgpt/storage/vector_store/weaviate_store.py index 8daf67827..9e646ab5e 100644 --- a/dbgpt/storage/vector_store/weaviate_store.py +++ b/dbgpt/storage/vector_store/weaviate_store.py @@ -68,7 +68,7 @@ def __init__(self, vector_store_config: WeaviateVectorConfig) -> None: "Could not import weaviate python package. " "Please install it with `pip install weaviate-client`." ) - + super().__init__() self.weaviate_url = vector_store_config.weaviate_url self.embedding = vector_store_config.embedding_fn self.vector_name = vector_store_config.name