Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ChatKnowledge): return topk document source (KBQA) #608

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions assets/schema/history.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
CREATE DATABASE history;
use history;
CREATE TABLE `chat_feed_back` (
`id` bigint(20) NOT NULL AUTO_INCREMENT,
`conv_uid` varchar(128) DEFAULT NULL COMMENT '会话id',
`conv_index` int(4) DEFAULT NULL COMMENT '第几轮会话',
`score` int(1) DEFAULT NULL COMMENT '评分',
`ques_type` varchar(32) DEFAULT NULL COMMENT '用户问题类别',
`question` longtext DEFAULT NULL COMMENT '用户问题',
`knowledge_space` varchar(128) DEFAULT NULL COMMENT '知识库',
`messages` longtext DEFAULT NULL COMMENT '评价详情',
`user_name` varchar(128) DEFAULT NULL COMMENT '评价人',
`gmt_created` datetime DEFAULT NULL,
`gmt_modified` datetime DEFAULT NULL,
PRIMARY KEY (`id`),
UNIQUE KEY `uk_conv` (`conv_uid`,`conv_index`),
KEY `idx_conv` (`conv_uid`,`conv_index`)
) ENGINE=InnoDB AUTO_INCREMENT=0 DEFAULT CHARSET=utf8mb4 COMMENT='用户评分反馈表';
3 changes: 3 additions & 0 deletions pilot/configs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,9 @@ def __init__(self) -> None:
self.KNOWLEDGE_SEARCH_MAX_TOKEN = int(
os.getenv("KNOWLEDGE_SEARCH_MAX_TOKEN", 2000)
)
### Control whether to display the source document of knowledge on the front end.
self.KNOWLEDGE_CHAT_SHOW_RELATIONS = False

### SUMMARY_CONFIG Configuration
self.SUMMARY_CONFIG = os.getenv("SUMMARY_CONFIG", "FAST")

Expand Down
36 changes: 36 additions & 0 deletions pilot/openapi/api_v1/api_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,26 @@ def plugins_select_info():
return plugins_infos


def get_db_list_info():
dbs = CFG.LOCAL_DB_MANAGE.get_db_list()
params: dict = {}
for item in dbs:
comment = item["comment"]
if comment is not None and len(comment) > 0:
params.update({item["db_name"]: comment})
return params


def knowledge_list_info():
"""return knowledge space list"""
params: dict = {}
request = KnowledgeSpaceRequest()
spaces = knowledge_service.get_knowledge_space(request)
for space in spaces:
params.update({space.name: space.desc})
return params


def knowledge_list():
"""return knowledge space list"""
params: dict = {}
Expand Down Expand Up @@ -236,6 +256,22 @@ async def params_list(chat_mode: str = ChatScene.ChatNormal.value()):
return Result.succ(None)


@router.post("/v1/chat/mode/params/info", response_model=Result[dict])
async def params_list_info(chat_mode: str = ChatScene.ChatNormal.value()):
if ChatScene.ChatWithDbQA.value() == chat_mode:
return Result.succ(get_db_list_info())
elif ChatScene.ChatWithDbExecute.value() == chat_mode:
return Result.succ(get_db_list_info())
elif ChatScene.ChatDashboard.value() == chat_mode:
return Result.succ(get_db_list_info())
elif ChatScene.ChatExecution.value() == chat_mode:
return Result.succ(plugins_select_info())
elif ChatScene.ChatKnowledge.value() == chat_mode:
return Result.succ(knowledge_list_info())
else:
return Result.succ(None)


@router.post("/v1/chat/mode/params/file/load")
async def params_load(
conv_uid: str, chat_mode: str, model_name: str, doc_file: UploadFile = File(...)
Expand Down
Empty file.
48 changes: 48 additions & 0 deletions pilot/openapi/api_v1/feedback/api_fb_v1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from fastapi import APIRouter, Body, Request

from pilot.openapi.api_v1.feedback.feed_back_model import FeedBackBody
from pilot.openapi.api_v1.feedback.feed_back_db import (
ChatFeedBackDao,
ChatFeedBackEntity,
)
from pilot.openapi.api_view_model import Result

router = APIRouter()
chat_feed_back = ChatFeedBackDao()


@router.get("/v1/feedback/find", response_model=Result[FeedBackBody])
async def feed_back_find(conv_uid: str, conv_index: int):
rt = chat_feed_back.get_chat_feed_back(conv_uid, conv_index)
if rt is not None:
return Result.succ(
FeedBackBody(
conv_uid=rt.conv_uid,
conv_index=rt.conv_index,
question=rt.question,
knowledge_space=rt.knowledge_space,
score=rt.score,
ques_type=rt.ques_type,
messages=rt.messages,
)
)
else:
return Result.succ(None)


@router.post("/v1/feedback/commit", response_model=Result[bool])
async def feed_back_commit(request: Request, feed_back_body: FeedBackBody = Body()):
chat_feed_back.create_or_update_chat_feed_back(feed_back_body)
return Result.succ(True)


@router.get("/v1/feedback/select", response_model=Result[dict])
async def feed_back_select():
return Result.succ(
{
"information": "信息查询",
"work_study": "工作学习",
"just_fun": "互动闲聊",
"others": "其他",
}
)
84 changes: 84 additions & 0 deletions pilot/openapi/api_v1/feedback/feed_back_db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from datetime import datetime

from sqlalchemy import Column, Integer, Text, String, DateTime
from sqlalchemy.ext.declarative import declarative_base

from pilot.connections.rdbms.base_dao import BaseDao
from pilot.openapi.api_v1.feedback.feed_back_model import FeedBackBody

Base = declarative_base()


class ChatFeedBackEntity(Base):
__tablename__ = "chat_feed_back"
id = Column(Integer, primary_key=True)
conv_uid = Column(String(128))
conv_index = Column(Integer)
score = Column(Integer)
ques_type = Column(String(32))
question = Column(Text)
knowledge_space = Column(String(128))
messages = Column(Text)
user_name = Column(String(128))
gmt_created = Column(DateTime)
gmt_modified = Column(DateTime)

def __repr__(self):
return (
f"ChatFeekBackEntity(id={self.id}, conv_index='{self.conv_index}', conv_index='{self.conv_index}', "
f"score='{self.score}', ques_type='{self.ques_type}', question='{self.question}', knowledge_space='{self.knowledge_space}', "
f"messages='{self.messages}', user_name='{self.user_name}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')"
)


class ChatFeedBackDao(BaseDao):
def __init__(self):
super().__init__(database="history", orm_base=Base, create_not_exist_table=True)

def create_or_update_chat_feed_back(self, feed_back: FeedBackBody):
# Todo: We need to have user information first.
def_user_name = ""

session = self.Session()
chat_feed_back = ChatFeedBackEntity(
conv_uid=feed_back.conv_uid,
conv_index=feed_back.conv_index,
score=feed_back.score,
ques_type=feed_back.ques_type,
question=feed_back.question,
knowledge_space=feed_back.knowledge_space,
messages=feed_back.messages,
user_name=def_user_name,
gmt_created=datetime.now(),
gmt_modified=datetime.now(),
)
result = (
session.query(ChatFeedBackEntity)
.filter(ChatFeedBackEntity.conv_uid == feed_back.conv_uid)
.filter(ChatFeedBackEntity.conv_index == feed_back.conv_index)
.first()
)
if result is not None:
result.score = feed_back.score
result.ques_type = feed_back.ques_type
result.question = feed_back.question
result.knowledge_space = feed_back.knowledge_space
result.messages = feed_back.messages
result.user_name = def_user_name
result.gmt_created = datetime.now()
result.gmt_modified = datetime.now()
else:
session.merge(chat_feed_back)
session.commit()
session.close()

def get_chat_feed_back(self, conv_uid: str, conv_index: int):
session = self.Session()
result = (
session.query(ChatFeedBackEntity)
.filter(ChatFeedBackEntity.conv_uid == conv_uid)
.filter(ChatFeedBackEntity.conv_index == conv_index)
.first()
)
session.close()
return result
25 changes: 25 additions & 0 deletions pilot/openapi/api_v1/feedback/feed_back_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from pydantic.main import BaseModel


class FeedBackBody(BaseModel):
"""conv_uid: conversation id"""

conv_uid: str

"""conv_index: conversation index"""
conv_index: int

"""question: human question"""
question: str

"""knowledge_space: knowledge space"""
knowledge_space: str

"""score: rating of the llm's answer"""
score: int

"""ques_type: question type"""
ques_type: str

"""messages: rating detail"""
messages: str
25 changes: 24 additions & 1 deletion pilot/scene/chat_knowledge/v1/chat.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from typing import Dict

from pilot.scene.base_chat import BaseChat
Expand Down Expand Up @@ -55,6 +56,21 @@ def __init__(self, chat_param: Dict):
vector_store_config=vector_store_config,
embedding_factory=embedding_factory,
)
self.prompt_template.template_is_strict = False

async def stream_call(self):
input_values = self.generate_input_values()
async for output in super().stream_call():
# Source of knowledge file
relations = input_values.get("relations")
if (
CFG.KNOWLEDGE_CHAT_SHOW_RELATIONS
and type(relations) == list
and len(relations) > 0
and hasattr(output, "text")
):
output.text = output.text + "\trelations:" + ",".join(relations)
yield output

def generate_input_values(self):
if self.space_context:
Expand All @@ -69,7 +85,14 @@ def generate_input_values(self):
)
context = [d.page_content for d in docs]
context = context[: self.max_token]
input_values = {"context": context, "question": self.current_user_input}
relations = list(
set([os.path.basename(d.metadata.get("source")) for d in docs])
)
input_values = {
"context": context,
"question": self.current_user_input,
"relations": relations,
}
return input_values

@property
Expand Down
2 changes: 2 additions & 0 deletions pilot/server/dbgpt_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from pilot.openapi.api_v1.api_v1 import router as api_v1
from pilot.openapi.base import validation_exception_handler
from pilot.openapi.api_v1.editor.api_editor_v1 import router as api_editor_route_v1
from pilot.openapi.api_v1.feedback.api_fb_v1 import router as api_fb_v1
from pilot.commands.disply_type.show_chart_gen import static_message_img_path
from pilot.model.cluster import initialize_worker_manager_in_client
from pilot.utils.utils import setup_logging, logging_str_to_uvicorn_level
Expand Down Expand Up @@ -72,6 +73,7 @@ def swagger_monkey_patch(*args, **kwargs):
app.include_router(api_v1, prefix="/api")
app.include_router(knowledge_router, prefix="/api")
app.include_router(api_editor_route_v1, prefix="/api")
app.include_router(api_fb_v1, prefix="/api")

# app.include_router(api_v1)
app.include_router(knowledge_router)
Expand Down