Skip to content

Commit

Permalink
feat:
Browse files Browse the repository at this point in the history
1. Support score feedback
2. Add method to return topk document source (KBQA)
  • Loading branch information
cm-liushaodong committed Sep 20, 2023
1 parent 132814e commit a4832be
Show file tree
Hide file tree
Showing 9 changed files with 224 additions and 1 deletion.
18 changes: 18 additions & 0 deletions assets/schema/history.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
CREATE DATABASE history;
use history;
CREATE TABLE `chat_feed_back` (
`id` bigint(20) NOT NULL AUTO_INCREMENT,
`conv_uid` varchar(128) DEFAULT NULL COMMENT '会话id',
`conv_index` int(4) DEFAULT NULL COMMENT '第几轮会话',
`score` int(1) DEFAULT NULL COMMENT '评分',
`ques_type` varchar(32) DEFAULT NULL COMMENT '用户问题类别',
`question` longtext DEFAULT NULL COMMENT '用户问题',
`knowledge_space` varchar(128) DEFAULT NULL COMMENT '知识库',
`messages` longtext DEFAULT NULL COMMENT '评价详情',
`user_name` varchar(128) DEFAULT NULL COMMENT '评价人',
`gmt_created` datetime DEFAULT NULL,
`gmt_modified` datetime DEFAULT NULL,
PRIMARY KEY (`id`),
UNIQUE KEY `uk_conv` (`conv_uid`,`conv_index`),
KEY `idx_conv` (`conv_uid`,`conv_index`)
) ENGINE=InnoDB AUTO_INCREMENT=0 DEFAULT CHARSET=utf8mb4 COMMENT='用户评分反馈表';
3 changes: 3 additions & 0 deletions pilot/configs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,9 @@ def __init__(self) -> None:
self.KNOWLEDGE_SEARCH_MAX_TOKEN = int(
os.getenv("KNOWLEDGE_SEARCH_MAX_TOKEN", 2000)
)
### Control whether to display the source document of knowledge on the front end.
self.KNOWLEDGE_CHAT_SHOW_RELATIONS = False

### SUMMARY_CONFIG Configuration
self.SUMMARY_CONFIG = os.getenv("SUMMARY_CONFIG", "FAST")

Expand Down
20 changes: 20 additions & 0 deletions pilot/openapi/api_v1/api_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,26 @@ def plugins_select_info():
return plugins_infos


def get_db_list_info():
dbs = CFG.LOCAL_DB_MANAGE.get_db_list()
params: dict = {}
for item in dbs:
comment = item["comment"]
if comment is not None and len(comment) > 0:
params.update({item["db_name"]: comment})
return params


def knowledge_list_info():
"""return knowledge space list"""
params: dict = {}
request = KnowledgeSpaceRequest()
spaces = knowledge_service.get_knowledge_space(request)
for space in spaces:
params.update({space.name: space.desc})
return params


def knowledge_list():
"""return knowledge space list"""
params: dict = {}
Expand Down
Empty file.
48 changes: 48 additions & 0 deletions pilot/openapi/api_v1/feedback/api_fb_v1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from fastapi import APIRouter, Body, Request

from pilot.openapi.api_v1.feedback.feed_back_model import FeedBackBody
from pilot.openapi.api_v1.feedback.feed_back_db import (
ChatFeedBackDao,
ChatFeedBackEntity,
)
from pilot.openapi.api_view_model import Result

router = APIRouter()
chat_feed_back = ChatFeedBackDao()


@router.get("/v1/feedback/find", response_model=Result[FeedBackBody])
async def feed_back_find(conv_uid: str, conv_index: int):
rt = chat_feed_back.get_chat_feed_back(conv_uid, conv_index)
if rt is not None:
return Result.succ(
FeedBackBody(
conv_uid=rt.conv_uid,
conv_index=rt.conv_index,
question=rt.question,
knowledge_space=rt.knowledge_space,
score=rt.score,
ques_type=rt.ques_type,
messages=rt.messages,
)
)
else:
return Result.succ(None)


@router.post("/v1/feedback/commit", response_model=Result[bool])
async def feed_back_commit(request: Request, feed_back_body: FeedBackBody = Body()):
chat_feed_back.create_or_update_chat_feed_back(feed_back_body)
return Result.succ(True)


@router.get("/v1/feedback/select", response_model=Result[dict])
async def feed_back_select():
return Result.succ(
{
"information": "信息查询",
"work_study": "工作学习",
"just_fun": "互动闲聊",
"others": "其他",
}
)
84 changes: 84 additions & 0 deletions pilot/openapi/api_v1/feedback/feed_back_db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from datetime import datetime

from sqlalchemy import Column, Integer, Text, String, DateTime
from sqlalchemy.ext.declarative import declarative_base

from pilot.connections.rdbms.base_dao import BaseDao
from pilot.openapi.api_v1.feedback.feed_back_model import FeedBackBody

Base = declarative_base()


class ChatFeedBackEntity(Base):
__tablename__ = "chat_feed_back"
id = Column(Integer, primary_key=True)
conv_uid = Column(String(128))
conv_index = Column(Integer)
score = Column(Integer)
ques_type = Column(String(32))
question = Column(Text)
knowledge_space = Column(String(128))
messages = Column(Text)
user_name = Column(String(128))
gmt_created = Column(DateTime)
gmt_modified = Column(DateTime)

def __repr__(self):
return (
f"ChatFeekBackEntity(id={self.id}, conv_index='{self.conv_index}', conv_index='{self.conv_index}', "
f"score='{self.score}', ques_type='{self.ques_type}', question='{self.question}', knowledge_space='{self.knowledge_space}', "
f"messages='{self.messages}', user_name='{self.user_name}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')"
)


class ChatFeedBackDao(BaseDao):
def __init__(self):
super().__init__(database="history", orm_base=Base, create_not_exist_table=True)

def create_or_update_chat_feed_back(self, feed_back: FeedBackBody):
# Todo: We need to have user information first.
def_user_name = ""

session = self.Session()
chat_feed_back = ChatFeedBackEntity(
conv_uid=feed_back.conv_uid,
conv_index=feed_back.conv_index,
score=feed_back.score,
ques_type=feed_back.ques_type,
question=feed_back.question,
knowledge_space=feed_back.knowledge_space,
messages=feed_back.messages,
user_name=def_user_name,
gmt_created=datetime.now(),
gmt_modified=datetime.now(),
)
result = (
session.query(ChatFeedBackEntity)
.filter(ChatFeedBackEntity.conv_uid == feed_back.conv_uid)
.filter(ChatFeedBackEntity.conv_index == feed_back.conv_index)
.first()
)
if result is not None:
result.score = feed_back.score
result.ques_type = feed_back.ques_type
result.question = feed_back.question
result.knowledge_space = feed_back.knowledge_space
result.messages = feed_back.messages
result.user_name = def_user_name
result.gmt_created = datetime.now()
result.gmt_modified = datetime.now()
else:
session.merge(chat_feed_back)
session.commit()
session.close()

def get_chat_feed_back(self, conv_uid: str, conv_index: int):
session = self.Session()
result = (
session.query(ChatFeedBackEntity)
.filter(ChatFeedBackEntity.conv_uid == conv_uid)
.filter(ChatFeedBackEntity.conv_index == conv_index)
.first()
)
session.close()
return result
25 changes: 25 additions & 0 deletions pilot/openapi/api_v1/feedback/feed_back_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from pydantic.main import BaseModel


class FeedBackBody(BaseModel):
"""conv_uid: conversation id"""

conv_uid: str

"""conv_index: conversation index"""
conv_index: int

"""question: human question"""
question: str

"""knowledge_space: knowledge space"""
knowledge_space: str

"""score: rating of the llm's answer"""
score: int

"""ques_type: question type"""
ques_type: str

"""messages: rating detail"""
messages: str
25 changes: 24 additions & 1 deletion pilot/scene/chat_knowledge/v1/chat.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from typing import Dict

from pilot.scene.base_chat import BaseChat
Expand Down Expand Up @@ -55,6 +56,21 @@ def __init__(self, chat_param: Dict):
vector_store_config=vector_store_config,
embedding_factory=embedding_factory,
)
self.prompt_template.template_is_strict = False

async def stream_call(self):
input_values = self.generate_input_values()
async for output in super().stream_call():
# Source of knowledge file
relations = input_values.get("relations")
if (
CFG.KNOWLEDGE_CHAT_SHOW_RELATIONS
and type(relations) == list
and len(relations) > 0
and hasattr(output, "text")
):
output.text = output.text + "\trelations:" + ",".join(relations)
yield output

def generate_input_values(self):
if self.space_context:
Expand All @@ -69,7 +85,14 @@ def generate_input_values(self):
)
context = [d.page_content for d in docs]
context = context[: self.max_token]
input_values = {"context": context, "question": self.current_user_input}
relations = list(
set([os.path.basename(d.metadata.get("source")) for d in docs])
)
input_values = {
"context": context,
"question": self.current_user_input,
"relations": relations,
}
return input_values

@property
Expand Down
2 changes: 2 additions & 0 deletions pilot/server/dbgpt_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from pilot.openapi.api_v1.api_v1 import router as api_v1
from pilot.openapi.base import validation_exception_handler
from pilot.openapi.api_v1.editor.api_editor_v1 import router as api_editor_route_v1
from pilot.openapi.api_v1.feedback.api_fb_v1 import router as api_fb_v1
from pilot.commands.disply_type.show_chart_gen import static_message_img_path
from pilot.model.cluster import initialize_worker_manager_in_client
from pilot.utils.utils import setup_logging, logging_str_to_uvicorn_level
Expand Down Expand Up @@ -72,6 +73,7 @@ def swagger_monkey_patch(*args, **kwargs):
app.include_router(api_v1, prefix="/api")
app.include_router(knowledge_router, prefix="/api")
app.include_router(api_editor_route_v1, prefix="/api")
app.include_router(api_fb_v1, prefix="/api")

# app.include_router(api_v1)
app.include_router(knowledge_router)
Expand Down

0 comments on commit a4832be

Please sign in to comment.