From 74ddb02d92ab6f145518ca8f951f8a7cbadb0988 Mon Sep 17 00:00:00 2001 From: liushaodong03 Date: Wed, 20 Sep 2023 21:31:45 +0800 Subject: [PATCH 1/2] feat: 1. Support score feedback 2. Add method to return topk document source (KBQA) --- assets/schema/history.sql | 18 +++++ pilot/configs/config.py | 3 + pilot/openapi/api_v1/api_v1.py | 32 ++++++++ pilot/openapi/api_v1/feedback/__init__.py | 0 pilot/openapi/api_v1/feedback/api_fb_v1.py | 40 ++++++++++ pilot/openapi/api_v1/feedback/feed_back_db.py | 80 +++++++++++++++++++ .../api_v1/feedback/feed_back_model.py | 24 ++++++ pilot/scene/chat_knowledge/v1/chat.py | 14 +++- pilot/server/dbgpt_server.py | 2 + 9 files changed, 212 insertions(+), 1 deletion(-) create mode 100644 assets/schema/history.sql create mode 100644 pilot/openapi/api_v1/feedback/__init__.py create mode 100644 pilot/openapi/api_v1/feedback/api_fb_v1.py create mode 100644 pilot/openapi/api_v1/feedback/feed_back_db.py create mode 100644 pilot/openapi/api_v1/feedback/feed_back_model.py diff --git a/assets/schema/history.sql b/assets/schema/history.sql new file mode 100644 index 000000000..3323b73a0 --- /dev/null +++ b/assets/schema/history.sql @@ -0,0 +1,18 @@ +CREATE DATABASE history; +use history; +CREATE TABLE `chat_feed_back` ( + `id` bigint(20) NOT NULL AUTO_INCREMENT, + `conv_uid` varchar(128) DEFAULT NULL COMMENT '会话id', + `conv_index` int(4) DEFAULT NULL COMMENT '第几轮会话', + `score` int(1) DEFAULT NULL COMMENT '评分', + `ques_type` varchar(32) DEFAULT NULL COMMENT '用户问题类别', + `question` longtext DEFAULT NULL COMMENT '用户问题', + `knowledge_space` varchar(128) DEFAULT NULL COMMENT '知识库', + `messages` longtext DEFAULT NULL COMMENT '评价详情', + `user_name` varchar(128) DEFAULT NULL COMMENT '评价人', + `gmt_created` datetime DEFAULT NULL, + `gmt_modified` datetime DEFAULT NULL, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_conv` (`conv_uid`,`conv_index`), + KEY `idx_conv` (`conv_uid`,`conv_index`) +) ENGINE=InnoDB AUTO_INCREMENT=0 DEFAULT CHARSET=utf8mb4 COMMENT='用户评分反馈表'; \ No newline at end of file diff --git a/pilot/configs/config.py b/pilot/configs/config.py index 5dde6ff99..8bd6c4007 100644 --- a/pilot/configs/config.py +++ b/pilot/configs/config.py @@ -185,6 +185,9 @@ def __init__(self) -> None: self.KNOWLEDGE_SEARCH_MAX_TOKEN = int( os.getenv("KNOWLEDGE_SEARCH_MAX_TOKEN", 2000) ) + ### Control whether to display the source document of knowledge on the front end. + self.KNOWLEDGE_CHAT_SHOW_RELATIONS = False + ### SUMMARY_CONFIG Configuration self.SUMMARY_CONFIG = os.getenv("SUMMARY_CONFIG", "FAST") diff --git a/pilot/openapi/api_v1/api_v1.py b/pilot/openapi/api_v1/api_v1.py index 111487f00..70b9d236d 100644 --- a/pilot/openapi/api_v1/api_v1.py +++ b/pilot/openapi/api_v1/api_v1.py @@ -84,6 +84,22 @@ def plugins_select_info(): plugins_infos.update({f"【{plugin._name}】=>{plugin._description}": plugin._name}) return plugins_infos +def get_db_list_info(): + dbs = CFG.LOCAL_DB_MANAGE.get_db_list() + params: dict = {} + for item in dbs: + comment = item["comment"] + if comment is not None and len(comment) > 0: + params.update({item["db_name"]: comment}) + return params +def knowledge_list_info(): + """return knowledge space list""" + params: dict = {} + request = KnowledgeSpaceRequest() + spaces = knowledge_service.get_knowledge_space(request) + for space in spaces: + params.update({space.name: space.desc}) + return params def knowledge_list(): """return knowledge space list""" @@ -236,6 +252,22 @@ async def params_list(chat_mode: str = ChatScene.ChatNormal.value()): return Result.succ(None) +@router.post("/v1/chat/mode/params/info", response_model=Result[dict]) +async def params_list_info(chat_mode: str = ChatScene.ChatNormal.value()): + if ChatScene.ChatWithDbQA.value() == chat_mode: + return Result.succ(get_db_list_info()) + elif ChatScene.ChatWithDbExecute.value() == chat_mode: + return Result.succ(get_db_list_info()) + elif ChatScene.ChatDashboard.value() == chat_mode: + return Result.succ(get_db_list_info()) + elif ChatScene.ChatExecution.value() == chat_mode: + return Result.succ(plugins_select_info()) + elif ChatScene.ChatKnowledge.value() == chat_mode: + return Result.succ(knowledge_list_info()) + else: + return Result.succ(None) + + @router.post("/v1/chat/mode/params/file/load") async def params_load( conv_uid: str, chat_mode: str, model_name: str, doc_file: UploadFile = File(...) diff --git a/pilot/openapi/api_v1/feedback/__init__.py b/pilot/openapi/api_v1/feedback/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/openapi/api_v1/feedback/api_fb_v1.py b/pilot/openapi/api_v1/feedback/api_fb_v1.py new file mode 100644 index 000000000..d380cfa32 --- /dev/null +++ b/pilot/openapi/api_v1/feedback/api_fb_v1.py @@ -0,0 +1,40 @@ +from fastapi import ( + APIRouter, + Body, + Request +) + +from pilot.openapi.api_v1.feedback.feed_back_model import FeedBackBody +from pilot.openapi.api_v1.feedback.feed_back_db import ChatFeedBackDao, ChatFeedBackEntity +from pilot.openapi.api_view_model import Result + +router = APIRouter() +chat_feed_back = ChatFeedBackDao() + + +@router.get("/v1/feedback/find", response_model=Result[FeedBackBody]) +async def feed_back_find(conv_uid: str, conv_index: int): + rt = chat_feed_back.get_chat_feed_back(conv_uid, conv_index) + if rt is not None: + return Result.succ(FeedBackBody( + conv_uid=rt.conv_uid, + conv_index=rt.conv_index, + question=rt.question, + knowledge_space=rt.knowledge_space, + score=rt.score, + ques_type=rt.ques_type, + messages=rt.messages + )) + else: + return Result.succ(None) + + +@router.post("/v1/feedback/commit", response_model=Result[bool]) +async def feed_back_commit(request: Request, feed_back_body: FeedBackBody = Body()): + chat_feed_back.create_or_update_chat_feed_back(feed_back_body) + return Result.succ(True) + + +@router.get("/v1/feedback/select", response_model=Result[dict]) +async def feed_back_select(): + return Result.succ({'information': '信息查询', 'work_study': '工作学习', 'just_fun': '互动闲聊', 'others': '其他'}) diff --git a/pilot/openapi/api_v1/feedback/feed_back_db.py b/pilot/openapi/api_v1/feedback/feed_back_db.py new file mode 100644 index 000000000..3cbe1ce67 --- /dev/null +++ b/pilot/openapi/api_v1/feedback/feed_back_db.py @@ -0,0 +1,80 @@ +from datetime import datetime + +from sqlalchemy import Column, Integer, Text, String, DateTime +from sqlalchemy.ext.declarative import declarative_base + +from pilot.connections.rdbms.base_dao import BaseDao +from pilot.openapi.api_v1.feedback.feed_back_model import FeedBackBody + +Base = declarative_base() + + +class ChatFeedBackEntity(Base): + __tablename__ = "chat_feed_back" + id = Column(Integer, primary_key=True) + conv_uid = Column(String(128)) + conv_index = Column(Integer) + score = Column(Integer) + ques_type = Column(String(32)) + question = Column(Text) + knowledge_space = Column(String(128)) + messages = Column(Text) + user_name = Column(String(128)) + gmt_created = Column(DateTime) + gmt_modified = Column(DateTime) + + def __repr__(self): + return (f"ChatFeekBackEntity(id={self.id}, conv_index='{self.conv_index}', conv_index='{self.conv_index}', " + f"score='{self.score}', ques_type='{self.ques_type}', question='{self.question}', knowledge_space='{self.knowledge_space}', " + f"messages='{self.messages}', user_name='{self.user_name}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')") + + +class ChatFeedBackDao(BaseDao): + def __init__(self): + super().__init__( + database="history", orm_base=Base, create_not_exist_table=True + ) + + def create_or_update_chat_feed_back(self, feed_back: FeedBackBody): + # Todo: We need to have user information first. + def_user_name = "" + + session = self.Session() + chat_feed_back = ChatFeedBackEntity( + conv_uid=feed_back.conv_uid, + conv_index=feed_back.conv_index, + score=feed_back.score, + ques_type=feed_back.ques_type, + question=feed_back.question, + knowledge_space=feed_back.knowledge_space, + messages=feed_back.messages, + user_name=def_user_name, + gmt_created=datetime.now(), + gmt_modified=datetime.now(), + ) + result = (session.query(ChatFeedBackEntity) + .filter(ChatFeedBackEntity.conv_uid == feed_back.conv_uid) + .filter(ChatFeedBackEntity.conv_index == feed_back.conv_index) + .first()) + if result is not None: + result.score = feed_back.score + result.ques_type = feed_back.ques_type + result.question = feed_back.question + result.knowledge_space = feed_back.knowledge_space + result.messages = feed_back.messages + result.user_name = def_user_name + result.gmt_created = datetime.now() + result.gmt_modified = datetime.now() + else: + session.merge(chat_feed_back) + session.commit() + session.close() + + def get_chat_feed_back(self, conv_uid: str, conv_index: int): + session = self.Session() + result = (session.query(ChatFeedBackEntity) + .filter(ChatFeedBackEntity.conv_uid == conv_uid) + .filter(ChatFeedBackEntity.conv_index == conv_index) + .first()) + session.close() + return result diff --git a/pilot/openapi/api_v1/feedback/feed_back_model.py b/pilot/openapi/api_v1/feedback/feed_back_model.py new file mode 100644 index 000000000..fe04ab23b --- /dev/null +++ b/pilot/openapi/api_v1/feedback/feed_back_model.py @@ -0,0 +1,24 @@ +from pydantic.main import BaseModel + + +class FeedBackBody(BaseModel): + """conv_uid: conversation id""" + conv_uid: str + + """conv_index: conversation index""" + conv_index: int + + """question: human question""" + question: str + + """knowledge_space: knowledge space""" + knowledge_space: str + + """score: rating of the llm's answer""" + score: int + + """ques_type: question type""" + ques_type: str + + """messages: rating detail""" + messages: str diff --git a/pilot/scene/chat_knowledge/v1/chat.py b/pilot/scene/chat_knowledge/v1/chat.py index 230ae1523..626c1cd63 100644 --- a/pilot/scene/chat_knowledge/v1/chat.py +++ b/pilot/scene/chat_knowledge/v1/chat.py @@ -1,3 +1,4 @@ +import os from typing import Dict from pilot.scene.base_chat import BaseChat @@ -55,6 +56,16 @@ def __init__(self, chat_param: Dict): vector_store_config=vector_store_config, embedding_factory=embedding_factory, ) + self.prompt_template.template_is_strict=False + + async def stream_call(self): + input_values = self.generate_input_values() + async for output in super().stream_call(): + # Source of knowledge file + relations = input_values.get("relations") + if CFG.KNOWLEDGE_CHAT_SHOW_RELATIONS and type(relations) == list and len(relations) > 0 and hasattr(output, 'text'): + output.text = output.text + "\trelations:" + ",".join(relations) + yield output def generate_input_values(self): if self.space_context: @@ -69,7 +80,8 @@ def generate_input_values(self): ) context = [d.page_content for d in docs] context = context[: self.max_token] - input_values = {"context": context, "question": self.current_user_input} + relations = list(set([os.path.basename(d.metadata.get('source')) for d in docs])) + input_values = {"context": context, "question": self.current_user_input, "relations": relations} return input_values @property diff --git a/pilot/server/dbgpt_server.py b/pilot/server/dbgpt_server.py index 3df72ff32..8225b2d53 100644 --- a/pilot/server/dbgpt_server.py +++ b/pilot/server/dbgpt_server.py @@ -29,6 +29,7 @@ from pilot.openapi.api_v1.api_v1 import router as api_v1 from pilot.openapi.base import validation_exception_handler from pilot.openapi.api_v1.editor.api_editor_v1 import router as api_editor_route_v1 +from pilot.openapi.api_v1.feedback.api_fb_v1 import router as api_fb_v1 from pilot.commands.disply_type.show_chart_gen import static_message_img_path from pilot.model.cluster import initialize_worker_manager_in_client from pilot.utils.utils import setup_logging, logging_str_to_uvicorn_level @@ -72,6 +73,7 @@ def swagger_monkey_patch(*args, **kwargs): app.include_router(api_v1, prefix="/api") app.include_router(knowledge_router, prefix="/api") app.include_router(api_editor_route_v1, prefix="/api") +app.include_router(api_fb_v1, prefix="/api") # app.include_router(api_v1) app.include_router(knowledge_router) From 2765ab494d4e1ed519ee17eb92896bd2b218136c Mon Sep 17 00:00:00 2001 From: liushaodong03 Date: Wed, 20 Sep 2023 22:11:40 +0800 Subject: [PATCH 2/2] feat: 1. Support score feedback 2. Add method to return topk document source (KBQA) --- pilot/openapi/api_v1/api_v1.py | 4 ++ pilot/openapi/api_v1/feedback/api_fb_v1.py | 40 +++++++++++-------- pilot/openapi/api_v1/feedback/feed_back_db.py | 34 +++++++++------- .../api_v1/feedback/feed_back_model.py | 15 +++---- pilot/scene/chat_knowledge/v1/chat.py | 19 +++++++-- 5 files changed, 70 insertions(+), 42 deletions(-) diff --git a/pilot/openapi/api_v1/api_v1.py b/pilot/openapi/api_v1/api_v1.py index 70b9d236d..4273bc2d8 100644 --- a/pilot/openapi/api_v1/api_v1.py +++ b/pilot/openapi/api_v1/api_v1.py @@ -84,6 +84,7 @@ def plugins_select_info(): plugins_infos.update({f"【{plugin._name}】=>{plugin._description}": plugin._name}) return plugins_infos + def get_db_list_info(): dbs = CFG.LOCAL_DB_MANAGE.get_db_list() params: dict = {} @@ -92,6 +93,8 @@ def get_db_list_info(): if comment is not None and len(comment) > 0: params.update({item["db_name"]: comment}) return params + + def knowledge_list_info(): """return knowledge space list""" params: dict = {} @@ -101,6 +104,7 @@ def knowledge_list_info(): params.update({space.name: space.desc}) return params + def knowledge_list(): """return knowledge space list""" params: dict = {} diff --git a/pilot/openapi/api_v1/feedback/api_fb_v1.py b/pilot/openapi/api_v1/feedback/api_fb_v1.py index d380cfa32..d185cff11 100644 --- a/pilot/openapi/api_v1/feedback/api_fb_v1.py +++ b/pilot/openapi/api_v1/feedback/api_fb_v1.py @@ -1,11 +1,10 @@ -from fastapi import ( - APIRouter, - Body, - Request -) +from fastapi import APIRouter, Body, Request from pilot.openapi.api_v1.feedback.feed_back_model import FeedBackBody -from pilot.openapi.api_v1.feedback.feed_back_db import ChatFeedBackDao, ChatFeedBackEntity +from pilot.openapi.api_v1.feedback.feed_back_db import ( + ChatFeedBackDao, + ChatFeedBackEntity, +) from pilot.openapi.api_view_model import Result router = APIRouter() @@ -16,15 +15,17 @@ async def feed_back_find(conv_uid: str, conv_index: int): rt = chat_feed_back.get_chat_feed_back(conv_uid, conv_index) if rt is not None: - return Result.succ(FeedBackBody( - conv_uid=rt.conv_uid, - conv_index=rt.conv_index, - question=rt.question, - knowledge_space=rt.knowledge_space, - score=rt.score, - ques_type=rt.ques_type, - messages=rt.messages - )) + return Result.succ( + FeedBackBody( + conv_uid=rt.conv_uid, + conv_index=rt.conv_index, + question=rt.question, + knowledge_space=rt.knowledge_space, + score=rt.score, + ques_type=rt.ques_type, + messages=rt.messages, + ) + ) else: return Result.succ(None) @@ -37,4 +38,11 @@ async def feed_back_commit(request: Request, feed_back_body: FeedBackBody = Body @router.get("/v1/feedback/select", response_model=Result[dict]) async def feed_back_select(): - return Result.succ({'information': '信息查询', 'work_study': '工作学习', 'just_fun': '互动闲聊', 'others': '其他'}) + return Result.succ( + { + "information": "信息查询", + "work_study": "工作学习", + "just_fun": "互动闲聊", + "others": "其他", + } + ) diff --git a/pilot/openapi/api_v1/feedback/feed_back_db.py b/pilot/openapi/api_v1/feedback/feed_back_db.py index 3cbe1ce67..2b57c4bde 100644 --- a/pilot/openapi/api_v1/feedback/feed_back_db.py +++ b/pilot/openapi/api_v1/feedback/feed_back_db.py @@ -24,19 +24,19 @@ class ChatFeedBackEntity(Base): gmt_modified = Column(DateTime) def __repr__(self): - return (f"ChatFeekBackEntity(id={self.id}, conv_index='{self.conv_index}', conv_index='{self.conv_index}', " - f"score='{self.score}', ques_type='{self.ques_type}', question='{self.question}', knowledge_space='{self.knowledge_space}', " - f"messages='{self.messages}', user_name='{self.user_name}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')") + return ( + f"ChatFeekBackEntity(id={self.id}, conv_index='{self.conv_index}', conv_index='{self.conv_index}', " + f"score='{self.score}', ques_type='{self.ques_type}', question='{self.question}', knowledge_space='{self.knowledge_space}', " + f"messages='{self.messages}', user_name='{self.user_name}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')" + ) class ChatFeedBackDao(BaseDao): def __init__(self): - super().__init__( - database="history", orm_base=Base, create_not_exist_table=True - ) + super().__init__(database="history", orm_base=Base, create_not_exist_table=True) def create_or_update_chat_feed_back(self, feed_back: FeedBackBody): - # Todo: We need to have user information first. + # Todo: We need to have user information first. def_user_name = "" session = self.Session() @@ -52,10 +52,12 @@ def create_or_update_chat_feed_back(self, feed_back: FeedBackBody): gmt_created=datetime.now(), gmt_modified=datetime.now(), ) - result = (session.query(ChatFeedBackEntity) - .filter(ChatFeedBackEntity.conv_uid == feed_back.conv_uid) - .filter(ChatFeedBackEntity.conv_index == feed_back.conv_index) - .first()) + result = ( + session.query(ChatFeedBackEntity) + .filter(ChatFeedBackEntity.conv_uid == feed_back.conv_uid) + .filter(ChatFeedBackEntity.conv_index == feed_back.conv_index) + .first() + ) if result is not None: result.score = feed_back.score result.ques_type = feed_back.ques_type @@ -72,9 +74,11 @@ def create_or_update_chat_feed_back(self, feed_back: FeedBackBody): def get_chat_feed_back(self, conv_uid: str, conv_index: int): session = self.Session() - result = (session.query(ChatFeedBackEntity) - .filter(ChatFeedBackEntity.conv_uid == conv_uid) - .filter(ChatFeedBackEntity.conv_index == conv_index) - .first()) + result = ( + session.query(ChatFeedBackEntity) + .filter(ChatFeedBackEntity.conv_uid == conv_uid) + .filter(ChatFeedBackEntity.conv_index == conv_index) + .first() + ) session.close() return result diff --git a/pilot/openapi/api_v1/feedback/feed_back_model.py b/pilot/openapi/api_v1/feedback/feed_back_model.py index fe04ab23b..fabc30c09 100644 --- a/pilot/openapi/api_v1/feedback/feed_back_model.py +++ b/pilot/openapi/api_v1/feedback/feed_back_model.py @@ -2,23 +2,24 @@ class FeedBackBody(BaseModel): - """conv_uid: conversation id""" + """conv_uid: conversation id""" + conv_uid: str - """conv_index: conversation index""" + """conv_index: conversation index""" conv_index: int - """question: human question""" + """question: human question""" question: str - """knowledge_space: knowledge space""" + """knowledge_space: knowledge space""" knowledge_space: str - """score: rating of the llm's answer""" + """score: rating of the llm's answer""" score: int - """ques_type: question type""" + """ques_type: question type""" ques_type: str - """messages: rating detail""" + """messages: rating detail""" messages: str diff --git a/pilot/scene/chat_knowledge/v1/chat.py b/pilot/scene/chat_knowledge/v1/chat.py index 626c1cd63..18d7e5060 100644 --- a/pilot/scene/chat_knowledge/v1/chat.py +++ b/pilot/scene/chat_knowledge/v1/chat.py @@ -56,14 +56,19 @@ def __init__(self, chat_param: Dict): vector_store_config=vector_store_config, embedding_factory=embedding_factory, ) - self.prompt_template.template_is_strict=False + self.prompt_template.template_is_strict = False async def stream_call(self): input_values = self.generate_input_values() async for output in super().stream_call(): # Source of knowledge file relations = input_values.get("relations") - if CFG.KNOWLEDGE_CHAT_SHOW_RELATIONS and type(relations) == list and len(relations) > 0 and hasattr(output, 'text'): + if ( + CFG.KNOWLEDGE_CHAT_SHOW_RELATIONS + and type(relations) == list + and len(relations) > 0 + and hasattr(output, "text") + ): output.text = output.text + "\trelations:" + ",".join(relations) yield output @@ -80,8 +85,14 @@ def generate_input_values(self): ) context = [d.page_content for d in docs] context = context[: self.max_token] - relations = list(set([os.path.basename(d.metadata.get('source')) for d in docs])) - input_values = {"context": context, "question": self.current_user_input, "relations": relations} + relations = list( + set([os.path.basename(d.metadata.get("source")) for d in docs]) + ) + input_values = { + "context": context, + "question": self.current_user_input, + "relations": relations, + } return input_values @property