Skip to content

Commit

Permalink
feat(rag): New knowledge config api
Browse files Browse the repository at this point in the history
  • Loading branch information
fangyinc committed Jul 20, 2024
1 parent 8970e93 commit 7d8f1f5
Show file tree
Hide file tree
Showing 7 changed files with 138 additions and 23 deletions.
63 changes: 62 additions & 1 deletion dbgpt/app/knowledge/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,27 @@
from dbgpt.app.knowledge.service import KnowledgeService
from dbgpt.app.openapi.api_v1.api_v1 import no_stream_generator, stream_generator
from dbgpt.app.openapi.api_view_model import Result
from dbgpt.configs import TAG_KEY_KNOWLEDGE_FACTORY_DOMAIN_TYPE
from dbgpt.configs.model_config import (
EMBEDDING_MODEL_CONFIG,
KNOWLEDGE_UPLOAD_ROOT_PATH,
)
from dbgpt.core.awel.dag.dag_manager import DAGManager
from dbgpt.rag import ChunkParameters
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
from dbgpt.rag.knowledge.base import ChunkStrategy, KnowledgeType
from dbgpt.rag.knowledge.factory import KnowledgeFactory
from dbgpt.rag.retriever.embedding import EmbeddingRetriever
from dbgpt.serve.rag.api.schemas import KnowledgeSyncRequest
from dbgpt.serve.rag.api.schemas import (
KnowledgeConfigResponse,
KnowledgeDomainType,
KnowledgeStorageType,
KnowledgeSyncRequest,
)
from dbgpt.serve.rag.connector import VectorStoreConnector
from dbgpt.serve.rag.service.service import Service
from dbgpt.storage.vector_store.base import VectorStoreConfig
from dbgpt.util.i18n_utils import _
from dbgpt.util.tracer import SpanType, root_tracer

logger = logging.getLogger(__name__)
Expand All @@ -53,6 +61,11 @@ def get_rag_service() -> Service:
return Service.get_instance(CFG.SYSTEM_APP)


def get_dag_manager() -> DAGManager:
"""Get DAG Manager."""
return DAGManager.get_instance(CFG.SYSTEM_APP)


@router.post("/knowledge/space/add")
def space_add(request: KnowledgeSpaceRequest):
print(f"/space/add params: {request}")
Expand Down Expand Up @@ -148,6 +161,54 @@ def chunk_strategies():
return Result.failed(code="E000X", msg=f"chunk strategies error {e}")


@router.get("/knowledge/space/config", response_model=Result[KnowledgeConfigResponse])
async def space_config() -> Result[KnowledgeConfigResponse]:
"""Get space config"""
try:
storage_list: List[KnowledgeStorageType] = []
dag_manager: DAGManager = get_dag_manager()
# Vector Storage
vs_domain_types = [KnowledgeDomainType(name="Normal", desc="Normal")]
dag_map = dag_manager.get_dags_by_tag_key(TAG_KEY_KNOWLEDGE_FACTORY_DOMAIN_TYPE)
for domain_type, dags in dag_map.items():
vs_domain_types.append(
KnowledgeDomainType(
name=domain_type, desc=dags[0].description or domain_type
)
)

storage_list.append(
KnowledgeStorageType(
name="VectorStore",
desc=_("Vector Store"),
domain_types=vs_domain_types,
)
)
# Graph Storage
storage_list.append(
KnowledgeStorageType(
name="KnowledgeGraph",
desc=_("Knowledge Graph"),
domain_types=[KnowledgeDomainType(name="Normal", desc="Normal")],
)
)
storage_list.append(
KnowledgeStorageType(
name="FullText",
desc=_("Full Text"),
domain_types=[KnowledgeDomainType(name="Normal", desc="Normal")],
)
)

return Result.succ(
KnowledgeConfigResponse(
storage=storage_list,
)
)
except Exception as e:
return Result.failed(code="E000X", msg=f"space config error {e}")


@router.post("/knowledge/{space_name}/document/list")
def document_list(space_name: str, query_request: DocumentQueryRequest):
print(f"/document/list params: {space_name}, {query_request}")
Expand Down
57 changes: 37 additions & 20 deletions dbgpt/app/openapi/api_v1/api_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import uuid
from concurrent.futures import Executor
from typing import List, Optional
from typing import List, Optional, cast

import aiofiles
from fastapi import APIRouter, Body, Depends, File, UploadFile
Expand All @@ -21,8 +21,11 @@
)
from dbgpt.app.scene import BaseChat, ChatFactory, ChatScene
from dbgpt.component import ComponentType
from dbgpt.configs import TAG_KEY_KNOWLEDGE_CHAT_DOMAIN_TYPE
from dbgpt.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH
from dbgpt.core.awel import CommonLLMHttpRequestBody
from dbgpt.core.awel import BaseOperator, CommonLLMHttpRequestBody
from dbgpt.core.awel.dag.dag_manager import DAGManager
from dbgpt.core.awel.util.chat_util import safe_chat_stream_with_dag_task
from dbgpt.core.schema.api import (
ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse,
Expand Down Expand Up @@ -127,6 +130,11 @@ def get_worker_manager() -> WorkerManager:
return worker_manager


def get_dag_manager() -> DAGManager:
"""Get the global default DAGManager"""
return DAGManager.get_instance(CFG.SYSTEM_APP)


def get_chat_flow() -> FlowService:
"""Get Chat Flow Service."""
return FlowService.get_instance(CFG.SYSTEM_APP)
Expand Down Expand Up @@ -252,7 +260,7 @@ async def params_load(
sys_code: Optional[str] = None,
doc_file: UploadFile = File(...),
):
print(f"params_load: {conv_uid},{chat_mode},{model_name}")
logger.info(f"params_load: {conv_uid},{chat_mode},{model_name}")
try:
if doc_file:
# Save the uploaded file
Expand Down Expand Up @@ -335,7 +343,7 @@ async def chat_completions(
dialogue: ConversationVo = Body(),
flow_service: FlowService = Depends(get_chat_flow),
):
print(
logger.info(
f"chat_completions:{dialogue.chat_mode},{dialogue.select_param},{dialogue.model_name}"
)
headers = {
Expand All @@ -344,6 +352,7 @@ async def chat_completions(
"Connection": "keep-alive",
"Transfer-Encoding": "chunked",
}
domain_type = _parse_domain_type(dialogue)
if dialogue.chat_mode == ChatScene.ChatAgent.value():
return StreamingResponse(
multi_agents.app_agent_chat(
Expand Down Expand Up @@ -378,9 +387,9 @@ async def chat_completions(
headers=headers,
media_type="text/event-stream",
)
elif is_fin_report_chat(dialogue):
elif domain_type is not None:
return StreamingResponse(
chat_with_business_flow(dialogue),
chat_with_domain_flow(dialogue, domain_type),
headers=headers,
media_type="text/event-stream",
)
Expand Down Expand Up @@ -494,8 +503,9 @@ def message2Vo(message: dict, order, model_name) -> MessageVo:
)


def is_fin_report_chat(dialogue: ConversationVo):
def _parse_domain_type(dialogue: ConversationVo) -> Optional[str]:
if dialogue.chat_mode == ChatScene.ChatKnowledge.value():
# Supported in the knowledge chat
space_name = dialogue.select_param
spaces = knowledge_service.get_knowledge_space(
KnowledgeSpaceRequest(name=space_name)
Expand All @@ -504,18 +514,24 @@ def is_fin_report_chat(dialogue: ConversationVo):
return Result.failed(
code="E000X", msg=f"Knowledge space {space_name} not found"
)
if (
spaces[0].field_type
and spaces[0].field_type == BusinessFieldType.FINANCIAL_REPORT.value
):
return True
return False
if spaces[0].field_type:
return spaces[0].field_type
else:
return None


async def chat_with_business_flow(dialogue: ConversationVo):
"""Call the chat module"""
async def chat_with_domain_flow(dialogue: ConversationVo, domain_type: str):
"""Chat with domain flow"""
dag_manager = get_dag_manager()
dags = dag_manager.get_dags_by_tag(TAG_KEY_KNOWLEDGE_CHAT_DOMAIN_TYPE, domain_type)
if not dags or not dags[0].leaf_nodes:
raise ValueError(f"Cant find the DAG for domain type {domain_type}")

end_task = cast(BaseOperator, dags[0].leaf_nodes[0])

space = dialogue.select_param
connector_manager = CFG.local_db_manager
# TODO: Some flow maybe not connector
db_list = [item["db_name"] for item in connector_manager.get_db_list()]
db_names = [item for item in db_list if space in item]
if len(db_names) == 0:
Expand All @@ -534,11 +550,12 @@ async def chat_with_business_flow(dialogue: ConversationVo):
sys_code=dialogue.sys_code,
incremental=dialogue.incremental,
)
flow_service = FlowService.get_instance(CFG.SYSTEM_APP)
async for output in flow_service.chat_stream_flow_str(
"51166a4d-f59a-448f-994e-8f21b05ba1f9", request
):
text = output
async for output in safe_chat_stream_with_dag_task(end_task, request, False):
text = output.text
if text:
text = text.replace("\n", "\\n")
if output.error_code != 0:
yield f"data:[SERVER_ERROR]{text}\n\n"
break
else:
yield f"data:{text}\n\n"
7 changes: 7 additions & 0 deletions dbgpt/core/awel/dag/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,10 +618,12 @@ def __init__(
dag_id: str,
resource_group: Optional[ResourceGroup] = None,
tags: Optional[Dict[str, str]] = None,
description: Optional[str] = None,
) -> None:
"""Initialize a DAG."""
self._dag_id = dag_id
self._tags: Dict[str, str] = tags or {}
self._description = description
self.node_map: Dict[str, DAGNode] = {}
self.node_name_to_node: Dict[str, DAGNode] = {}
self._root_nodes: List[DAGNode] = []
Expand Down Expand Up @@ -661,6 +663,11 @@ def tags(self) -> Dict[str, str]:
"""Return the tags of current DAG."""
return self._tags

@property
def description(self) -> Optional[str]:
"""Return the description of current DAG."""
return self._description

@property
def dev_mode(self) -> bool:
"""Whether the current DAG is in dev mode.
Expand Down
9 changes: 9 additions & 0 deletions dbgpt/core/awel/dag/dag_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,15 @@ def get_dags_by_tag(self, tag_key: str, tag_value) -> List[DAG]:
dag_ids = self._tags_to_dag_ids.get(tag_key, {}).get(tag_value, set())
return [self.dag_map[dag_id] for dag_id in dag_ids]

def get_dags_by_tag_key(self, tag_key: str) -> Dict[str, List[DAG]]:
"""Get all DAGs with the given tag key."""
with self.lock:
value_dict = self._tags_to_dag_ids.get(tag_key, {})
result = {}
for k, v in value_dict.items():
result[k] = [self.dag_map[dag_id] for dag_id in v]
return result

def get_dag_metadata(
self, dag_id: Optional[str] = None, alias_name: Optional[str] = None
) -> Optional[DAGMetadata]:
Expand Down
1 change: 1 addition & 0 deletions dbgpt/rag/knowledge/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class KnowledgeType(Enum):
DOCUMENT = "DOCUMENT"
URL = "URL"
TEXT = "TEXT"
# TODO: Remove this type
FIN_REPORT = "FIN_REPORT"

@property
Expand Down
23 changes: 22 additions & 1 deletion dbgpt/serve/rag/api/schemas.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import List, Optional

from fastapi import File, UploadFile

Expand Down Expand Up @@ -132,3 +132,24 @@ class DocumentVO(BaseModel):
summary: Optional[str] = Field(None, description="document summary")
gmt_created: str = Field(..., description="document create time")
gmt_modified: str = Field(..., description="document modify time")


class KnowledgeDomainType(BaseModel):
"""Knowledge domain type"""

name: str = Field(..., description="The domain type name")
desc: str = Field(..., description="The domain type description")


class KnowledgeStorageType(BaseModel):
"""Knowledge storage type"""

name: str = Field(..., description="The storage type name")
desc: str = Field(..., description="The storage type description")
domain_types: List[KnowledgeDomainType] = Field(..., description="The domain types")


class KnowledgeConfigResponse(BaseModel):
"""Knowledge config response"""

storage: List[KnowledgeStorageType] = Field(..., description="The storage types")
1 change: 0 additions & 1 deletion dbgpt/serve/rag/service/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
from dbgpt.storage.metadata import BaseDao
from dbgpt.storage.metadata._base_dao import QUERY_SPEC
from dbgpt.storage.vector_store.base import VectorStoreConfig
from dbgpt.util.dbgpts.loader import DBGPTsLoader
from dbgpt.util.pagination_utils import PaginationResult
from dbgpt.util.tracer import root_tracer, trace

Expand Down

0 comments on commit 7d8f1f5

Please sign in to comment.