Skip to content

Commit

Permalink
feat:llm manage
Browse files Browse the repository at this point in the history
  • Loading branch information
Aries-ckt committed Sep 21, 2023
2 parents ce3b2e6 + 3590d7b commit d512dde
Show file tree
Hide file tree
Showing 14 changed files with 310 additions and 17 deletions.
18 changes: 18 additions & 0 deletions assets/schema/history.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
CREATE DATABASE history;
use history;
CREATE TABLE `chat_feed_back` (
`id` bigint(20) NOT NULL AUTO_INCREMENT,
`conv_uid` varchar(128) DEFAULT NULL COMMENT '会话id',
`conv_index` int(4) DEFAULT NULL COMMENT '第几轮会话',
`score` int(1) DEFAULT NULL COMMENT '评分',
`ques_type` varchar(32) DEFAULT NULL COMMENT '用户问题类别',
`question` longtext DEFAULT NULL COMMENT '用户问题',
`knowledge_space` varchar(128) DEFAULT NULL COMMENT '知识库',
`messages` longtext DEFAULT NULL COMMENT '评价详情',
`user_name` varchar(128) DEFAULT NULL COMMENT '评价人',
`gmt_created` datetime DEFAULT NULL,
`gmt_modified` datetime DEFAULT NULL,
PRIMARY KEY (`id`),
UNIQUE KEY `uk_conv` (`conv_uid`,`conv_index`),
KEY `idx_conv` (`conv_uid`,`conv_index`)
) ENGINE=InnoDB AUTO_INCREMENT=0 DEFAULT CHARSET=utf8mb4 COMMENT='用户评分反馈表';
3 changes: 3 additions & 0 deletions pilot/configs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,9 @@ def __init__(self) -> None:
self.KNOWLEDGE_SEARCH_MAX_TOKEN = int(
os.getenv("KNOWLEDGE_SEARCH_MAX_TOKEN", 2000)
)
### Control whether to display the source document of knowledge on the front end.
self.KNOWLEDGE_CHAT_SHOW_RELATIONS = False

### SUMMARY_CONFIG Configuration
self.SUMMARY_CONFIG = os.getenv("SUMMARY_CONFIG", "FAST")

Expand Down
4 changes: 2 additions & 2 deletions pilot/model/cluster/worker/default_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pilot.model.parameter import ModelParameters
from pilot.model.cluster.worker_base import ModelWorker
from pilot.server.chat_adapter import get_llm_chat_adapter, BaseChatAdpter
from pilot.utils.model_utils import _clear_torch_cache
from pilot.utils.model_utils import _clear_model_cache
from pilot.utils.parameter_utils import EnvArgumentParser

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -87,7 +87,7 @@ def stop(self) -> None:
del self.tokenizer
self.model = None
self.tokenizer = None
_clear_torch_cache(self._model_params.device)
_clear_model_cache(self._model_params.device)

def generate_stream(self, params: Dict) -> Iterator[ModelOutput]:
torch_imported = False
Expand Down
4 changes: 2 additions & 2 deletions pilot/model/cluster/worker/embedding_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
)
from pilot.model.cluster.worker_base import ModelWorker
from pilot.model.cluster.embedding.loader import EmbeddingLoader
from pilot.utils.model_utils import _clear_torch_cache
from pilot.utils.model_utils import _clear_model_cache
from pilot.utils.parameter_utils import EnvArgumentParser

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -79,7 +79,7 @@ def stop(self) -> None:
return
del self._embeddings_impl
self._embeddings_impl = None
_clear_torch_cache(self._model_params.device)
_clear_model_cache(self._model_params.device)

def generate_stream(self, params: Dict):
"""Generate stream result, chat scene"""
Expand Down
1 change: 1 addition & 0 deletions pilot/model/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
def _check_multi_gpu_or_4bit_quantization(model_params: ModelParameters):
# TODO: vicuna-v1.5 8-bit quantization info is slow
# TODO: support wizardlm quantization, see: https://huggingface.co/WizardLM/WizardLM-13B-V1.2/discussions/5
# TODO: support internlm quantization
model_name = model_params.model_name.lower()
supported_models = ["llama", "baichuan", "vicuna"]
return any(m in model_name for m in supported_models)
Expand Down
64 changes: 58 additions & 6 deletions pilot/openapi/api_v1/api_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
ConversationVo,
MessageVo,
ChatSceneVo,
ChatCompletionResponseStreamChoice,
DeltaMessage,
ChatCompletionStreamResponse,
)
from pilot.connections.db_conn_info import DBConfig, DbTypeInfo
from pilot.configs.config import Config
Expand Down Expand Up @@ -85,6 +88,26 @@ def plugins_select_info():
return plugins_infos


def get_db_list_info():
dbs = CFG.LOCAL_DB_MANAGE.get_db_list()
params: dict = {}
for item in dbs:
comment = item["comment"]
if comment is not None and len(comment) > 0:
params.update({item["db_name"]: comment})
return params


def knowledge_list_info():
"""return knowledge space list"""
params: dict = {}
request = KnowledgeSpaceRequest()
spaces = knowledge_service.get_knowledge_space(request)
for space in spaces:
params.update({space.name: space.desc})
return params


def knowledge_list():
"""return knowledge space list"""
params: dict = {}
Expand Down Expand Up @@ -363,7 +386,7 @@ async def chat_completions(dialogue: ConversationVo = Body()):
)
else:
return StreamingResponse(
stream_generator(chat),
stream_generator(chat, dialogue.incremental, dialogue.model_name),
headers=headers,
media_type="text/plain",
)
Expand Down Expand Up @@ -401,19 +424,48 @@ async def no_stream_generator(chat):
yield f"data: {msg}\n\n"


async def stream_generator(chat):
async def stream_generator(chat, incremental: bool, model_name: str):
"""Generate streaming responses
Our goal is to generate an openai-compatible streaming responses.
Currently, the incremental response is compatible, and the full response will be transformed in the future.
Args:
chat (BaseChat): Chat instance.
incremental (bool): Used to control whether the content is returned incrementally or in full each time.
model_name (str): The model name
Yields:
_type_: streaming responses
"""
msg = "[LLM_ERROR]: llm server has no output, maybe your prompt template is wrong."

stream_id = f"chatcmpl-{str(uuid.uuid1())}"
previous_response = ""
async for chunk in chat.stream_call():
if chunk:
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(
chunk, chat.skip_echo_len
)

msg = msg.replace("\n", "\\n")
yield f"data:{msg}\n\n"
msg = msg.replace("\ufffd", "")
if incremental:
incremental_output = msg[len(previous_response) :]
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role="assistant", content=incremental_output),
)
chunk = ChatCompletionStreamResponse(
id=stream_id, choices=[choice_data], model=model_name
)
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
else:
# TODO generate an openai-compatible streaming responses
msg = msg.replace("\n", "\\n")
yield f"data:{msg}\n\n"
previous_response = msg
await asyncio.sleep(0.02)

if incremental:
yield "data: [DONE]\n\n"
chat.current_message.add_ai_message(msg)
chat.current_message.add_view_message(msg)
chat.memory.append(chat.current_message)
Expand Down
Empty file.
48 changes: 48 additions & 0 deletions pilot/openapi/api_v1/feedback/api_fb_v1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from fastapi import APIRouter, Body, Request

from pilot.openapi.api_v1.feedback.feed_back_model import FeedBackBody
from pilot.openapi.api_v1.feedback.feed_back_db import (
ChatFeedBackDao,
ChatFeedBackEntity,
)
from pilot.openapi.api_view_model import Result

router = APIRouter()
chat_feed_back = ChatFeedBackDao()


@router.get("/v1/feedback/find", response_model=Result[FeedBackBody])
async def feed_back_find(conv_uid: str, conv_index: int):
rt = chat_feed_back.get_chat_feed_back(conv_uid, conv_index)
if rt is not None:
return Result.succ(
FeedBackBody(
conv_uid=rt.conv_uid,
conv_index=rt.conv_index,
question=rt.question,
knowledge_space=rt.knowledge_space,
score=rt.score,
ques_type=rt.ques_type,
messages=rt.messages,
)
)
else:
return Result.succ(None)


@router.post("/v1/feedback/commit", response_model=Result[bool])
async def feed_back_commit(request: Request, feed_back_body: FeedBackBody = Body()):
chat_feed_back.create_or_update_chat_feed_back(feed_back_body)
return Result.succ(True)


@router.get("/v1/feedback/select", response_model=Result[dict])
async def feed_back_select():
return Result.succ(
{
"information": "信息查询",
"work_study": "工作学习",
"just_fun": "互动闲聊",
"others": "其他",
}
)
84 changes: 84 additions & 0 deletions pilot/openapi/api_v1/feedback/feed_back_db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from datetime import datetime

from sqlalchemy import Column, Integer, Text, String, DateTime
from sqlalchemy.ext.declarative import declarative_base

from pilot.connections.rdbms.base_dao import BaseDao
from pilot.openapi.api_v1.feedback.feed_back_model import FeedBackBody

Base = declarative_base()


class ChatFeedBackEntity(Base):
__tablename__ = "chat_feed_back"
id = Column(Integer, primary_key=True)
conv_uid = Column(String(128))
conv_index = Column(Integer)
score = Column(Integer)
ques_type = Column(String(32))
question = Column(Text)
knowledge_space = Column(String(128))
messages = Column(Text)
user_name = Column(String(128))
gmt_created = Column(DateTime)
gmt_modified = Column(DateTime)

def __repr__(self):
return (
f"ChatFeekBackEntity(id={self.id}, conv_index='{self.conv_index}', conv_index='{self.conv_index}', "
f"score='{self.score}', ques_type='{self.ques_type}', question='{self.question}', knowledge_space='{self.knowledge_space}', "
f"messages='{self.messages}', user_name='{self.user_name}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')"
)


class ChatFeedBackDao(BaseDao):
def __init__(self):
super().__init__(database="history", orm_base=Base, create_not_exist_table=True)

def create_or_update_chat_feed_back(self, feed_back: FeedBackBody):
# Todo: We need to have user information first.
def_user_name = ""

session = self.Session()
chat_feed_back = ChatFeedBackEntity(
conv_uid=feed_back.conv_uid,
conv_index=feed_back.conv_index,
score=feed_back.score,
ques_type=feed_back.ques_type,
question=feed_back.question,
knowledge_space=feed_back.knowledge_space,
messages=feed_back.messages,
user_name=def_user_name,
gmt_created=datetime.now(),
gmt_modified=datetime.now(),
)
result = (
session.query(ChatFeedBackEntity)
.filter(ChatFeedBackEntity.conv_uid == feed_back.conv_uid)
.filter(ChatFeedBackEntity.conv_index == feed_back.conv_index)
.first()
)
if result is not None:
result.score = feed_back.score
result.ques_type = feed_back.ques_type
result.question = feed_back.question
result.knowledge_space = feed_back.knowledge_space
result.messages = feed_back.messages
result.user_name = def_user_name
result.gmt_created = datetime.now()
result.gmt_modified = datetime.now()
else:
session.merge(chat_feed_back)
session.commit()
session.close()

def get_chat_feed_back(self, conv_uid: str, conv_index: int):
session = self.Session()
result = (
session.query(ChatFeedBackEntity)
.filter(ChatFeedBackEntity.conv_uid == conv_uid)
.filter(ChatFeedBackEntity.conv_index == conv_index)
.first()
)
session.close()
return result
25 changes: 25 additions & 0 deletions pilot/openapi/api_v1/feedback/feed_back_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from pydantic.main import BaseModel


class FeedBackBody(BaseModel):
"""conv_uid: conversation id"""

conv_uid: str

"""conv_index: conversation index"""
conv_index: int

"""question: human question"""
question: str

"""knowledge_space: knowledge space"""
knowledge_space: str

"""score: rating of the llm's answer"""
score: int

"""ques_type: question type"""
ques_type: str

"""messages: rating detail"""
messages: str
27 changes: 26 additions & 1 deletion pilot/openapi/api_view_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from pydantic import BaseModel, Field
from typing import TypeVar, Generic, Any
from typing import TypeVar, Generic, Any, Optional, Literal, List
import uuid
import time

T = TypeVar("T")

Expand Down Expand Up @@ -59,6 +61,11 @@ class ConversationVo(BaseModel):
"""
model_name: str = None

"""Used to control whether the content is returned incrementally or in full each time.
If this parameter is not provided, the default is full return.
"""
incremental: bool = False


class MessageVo(BaseModel):
"""
Expand All @@ -83,3 +90,21 @@ class MessageVo(BaseModel):
model_name
"""
model_name: str


class DeltaMessage(BaseModel):
role: Optional[str] = None
content: Optional[str] = None


class ChatCompletionResponseStreamChoice(BaseModel):
index: int
delta: DeltaMessage
finish_reason: Optional[Literal["stop", "length"]] = None


class ChatCompletionStreamResponse(BaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{str(uuid.uuid1())}")
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseStreamChoice]
Loading

0 comments on commit d512dde

Please sign in to comment.