Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(core): Support system code for feedback and prompt #873

Merged
merged 1 commit into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ DB-GPT是一个开源的数据库领域大模型框架。目的是构建大模
- [internlm-chat-20b](https://huggingface.co/internlm/internlm-chat-20b)
- [qwen-7b-chat](https://huggingface.co/Qwen/Qwen-7B-Chat)
- [qwen-14b-chat](https://huggingface.co/Qwen/Qwen-14B-Chat)
- [qwen-72b-chat](https://huggingface.co/Qwen/Qwen-72B-Chat)
- [wizardlm-13b](https://huggingface.co/WizardLM/WizardLM-13B-V1.2)
- [orca-2-7b](https://huggingface.co/microsoft/Orca-2-7b)
- [orca-2-13b](https://huggingface.co/microsoft/Orca-2-13b)
Expand Down
12 changes: 12 additions & 0 deletions pilot/configs/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,18 @@ def get_device() -> str:
"qwen-14b-chat-int8": os.path.join(MODEL_PATH, "Qwen-14B-Chat-Int8"),
# https://huggingface.co/Qwen/Qwen-14B-Chat-Int4
"qwen-14b-chat-int4": os.path.join(MODEL_PATH, "Qwen-14B-Chat-Int4"),
# https://huggingface.co/Qwen/Qwen-72B-Chat
"qwen-72b-chat": os.path.join(MODEL_PATH, "Qwen-72B-Chat"),
# https://huggingface.co/Qwen/Qwen-72B-Chat-Int8
"qwen-72b-chat-int8": os.path.join(MODEL_PATH, "Qwen-72B-Chat-Int8"),
# https://huggingface.co/Qwen/Qwen-72B-Chat-Int4
"qwen-72b-chat-int4": os.path.join(MODEL_PATH, "Qwen-72B-Chat-Int4"),
# https://huggingface.co/Qwen/Qwen-1_8B-Chat
"qwen-1.8b-chat": os.path.join(MODEL_PATH, "Qwen-1_8B-Chat"),
# https://huggingface.co/Qwen/Qwen-1_8B-Chat-Int8
"qwen-1.8b-chat-int8": os.path.join(MODEL_PATH, "wen-1_8B-Chat-Int8"),
# https://huggingface.co/Qwen/Qwen-1_8B-Chat-Int4
"qwen-1.8b-chat-int4": os.path.join(MODEL_PATH, "Qwen-1_8B-Chat-Int4"),
# (Llama2 based) We only support WizardLM-13B-V1.2 for now, which is trained from Llama-2 13b, see https://huggingface.co/WizardLM/WizardLM-13B-V1.2
"wizardlm-13b": os.path.join(MODEL_PATH, "WizardLM-13B-V1.2"),
# wget https://huggingface.co/TheBloke/vicuna-13B-v1.5-GGUF/resolve/main/vicuna-13b-v1.5.Q4_K_M.gguf -O models/ggml-model-q4_0.gguf
Expand Down
5 changes: 2 additions & 3 deletions pilot/openapi/api_v1/feedback/feed_back_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def __init__(self):

def create_or_update_chat_feed_back(self, feed_back: FeedBackBody):
# Todo: We need to have user information first.
def_user_name = ""

session = self.get_session()
chat_feed_back = ChatFeedBackEntity(
Expand All @@ -60,7 +59,7 @@ def create_or_update_chat_feed_back(self, feed_back: FeedBackBody):
question=feed_back.question,
knowledge_space=feed_back.knowledge_space,
messages=feed_back.messages,
user_name=def_user_name,
user_name=feed_back.user_name,
gmt_created=datetime.now(),
gmt_modified=datetime.now(),
)
Expand All @@ -76,7 +75,7 @@ def create_or_update_chat_feed_back(self, feed_back: FeedBackBody):
result.question = feed_back.question
result.knowledge_space = feed_back.knowledge_space
result.messages = feed_back.messages
result.user_name = def_user_name
result.user_name = feed_back.user_name
result.gmt_created = datetime.now()
result.gmt_modified = datetime.now()
else:
Expand Down
11 changes: 7 additions & 4 deletions pilot/openapi/api_v1/feedback/feed_back_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from pydantic.main import BaseModel
from typing import Optional


class FeedBackBody(BaseModel):
Expand All @@ -12,14 +13,16 @@ class FeedBackBody(BaseModel):
"""question: human question"""
question: str

"""knowledge_space: knowledge space"""
knowledge_space: str

"""score: rating of the llm's answer"""
score: int

"""ques_type: question type"""
ques_type: str

user_name: Optional[str] = None

"""messages: rating detail"""
messages: str
messages: Optional[str] = None

"""knowledge_space: knowledge space"""
knowledge_space: Optional[str] = None
3 changes: 3 additions & 0 deletions pilot/server/prompt/prompt_manage_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def create_prompt(self, prompt: PromptManageRequest):
prompt_name=prompt.prompt_name,
content=prompt.content,
user_name=prompt.user_name,
sys_code=prompt.sys_code,
gmt_created=datetime.now(),
gmt_modified=datetime.now(),
)
Expand Down Expand Up @@ -83,6 +84,8 @@ def get_prompts(self, query: PromptManageEntity):
prompts = prompts.filter(
PromptManageEntity.prompt_name == query.prompt_name
)
if query.sys_code is not None:
prompts = prompts.filter(PromptManageEntity.sys_code == query.sys_code)

prompts = prompts.order_by(PromptManageEntity.gmt_created.desc())
result = prompts.all()
Expand Down
56 changes: 38 additions & 18 deletions pilot/server/prompt/request/request.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,44 @@
from typing import List

from pydantic import BaseModel
from typing import Optional
from pydantic import BaseModel


class PromptManageRequest(BaseModel):
"""chat_scene: for example: chat_with_db_execute, chat_excel, chat_with_db_qa"""

chat_scene: str = None

"""sub_chat_scene: sub chat scene"""
sub_chat_scene: str = None

"""prompt_type: common or private"""
prompt_type: str = None

"""content: prompt content"""
content: str = None

"""user_name: user name"""
user_name: str = None

"""prompt_name: prompt name"""
prompt_name: str = None
"""Model for managing prompts."""

chat_scene: Optional[str] = None
"""
The chat scene, e.g. chat_with_db_execute, chat_excel, chat_with_db_qa.
"""

sub_chat_scene: Optional[str] = None
"""
The sub chat scene.
"""

prompt_type: Optional[str] = None
"""
The prompt type, either common or private.
"""

content: Optional[str] = None
"""
The prompt content.
"""

user_name: Optional[str] = None
"""
The user name.
"""

sys_code: Optional[str] = None
"""
System code
"""

prompt_name: Optional[str] = None
"""
The prompt name.
"""
9 changes: 8 additions & 1 deletion pilot/server/prompt/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,15 @@ def create_prompt(self, request: PromptManageRequest):
query = PromptManageRequest(
prompt_name=request.prompt_name,
)
err_sys_str = ""
if query.sys_code:
query.sys_code = request.sys_code
err_sys_str = f" and sys_code: {request.sys_code}"
prompt_name = prompt_manage_dao.get_prompts(query)
if len(prompt_name) > 0:
raise Exception(f"prompt name:{request.prompt_name} have already named")
raise Exception(
f"prompt name: {request.prompt_name}{err_sys_str} have already named"
)
prompt_manage_dao.create_prompt(request)
return True

Expand All @@ -32,6 +38,7 @@ def get_prompts(self, request: PromptManageRequest):
prompt_type=request.prompt_type,
prompt_name=request.prompt_name,
user_name=request.user_name,
sys_code=request.sys_code,
)
responses = []
prompts = prompt_manage_dao.get_prompts(query)
Expand Down