Skip to content

Commit

Permalink
fix(ChatKnowledge): add aload_document (#1548)
Browse files Browse the repository at this point in the history
  • Loading branch information
Aries-ckt authored May 23, 2024
1 parent 7f55aa4 commit 83d7e9d
Show file tree
Hide file tree
Showing 14 changed files with 180 additions and 238 deletions.
26 changes: 22 additions & 4 deletions dbgpt/app/knowledge/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
EMBEDDING_MODEL_CONFIG,
KNOWLEDGE_UPLOAD_ROOT_PATH,
)
from dbgpt.rag import ChunkParameters
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
from dbgpt.rag.knowledge.base import ChunkStrategy
from dbgpt.rag.knowledge.factory import KnowledgeFactory
Expand Down Expand Up @@ -235,13 +236,30 @@ async def document_upload(


@router.post("/knowledge/{space_name}/document/sync")
def document_sync(space_name: str, request: DocumentSyncRequest):
async def document_sync(
space_name: str,
request: DocumentSyncRequest,
service: Service = Depends(get_rag_service),
):
logger.info(f"Received params: {space_name}, {request}")
try:
knowledge_space_service.sync_knowledge_document(
space_name=space_name, sync_request=request
space = service.get({"name": space_name})
if space is None:
return Result.failed(code="E000X", msg=f"space {space_name} not exist")
if request.doc_ids is None or len(request.doc_ids) == 0:
return Result.failed(code="E000X", msg="doc_ids is None")
sync_request = KnowledgeSyncRequest(
doc_id=request.doc_ids[0],
space_id=str(space.id),
model_name=request.model_name,
)
return Result.succ([])
sync_request.chunk_parameters = ChunkParameters(
chunk_strategy="Automatic",
chunk_size=request.chunk_size or 512,
chunk_overlap=request.chunk_overlap or 50,
)
doc_ids = await service.sync_document(requests=[sync_request])
return Result.succ(doc_ids)
except Exception as e:
return Result.failed(code="E000X", msg=f"document sync error {e}")

Expand Down
188 changes: 1 addition & 187 deletions dbgpt/app/knowledge/service.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import json
import logging
from datetime import datetime
from typing import List

from dbgpt._private.config import Config
from dbgpt.app.knowledge.chunk_db import DocumentChunkDao, DocumentChunkEntity
Expand Down Expand Up @@ -32,13 +31,8 @@
from dbgpt.rag.assembler.summary import SummaryAssembler
from dbgpt.rag.chunk_manager import ChunkParameters
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
from dbgpt.rag.knowledge.base import ChunkStrategy, KnowledgeType
from dbgpt.rag.knowledge.base import KnowledgeType
from dbgpt.rag.knowledge.factory import KnowledgeFactory
from dbgpt.rag.text_splitter.text_splitter import (
RecursiveCharacterTextSplitter,
SpacyTextSplitter,
)
from dbgpt.serve.rag.api.schemas import KnowledgeSyncRequest
from dbgpt.serve.rag.models.models import KnowledgeSpaceDao, KnowledgeSpaceEntity
from dbgpt.serve.rag.service.service import SyncStatus
from dbgpt.storage.vector_store.base import VectorStoreConfig
Expand Down Expand Up @@ -199,186 +193,6 @@ def get_knowledge_documents(self, space, request: DocumentQueryRequest):
total = knowledge_document_dao.get_knowledge_documents_count(query)
return DocumentQueryResponse(data=data, total=total, page=page)

def batch_document_sync(
self,
space_name,
sync_requests: List[KnowledgeSyncRequest],
) -> List[int]:
"""batch sync knowledge document chunk into vector store
Args:
- space: Knowledge Space Name
- sync_requests: List[KnowledgeSyncRequest]
Returns:
- List[int]: document ids
"""
doc_ids = []
for sync_request in sync_requests:
docs = knowledge_document_dao.documents_by_ids([sync_request.doc_id])
if len(docs) == 0:
raise Exception(
f"there are document called, doc_id: {sync_request.doc_id}"
)
doc = docs[0]
if (
doc.status == SyncStatus.RUNNING.name
or doc.status == SyncStatus.FINISHED.name
):
raise Exception(
f" doc:{doc.doc_name} status is {doc.status}, can not sync"
)
chunk_parameters = sync_request.chunk_parameters
if chunk_parameters.chunk_strategy != ChunkStrategy.CHUNK_BY_SIZE.name:
space_context = self.get_space_context(space_name)
chunk_parameters.chunk_size = (
CFG.KNOWLEDGE_CHUNK_SIZE
if space_context is None
else int(space_context["embedding"]["chunk_size"])
)
chunk_parameters.chunk_overlap = (
CFG.KNOWLEDGE_CHUNK_OVERLAP
if space_context is None
else int(space_context["embedding"]["chunk_overlap"])
)
self._sync_knowledge_document(space_name, doc, chunk_parameters)
doc_ids.append(doc.id)
return doc_ids

def sync_knowledge_document(self, space_name, sync_request: DocumentSyncRequest):
"""sync knowledge document chunk into vector store
Args:
- space: Knowledge Space Name
- sync_request: DocumentSyncRequest
"""
from dbgpt.rag.text_splitter.pre_text_splitter import PreTextSplitter

doc_ids = sync_request.doc_ids
self.model_name = sync_request.model_name or CFG.LLM_MODEL
for doc_id in doc_ids:
query = KnowledgeDocumentEntity(id=doc_id)
docs = knowledge_document_dao.get_documents(query)
if len(docs) == 0:
raise Exception(
f"there are document called, doc_id: {sync_request.doc_id}"
)
doc = docs[0]
if (
doc.status == SyncStatus.RUNNING.name
or doc.status == SyncStatus.FINISHED.name
):
raise Exception(
f" doc:{doc.doc_name} status is {doc.status}, can not sync"
)

space_context = self.get_space_context(space_name)
chunk_size = (
CFG.KNOWLEDGE_CHUNK_SIZE
if space_context is None
else int(space_context["embedding"]["chunk_size"])
)
chunk_overlap = (
CFG.KNOWLEDGE_CHUNK_OVERLAP
if space_context is None
else int(space_context["embedding"]["chunk_overlap"])
)
if sync_request.chunk_size:
chunk_size = sync_request.chunk_size
if sync_request.chunk_overlap:
chunk_overlap = sync_request.chunk_overlap
separators = sync_request.separators or None
from dbgpt.rag.chunk_manager import ChunkParameters

chunk_parameters = ChunkParameters(
chunk_size=chunk_size, chunk_overlap=chunk_overlap
)
if CFG.LANGUAGE == "en":
text_splitter = RecursiveCharacterTextSplitter(
separators=separators,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
length_function=len,
)
else:
if separators and len(separators) > 1:
raise ValueError(
"SpacyTextSplitter do not support multipsle separators"
)
try:
separator = "\n\n" if not separators else separators[0]
text_splitter = SpacyTextSplitter(
separator=separator,
pipeline="zh_core_web_sm",
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
except Exception:
text_splitter = RecursiveCharacterTextSplitter(
separators=separators,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
if sync_request.pre_separator:
logger.info(f"Use preseparator, {sync_request.pre_separator}")
text_splitter = PreTextSplitter(
pre_separator=sync_request.pre_separator,
text_splitter_impl=text_splitter,
)
chunk_parameters.text_splitter = text_splitter
self._sync_knowledge_document(space_name, doc, chunk_parameters)
return doc.id

def _sync_knowledge_document(
self,
space_name,
doc: KnowledgeDocumentEntity,
chunk_parameters: ChunkParameters,
) -> List[Chunk]:
"""sync knowledge document chunk into vector store"""
embedding_factory = CFG.SYSTEM_APP.get_component(
"embedding_factory", EmbeddingFactory
)
embedding_fn = embedding_factory.create(
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
)

spaces = self.get_knowledge_space(KnowledgeSpaceRequest(name=space_name))
if len(spaces) != 1:
raise Exception(f"invalid space name:{space_name}")
space = spaces[0]

from dbgpt.storage.vector_store.base import VectorStoreConfig

config = VectorStoreConfig(
name=space.name,
embedding_fn=embedding_fn,
max_chunks_once_load=CFG.KNOWLEDGE_MAX_CHUNKS_ONCE_LOAD,
llm_client=self.llm_client,
model_name=self.model_name,
)
vector_store_connector = VectorStoreConnector(
vector_store_type=space.vector_type, vector_store_config=config
)
knowledge = KnowledgeFactory.create(
datasource=doc.content,
knowledge_type=KnowledgeType.get_by_value(doc.doc_type),
)
assembler = EmbeddingAssembler.load_from_knowledge(
knowledge=knowledge,
chunk_parameters=chunk_parameters,
embeddings=embedding_fn,
vector_store_connector=vector_store_connector,
)
chunk_docs = assembler.get_chunks()
doc.status = SyncStatus.RUNNING.name
doc.chunk_size = len(chunk_docs)
doc.gmt_modified = datetime.now()
knowledge_document_dao.update_knowledge_document(doc)
executor = CFG.SYSTEM_APP.get_component(
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
).create()
executor.submit(self.async_doc_embedding, assembler, chunk_docs, doc)
logger.info(f"begin save document chunks, doc:{doc.doc_name}")
return chunk_docs

async def document_summary(self, request: DocumentSummaryRequest):
"""get document summary
Args:
Expand Down
37 changes: 37 additions & 0 deletions dbgpt/rag/assembler/embedding.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
"""Embedding Assembler."""
from concurrent.futures import ThreadPoolExecutor
from typing import Any, List, Optional

from dbgpt.core import Chunk, Embeddings
from dbgpt.storage.vector_store.connector import VectorStoreConnector

from ...util.executor_utils import blocking_func_to_async
from ..assembler.base import BaseAssembler
from ..chunk_manager import ChunkParameters
from ..embedding.embedding_factory import DefaultEmbeddingFactory
Expand Down Expand Up @@ -98,6 +100,41 @@ def load_from_knowledge(
embeddings=embeddings,
)

@classmethod
async def aload_from_knowledge(
cls,
knowledge: Knowledge,
vector_store_connector: VectorStoreConnector,
chunk_parameters: Optional[ChunkParameters] = None,
embedding_model: Optional[str] = None,
embeddings: Optional[Embeddings] = None,
executor: Optional[ThreadPoolExecutor] = None,
) -> "EmbeddingAssembler":
"""Load document embedding into vector store from path.
Args:
knowledge: (Knowledge) Knowledge datasource.
vector_store_connector: (VectorStoreConnector) VectorStoreConnector to use.
chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for
chunking.
embedding_model: (Optional[str]) Embedding model to use.
embeddings: (Optional[Embeddings]) Embeddings to use.
executor: (Optional[ThreadPoolExecutor) ThreadPoolExecutor to use.
Returns:
EmbeddingAssembler
"""
executor = executor or ThreadPoolExecutor()
return await blocking_func_to_async(
executor,
cls,
knowledge,
vector_store_connector,
chunk_parameters,
embedding_model,
embeddings,
)

def persist(self) -> List[str]:
"""Persist chunks into vector store.
Expand Down
26 changes: 26 additions & 0 deletions dbgpt/rag/index/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field, model_to_dict
from dbgpt.core import Chunk, Embeddings
from dbgpt.storage.vector_store.filters import MetadataFilters
from dbgpt.util.executor_utils import blocking_func_to_async

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -46,6 +47,10 @@ def to_dict(self, **kwargs) -> Dict[str, Any]:
class IndexStoreBase(ABC):
"""Index store base class."""

def __init__(self, executor: Optional[ThreadPoolExecutor] = None):
"""Init index store."""
self._executor = executor or ThreadPoolExecutor()

@abstractmethod
def load_document(self, chunks: List[Chunk]) -> List[str]:
"""Load document in index database.
Expand Down Expand Up @@ -143,6 +148,27 @@ def load_document_with_limit(
)
return ids

async def aload_document_with_limit(
self, chunks: List[Chunk], max_chunks_once_load: int = 10, max_threads: int = 1
) -> List[str]:
"""Load document in index database with specified limit.
Args:
chunks(List[Chunk]): Document chunks.
max_chunks_once_load(int): Max number of chunks to load at once.
max_threads(int): Max number of threads to use.
Return:
List[str]: Chunk ids.
"""
return await blocking_func_to_async(
self._executor,
self.load_document_with_limit,
chunks,
max_chunks_once_load,
max_threads,
)

def similar_search(
self, text: str, topk: int, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
Expand Down
Loading

0 comments on commit 83d7e9d

Please sign in to comment.