From eeff46487d5e94c4aace7bef6e39810fa97be8d0 Mon Sep 17 00:00:00 2001
From: FangYin Cheng
Date: Mon, 27 Nov 2023 20:17:56 +0800
Subject: [PATCH 1/2] feat(core): MTB supports multi-user and multi-system
fields (#854)
---
assets/schema/knowledge_management.sql | 4 +
pilot/base_modules/agent/db/my_plugin_db.py | 6 +
pilot/configs/model_config.py | 4 +
.../connections/manages/connect_config_db.py | 1 +
pilot/memory/chat_history/base.py | 15 +-
.../chat_history/chat_hisotry_factory.py | 3 +-
pilot/memory/chat_history/chat_history_db.py | 10 +-
.../chat_history/store_type/duckdb_history.py | 38 +++--
.../store_type/meta_db_history.py | 11 +-
pilot/openapi/api_v1/api_v1.py | 88 ++++++-----
pilot/openapi/api_view_model.py | 2 +
pilot/scene/base_chat.py | 5 +-
pilot/scene/message.py | 10 +-
pilot/server/prompt/prompt_manage_db.py | 1 +
setup.py | 147 ++++++++++++++++--
15 files changed, 262 insertions(+), 83 deletions(-)
diff --git a/assets/schema/knowledge_management.sql b/assets/schema/knowledge_management.sql
index a6f1bc478..eb292b358 100644
--- a/assets/schema/knowledge_management.sql
+++ b/assets/schema/knowledge_management.sql
@@ -66,6 +66,7 @@ CREATE TABLE `connect_config` (
`db_user` varchar(255) DEFAULT NULL COMMENT 'db user',
`db_pwd` varchar(255) DEFAULT NULL COMMENT 'db password',
`comment` text COMMENT 'db comment',
+ `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code',
PRIMARY KEY (`id`),
UNIQUE KEY `uk_db` (`db_name`),
KEY `idx_q_db_type` (`db_type`)
@@ -78,6 +79,7 @@ CREATE TABLE `chat_history` (
`summary` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'Conversation record summary',
`user_name` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'interlocutor',
`messages` text COLLATE utf8mb4_unicode_ci COMMENT 'Conversation details',
+ `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code',
PRIMARY KEY (`id`)
) ENGINE=InnoDB AUTO_INCREMENT=2 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT 'Chat history';
@@ -110,6 +112,7 @@ CREATE TABLE `my_plugin` (
`version` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin version',
`use_count` int DEFAULT NULL COMMENT 'plugin total use count',
`succ_count` int DEFAULT NULL COMMENT 'plugin total success count',
+ `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code',
`gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'plugin install time',
PRIMARY KEY (`id`),
UNIQUE KEY `name` (`name`)
@@ -141,6 +144,7 @@ CREATE TABLE `prompt_manage` (
`prompt_name` varchar(512) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'prompt name',
`content` longtext COLLATE utf8mb4_unicode_ci COMMENT 'Prompt content',
`user_name` varchar(128) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'User name',
+ `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code',
`gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time',
`gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time',
PRIMARY KEY (`id`),
diff --git a/pilot/base_modules/agent/db/my_plugin_db.py b/pilot/base_modules/agent/db/my_plugin_db.py
index fb4fe25ee..cca4e176d 100644
--- a/pilot/base_modules/agent/db/my_plugin_db.py
+++ b/pilot/base_modules/agent/db/my_plugin_db.py
@@ -32,6 +32,7 @@ class MyPluginEntity(Base):
succ_count = Column(
Integer, nullable=True, default=0, comment="plugin total success count"
)
+ sys_code = Column(String(128), index=True, nullable=True, comment="System code")
gmt_created = Column(
DateTime, default=datetime.utcnow, comment="plugin install time"
)
@@ -58,6 +59,7 @@ def add(self, engity: MyPluginEntity):
version=engity.version,
use_count=engity.use_count or 0,
succ_count=engity.succ_count or 0,
+ sys_code=engity.sys_code,
gmt_created=datetime.now(),
)
session.add(my_plugin)
@@ -107,6 +109,8 @@ def list(self, query: MyPluginEntity, page=1, page_size=20) -> list[MyPluginEnti
my_plugins = my_plugins.filter(MyPluginEntity.user_code == query.user_code)
if query.user_name is not None:
my_plugins = my_plugins.filter(MyPluginEntity.user_name == query.user_name)
+ if query.sys_code is not None:
+ my_plugins = my_plugins.filter(MyPluginEntity.sys_code == query.sys_code)
my_plugins = my_plugins.order_by(MyPluginEntity.id.desc())
my_plugins = my_plugins.offset((page - 1) * page_size).limit(page_size)
@@ -133,6 +137,8 @@ def count(self, query: MyPluginEntity):
my_plugins = my_plugins.filter(MyPluginEntity.user_code == query.user_code)
if query.user_name is not None:
my_plugins = my_plugins.filter(MyPluginEntity.user_name == query.user_name)
+ if query.sys_code is not None:
+ my_plugins = my_plugins.filter(MyPluginEntity.sys_code == query.sys_code)
count = my_plugins.scalar()
session.close()
return count
diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py
index 890d380f9..b66102e24 100644
--- a/pilot/configs/model_config.py
+++ b/pilot/configs/model_config.py
@@ -128,6 +128,10 @@ def get_device() -> str:
"xwin-lm-70b-v0.1": os.path.join(MODEL_PATH, "Xwin-LM-70B-V0.1"),
# https://huggingface.co/01-ai/Yi-34B-Chat
"yi-34b-chat": os.path.join(MODEL_PATH, "Yi-34B-Chat"),
+ # https://huggingface.co/01-ai/Yi-34B-Chat-8bits
+ "yi-34b-chat-8bits": os.path.join(MODEL_PATH, "Yi-34B-Chat-8bits"),
+ # https://huggingface.co/01-ai/Yi-34B-Chat-4bits
+ "yi-34b-chat-4bits": os.path.join(MODEL_PATH, "Yi-34B-Chat-4bits"),
"yi-6b-chat": os.path.join(MODEL_PATH, "Yi-6B-Chat"),
}
diff --git a/pilot/connections/manages/connect_config_db.py b/pilot/connections/manages/connect_config_db.py
index 0898bfbea..602ff719b 100644
--- a/pilot/connections/manages/connect_config_db.py
+++ b/pilot/connections/manages/connect_config_db.py
@@ -24,6 +24,7 @@ class ConnectConfigEntity(Base):
db_user = Column(String(255), nullable=True, comment="db user")
db_pwd = Column(String(255), nullable=True, comment="db password")
comment = Column(Text, nullable=True, comment="db comment")
+ sys_code = Column(String(128), index=True, nullable=True, comment="System code")
__table_args__ = (
UniqueConstraint("db_name", name="uk_db"),
diff --git a/pilot/memory/chat_history/base.py b/pilot/memory/chat_history/base.py
index a8a09153c..6ad9a604b 100644
--- a/pilot/memory/chat_history/base.py
+++ b/pilot/memory/chat_history/base.py
@@ -1,7 +1,7 @@
from __future__ import annotations
from abc import ABC, abstractmethod
-from typing import List
+from typing import List, Optional, Dict
from enum import Enum
from pilot.scene.message import OnceConversation
@@ -35,11 +35,6 @@ def append(self, message: OnceConversation) -> None:
# def clear(self) -> None:
# """Clear session memory from the local file"""
- @abstractmethod
- def conv_list(self, user_name: str = None) -> None:
- """get user's conversation list"""
- pass
-
@abstractmethod
def update(self, messages: List[OnceConversation]) -> None:
pass
@@ -49,7 +44,7 @@ def delete(self) -> bool:
pass
@abstractmethod
- def conv_info(self, conv_uid: str = None) -> None:
+ def conv_info(self, conv_uid: Optional[str] = None) -> None:
pass
@abstractmethod
@@ -57,5 +52,7 @@ def get_messages(self) -> List[OnceConversation]:
pass
@staticmethod
- def conv_list(cls, user_name: str = None) -> None:
- pass
+ def conv_list(
+ user_name: Optional[str] = None, sys_code: Optional[str] = None
+ ) -> List[Dict]:
+ """get user's conversation list"""
diff --git a/pilot/memory/chat_history/chat_hisotry_factory.py b/pilot/memory/chat_history/chat_hisotry_factory.py
index c1a8f9cab..f8f02f346 100644
--- a/pilot/memory/chat_history/chat_hisotry_factory.py
+++ b/pilot/memory/chat_history/chat_hisotry_factory.py
@@ -1,3 +1,4 @@
+from typing import Type
from .base import MemoryStoreType
from pilot.configs.config import Config
from pilot.memory.chat_history.base import BaseChatHistoryMemory
@@ -32,5 +33,5 @@ def get_store_instance(self, chat_session_id: str) -> BaseChatHistoryMemory:
chat_session_id
)
- def get_store_cls(self):
+ def get_store_cls(self) -> Type[BaseChatHistoryMemory]:
return self.mem_store_class_map.get(CFG.CHAT_HISTORY_STORE_TYPE)
diff --git a/pilot/memory/chat_history/chat_history_db.py b/pilot/memory/chat_history/chat_history_db.py
index 8ef898c27..010699d45 100644
--- a/pilot/memory/chat_history/chat_history_db.py
+++ b/pilot/memory/chat_history/chat_history_db.py
@@ -1,4 +1,4 @@
-from typing import List
+from typing import List, Optional
from sqlalchemy import Column, Integer, String, Index, DateTime, func, Boolean, Text
from sqlalchemy import UniqueConstraint
@@ -32,7 +32,7 @@ class ChatHistoryEntity(Base):
messages = Column(
Text(length=2**31 - 1), nullable=True, comment="Conversation details"
)
-
+ sys_code = Column(String(128), index=True, nullable=True, comment="System code")
UniqueConstraint("conv_uid", name="uk_conversation")
Index("idx_q_user", "user_name")
Index("idx_q_mode", "chat_mode")
@@ -48,11 +48,15 @@ def __init__(self):
session=session,
)
- def list_last_20(self, user_name: str = None):
+ def list_last_20(
+ self, user_name: Optional[str] = None, sys_code: Optional[str] = None
+ ):
session = self.get_session()
chat_history = session.query(ChatHistoryEntity)
if user_name:
chat_history = chat_history.filter(ChatHistoryEntity.user_name == user_name)
+ if sys_code:
+ chat_history = chat_history.filter(ChatHistoryEntity.sys_code == sys_code)
chat_history = chat_history.order_by(ChatHistoryEntity.id.desc())
diff --git a/pilot/memory/chat_history/store_type/duckdb_history.py b/pilot/memory/chat_history/store_type/duckdb_history.py
index 28c92a142..790577343 100644
--- a/pilot/memory/chat_history/store_type/duckdb_history.py
+++ b/pilot/memory/chat_history/store_type/duckdb_history.py
@@ -1,7 +1,7 @@
import json
import os
import duckdb
-from typing import List
+from typing import List, Dict, Optional
from pilot.configs.config import Config
from pilot.memory.chat_history.base import BaseChatHistoryMemory
@@ -37,7 +37,7 @@ def __init_chat_history_tables(self):
if not result:
# 如果表不存在,则创建新表
self.connect.execute(
- "CREATE TABLE chat_history (id integer primary key, conv_uid VARCHAR(100) UNIQUE, chat_mode VARCHAR(50), summary VARCHAR(255), user_name VARCHAR(100), messages TEXT)"
+ "CREATE TABLE chat_history (id integer primary key, conv_uid VARCHAR(100) UNIQUE, chat_mode VARCHAR(50), summary VARCHAR(255), user_name VARCHAR(100), sys_code VARCHAR(128), messages TEXT)"
)
self.connect.execute("CREATE SEQUENCE seq_id START 1;")
@@ -61,8 +61,8 @@ def create(self, chat_mode, summary: str, user_name: str) -> None:
try:
cursor = self.connect.cursor()
cursor.execute(
- "INSERT INTO chat_history(id, conv_uid, chat_mode summary, user_name, messages)VALUES(nextval('seq_id'),?,?,?,?,?)",
- [self.chat_seesion_id, chat_mode, summary, user_name, ""],
+ "INSERT INTO chat_history(id, conv_uid, chat_mode summary, user_name, sys_code, messages)VALUES(nextval('seq_id'),?,?,?,?,?,?)",
+ [self.chat_seesion_id, chat_mode, summary, user_name, "", ""],
)
cursor.commit()
self.connect.commit()
@@ -83,12 +83,13 @@ def append(self, once_message: OnceConversation) -> None:
)
else:
cursor.execute(
- "INSERT INTO chat_history(id, conv_uid, chat_mode, summary, user_name, messages)VALUES(nextval('seq_id'),?,?,?,?,?)",
+ "INSERT INTO chat_history(id, conv_uid, chat_mode, summary, user_name, sys_code, messages)VALUES(nextval('seq_id'),?,?,?,?,?,?)",
[
self.chat_seesion_id,
once_message.chat_mode,
once_message.get_user_conv().content,
- "",
+ once_message.user_name,
+ once_message.sys_code,
json.dumps(conversations, ensure_ascii=False),
],
)
@@ -149,17 +150,26 @@ def get_messages(self) -> List[OnceConversation]:
return None
@staticmethod
- def conv_list(cls, user_name: str = None) -> None:
+ def conv_list(
+ user_name: Optional[str] = None, sys_code: Optional[str] = None
+ ) -> List[Dict]:
if os.path.isfile(duckdb_path):
cursor = duckdb.connect(duckdb_path).cursor()
+ query = "SELECT * FROM chat_history"
+ params = []
+ conditions = []
if user_name:
- cursor.execute(
- "SELECT * FROM chat_history where user_name=? order by id desc limit 20",
- [user_name],
- )
- else:
- cursor.execute("SELECT * FROM chat_history order by id desc limit 20")
- # 获取查询结果字段名
+ conditions.append("user_name = ?")
+ params.append(user_name)
+ if sys_code:
+ conditions.append("sys_code = ?")
+ params.append(sys_code)
+
+ if conditions:
+ query += " WHERE " + " AND ".join(conditions)
+
+ query += " ORDER BY id DESC LIMIT 20"
+ cursor.execute(query, params)
fields = [field[0] for field in cursor.description]
data = []
for row in cursor.fetchall():
diff --git a/pilot/memory/chat_history/store_type/meta_db_history.py b/pilot/memory/chat_history/store_type/meta_db_history.py
index f1c25d633..2e69e3ec5 100644
--- a/pilot/memory/chat_history/store_type/meta_db_history.py
+++ b/pilot/memory/chat_history/store_type/meta_db_history.py
@@ -1,6 +1,6 @@
import json
import logging
-from typing import List
+from typing import List, Dict, Optional
from sqlalchemy import Column, Integer, String, Index, DateTime, func, Boolean, Text
from sqlalchemy import UniqueConstraint
from pilot.configs.config import Config
@@ -62,7 +62,8 @@ def append(self, once_message: OnceConversation) -> None:
chat_history: ChatHistoryEntity = ChatHistoryEntity()
chat_history.conv_uid = self.chat_seesion_id
chat_history.chat_mode = once_message.chat_mode
- chat_history.user_name = "default"
+ chat_history.user_name = once_message.user_name
+ chat_history.sys_code = once_message.sys_code
chat_history.summary = once_message.get_user_conv().content
conversations.append(_conversation_to_dic(once_message))
@@ -92,9 +93,11 @@ def get_messages(self) -> List[OnceConversation]:
return []
@staticmethod
- def conv_list(cls, user_name: str = None) -> None:
+ def conv_list(
+ user_name: Optional[str] = None, sys_code: Optional[str] = None
+ ) -> List[Dict]:
chat_history_dao = ChatHistoryDao()
- history_list = chat_history_dao.list_last_20()
+ history_list = chat_history_dao.list_last_20(user_name, sys_code)
result = []
for history in history_list:
result.append(history.__dict__)
diff --git a/pilot/openapi/api_v1/api_v1.py b/pilot/openapi/api_v1/api_v1.py
index 70bfc2e9f..262ff6f6c 100644
--- a/pilot/openapi/api_v1/api_v1.py
+++ b/pilot/openapi/api_v1/api_v1.py
@@ -2,7 +2,7 @@
import uuid
import asyncio
import os
-import shutil
+import aiofiles
import logging
from fastapi import (
APIRouter,
@@ -17,7 +17,7 @@
from fastapi.responses import StreamingResponse
from fastapi.exceptions import RequestValidationError
-from typing import List
+from typing import List, Optional
import tempfile
from concurrent.futures import Executor
@@ -48,7 +48,11 @@
from pilot.model.cluster import BaseModelController, WorkerManager, WorkerManagerFactory
from pilot.model.base import FlatSupportedModel
from pilot.utils.tracer import root_tracer, SpanType
-from pilot.utils.executor_utils import ExecutorFactory, blocking_func_to_async
+from pilot.utils.executor_utils import (
+ ExecutorFactory,
+ blocking_func_to_async,
+ DefaultExecutorFactory,
+)
router = APIRouter()
CFG = Config()
@@ -68,9 +72,11 @@ def __get_conv_user_message(conversations: dict):
return ""
-def __new_conversation(chat_mode, user_id) -> ConversationVo:
+def __new_conversation(chat_mode, user_name: str, sys_code: str) -> ConversationVo:
unique_id = uuid.uuid1()
- return ConversationVo(conv_uid=str(unique_id), chat_mode=chat_mode)
+ return ConversationVo(
+ conv_uid=str(unique_id), chat_mode=chat_mode, sys_code=sys_code
+ )
def get_db_list():
@@ -141,7 +147,9 @@ def get_worker_manager() -> WorkerManager:
def get_executor() -> Executor:
"""Get the global default executor"""
return CFG.SYSTEM_APP.get_component(
- ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
+ ComponentType.EXECUTOR_DEFAULT,
+ ExecutorFactory,
+ or_register_component=DefaultExecutorFactory,
).create()
@@ -166,7 +174,6 @@ async def db_connect_delete(db_name: str = None):
async def async_db_summary_embedding(db_name, db_type):
- # 在这里执行需要异步运行的代码
db_summary_client = DBSummaryClient(system_app=CFG.SYSTEM_APP)
db_summary_client.db_summary_embedding(db_name, db_type)
@@ -200,16 +207,21 @@ async def db_support_types():
@router.get("/v1/chat/dialogue/list", response_model=Result[ConversationVo])
-async def dialogue_list(user_id: str = None):
+async def dialogue_list(
+ user_name: str = None, user_id: str = None, sys_code: str = None
+):
dialogues: List = []
chat_history_service = ChatHistory()
# TODO Change the synchronous call to the asynchronous call
- datas = chat_history_service.get_store_cls().conv_list(user_id)
+ user_name = user_name or user_id
+ datas = chat_history_service.get_store_cls().conv_list(user_name, sys_code)
for item in datas:
conv_uid = item.get("conv_uid")
summary = item.get("summary")
chat_mode = item.get("chat_mode")
model_name = item.get("model_name", CFG.LLM_MODEL)
+ user_name = item.get("user_name")
+ sys_code = item.get("sys_code")
messages = json.loads(item.get("messages"))
last_round = max(messages, key=lambda x: x["chat_order"])
@@ -223,6 +235,8 @@ async def dialogue_list(user_id: str = None):
chat_mode=chat_mode,
model_name=model_name,
select_param=select_param,
+ user_name=user_name,
+ sys_code=sys_code,
)
dialogues.append(conv_vo)
@@ -254,9 +268,14 @@ async def dialogue_scenes():
@router.post("/v1/chat/dialogue/new", response_model=Result[ConversationVo])
async def dialogue_new(
- chat_mode: str = ChatScene.ChatNormal.value(), user_id: str = None
+ chat_mode: str = ChatScene.ChatNormal.value(),
+ user_name: str = None,
+ # TODO remove user id
+ user_id: str = None,
+ sys_code: str = None,
):
- conv_vo = __new_conversation(chat_mode, user_id)
+ user_name = user_name or user_id
+ conv_vo = __new_conversation(chat_mode, user_name, sys_code)
return Result.succ(conv_vo)
@@ -280,40 +299,40 @@ async def params_list(chat_mode: str = ChatScene.ChatNormal.value()):
@router.post("/v1/chat/mode/params/file/load")
async def params_load(
- conv_uid: str, chat_mode: str, model_name: str, doc_file: UploadFile = File(...)
+ conv_uid: str,
+ chat_mode: str,
+ model_name: str,
+ user_name: Optional[str] = None,
+ sys_code: Optional[str] = None,
+ doc_file: UploadFile = File(...),
):
print(f"params_load: {conv_uid},{chat_mode},{model_name}")
try:
if doc_file:
- ## file save
- if not os.path.exists(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode)):
- os.makedirs(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode))
- # We can not move temp file in windows system when we open file in context of `with`
- tmp_fd, tmp_path = tempfile.mkstemp(
- dir=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode)
- )
- # TODO Use noblocking file save with aiofiles
- with os.fdopen(tmp_fd, "wb") as tmp:
- tmp.write(await doc_file.read())
- shutil.move(
- tmp_path,
- os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode, doc_file.filename),
- )
- ## chat prepare
+ # Save the uploaded file
+ upload_dir = os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode)
+ os.makedirs(upload_dir, exist_ok=True)
+ upload_path = os.path.join(upload_dir, doc_file.filename)
+ async with aiofiles.open(upload_path, "wb") as f:
+ await f.write(await doc_file.read())
+
+ # Prepare the chat
dialogue = ConversationVo(
conv_uid=conv_uid,
chat_mode=chat_mode,
select_param=doc_file.filename,
model_name=model_name,
+ user_name=user_name,
+ sys_code=sys_code,
)
chat: BaseChat = await get_chat_instance(dialogue)
resp = await chat.prepare()
- ### refresh messages
+ # Refresh messages
return Result.succ(get_hist_messages(conv_uid))
except Exception as e:
logger.error("excel load error!", e)
- return Result.failed(code="E000X", msg=f"File Load Error {e}")
+ return Result.failed(code="E000X", msg=f"File Load Error {str(e)}")
@router.post("/v1/chat/dialogue/delete")
@@ -354,7 +373,9 @@ async def get_chat_instance(dialogue: ConversationVo = Body()) -> BaseChat:
if not dialogue.chat_mode:
dialogue.chat_mode = ChatScene.ChatNormal.value()
if not dialogue.conv_uid:
- conv_vo = __new_conversation(dialogue.chat_mode, dialogue.user_name)
+ conv_vo = __new_conversation(
+ dialogue.chat_mode, dialogue.user_name, dialogue.sys_code
+ )
dialogue.conv_uid = conv_vo.conv_uid
if not ChatScene.is_valid_mode(dialogue.chat_mode):
@@ -364,13 +385,12 @@ async def get_chat_instance(dialogue: ConversationVo = Body()) -> BaseChat:
chat_param = {
"chat_session_id": dialogue.conv_uid,
+ "user_name": dialogue.user_name,
+ "sys_code": dialogue.sys_code,
"current_user_input": dialogue.user_input,
"select_param": dialogue.select_param,
"model_name": dialogue.model_name,
}
- # chat: BaseChat = CHAT_FACTORY.get_implementation(
- # dialogue.chat_mode, **{"chat_param": chat_param}
- # )
chat: BaseChat = await blocking_func_to_async(
get_executor(),
CHAT_FACTORY.get_implementation,
@@ -401,8 +421,6 @@ async def chat_completions(dialogue: ConversationVo = Body()):
"get_chat_instance", span_type=SpanType.CHAT, metadata=dialogue.dict()
):
chat: BaseChat = await get_chat_instance(dialogue)
- # background_tasks = BackgroundTasks()
- # background_tasks.add_task(release_model_semaphore)
headers = {
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
diff --git a/pilot/openapi/api_view_model.py b/pilot/openapi/api_view_model.py
index af1aa4b9c..5212076e8 100644
--- a/pilot/openapi/api_view_model.py
+++ b/pilot/openapi/api_view_model.py
@@ -66,6 +66,8 @@ class ConversationVo(BaseModel):
"""
incremental: bool = False
+ sys_code: Optional[str] = None
+
class MessageVo(BaseModel):
"""
diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py
index 0e263f7e5..864eb34b6 100644
--- a/pilot/scene/base_chat.py
+++ b/pilot/scene/base_chat.py
@@ -78,7 +78,9 @@ def __init__(self, chat_param: Dict):
self.history_message: List[OnceConversation] = self.memory.messages()
self.current_message: OnceConversation = OnceConversation(
- self.chat_mode.value()
+ self.chat_mode.value(),
+ user_name=chat_param.get("user_name"),
+ sys_code=chat_param.get("sys_code"),
)
self.current_message.model_name = self.llm_model
if chat_param["select_param"]:
@@ -171,7 +173,6 @@ async def __call_base(self):
"messages": llm_messages,
"temperature": float(self.prompt_template.temperature),
"max_new_tokens": int(self.prompt_template.max_new_tokens),
- # "stop": self.prompt_template.sep,
"echo": self.llm_echo,
}
return payload
diff --git a/pilot/scene/message.py b/pilot/scene/message.py
index 4d5a5c383..592f3cda0 100644
--- a/pilot/scene/message.py
+++ b/pilot/scene/message.py
@@ -18,7 +18,7 @@ class OnceConversation:
All the information of a conversation, the current single service in memory, can expand cache and database support distributed services
"""
- def __init__(self, chat_mode):
+ def __init__(self, chat_mode, user_name: str = None, sys_code: str = None):
self.chat_mode: str = chat_mode
self.messages: List[BaseMessage] = []
self.start_date: str = ""
@@ -28,6 +28,8 @@ def __init__(self, chat_mode):
self.param_value: str = ""
self.cost: int = 0
self.tokens: int = 0
+ self.user_name: str = user_name
+ self.sys_code: str = sys_code
def add_user_message(self, message: str) -> None:
"""Add a user message to the store"""
@@ -113,6 +115,8 @@ def _conversation_to_dic(once: OnceConversation) -> dict:
"messages": messages_to_dict(once.messages),
"param_type": once.param_type,
"param_value": once.param_value,
+ "user_name": once.user_name,
+ "sys_code": once.sys_code,
}
@@ -121,7 +125,9 @@ def conversations_to_dict(conversations: List[OnceConversation]) -> List[dict]:
def conversation_from_dict(once: dict) -> OnceConversation:
- conversation = OnceConversation()
+ conversation = OnceConversation(
+ once.get("chat_mode"), once.get("user_name"), once.get("sys_code")
+ )
conversation.cost = once.get("cost", 0)
conversation.chat_mode = once.get("chat_mode", "chat_normal")
conversation.tokens = once.get("tokens", 0)
diff --git a/pilot/server/prompt/prompt_manage_db.py b/pilot/server/prompt/prompt_manage_db.py
index 3245b3450..eff75c270 100644
--- a/pilot/server/prompt/prompt_manage_db.py
+++ b/pilot/server/prompt/prompt_manage_db.py
@@ -29,6 +29,7 @@ class PromptManageEntity(Base):
prompt_name = Column(String(512))
content = Column(Text)
user_name = Column(String(128))
+ sys_code = Column(String(128), index=True, nullable=True, comment="System code")
gmt_created = Column(DateTime)
gmt_modified = Column(DateTime)
diff --git a/setup.py b/setup.py
index ca0bd818f..d1386c79e 100644
--- a/setup.py
+++ b/setup.py
@@ -1,5 +1,4 @@
-from typing import List, Tuple
-
+from typing import List, Tuple, Optional, Callable
import setuptools
import platform
import subprocess
@@ -10,6 +9,7 @@
import re
import shutil
from setuptools import find_packages
+import functools
with open("README.md", mode="r", encoding="utf-8") as fh:
long_description = fh.read()
@@ -34,8 +34,15 @@ def parse_requirements(file_name: str) -> List[str]:
def get_latest_version(package_name: str, index_url: str, default_version: str):
+ python_command = shutil.which("python")
+ if not python_command:
+ python_command = shutil.which("python3")
+ if not python_command:
+ print("Python command not found.")
+ return default_version
+
command = [
- "python",
+ python_command,
"-m",
"pip",
"index",
@@ -125,6 +132,7 @@ class OSType(Enum):
OTHER = "other"
+@functools.cache
def get_cpu_avx_support() -> Tuple[OSType, AVXType]:
system = platform.system()
os_type = OSType.OTHER
@@ -206,6 +214,57 @@ def get_cuda_version() -> str:
return None
+def _build_wheels(
+ pkg_name: str,
+ pkg_version: str,
+ base_url: str = None,
+ base_url_func: Callable[[str, str, str], str] = None,
+ pkg_file_func: Callable[[str, str, str, str, OSType], str] = None,
+ supported_cuda_versions: List[str] = ["11.7", "11.8"],
+) -> Optional[str]:
+ """
+ Build the URL for the package wheel file based on the package name, version, and CUDA version.
+ Args:
+ pkg_name (str): The name of the package.
+ pkg_version (str): The version of the package.
+ base_url (str): The base URL for downloading the package.
+ base_url_func (Callable): A function to generate the base URL.
+ pkg_file_func (Callable): build package file function.
+ function params: pkg_name, pkg_version, cuda_version, py_version, OSType
+ supported_cuda_versions (List[str]): The list of supported CUDA versions.
+ Returns:
+ Optional[str]: The URL for the package wheel file.
+ """
+ os_type, _ = get_cpu_avx_support()
+ cuda_version = get_cuda_version()
+ py_version = platform.python_version()
+ py_version = "cp" + "".join(py_version.split(".")[0:2])
+ if os_type == OSType.DARWIN or not cuda_version:
+ return None
+ if cuda_version not in supported_cuda_versions:
+ print(
+ f"Warnning: {pkg_name} supported cuda version: {supported_cuda_versions}, replace to {supported_cuda_versions[-1]}"
+ )
+ cuda_version = supported_cuda_versions[-1]
+
+ cuda_version = "cu" + cuda_version.replace(".", "")
+ os_pkg_name = "linux_x86_64" if os_type == OSType.LINUX else "win_amd64"
+ if base_url_func:
+ base_url = base_url_func(pkg_version, cuda_version, py_version)
+ if base_url and base_url.endswith("/"):
+ base_url = base_url[:-1]
+ if pkg_file_func:
+ full_pkg_file = pkg_file_func(
+ pkg_name, pkg_version, cuda_version, py_version, os_type
+ )
+ else:
+ full_pkg_file = f"{pkg_name}-{pkg_version}+{cuda_version}-{py_version}-{py_version}-{os_pkg_name}.whl"
+ if not base_url:
+ return full_pkg_file
+ else:
+ return f"{base_url}/{full_pkg_file}"
+
+
def torch_requires(
torch_version: str = "2.0.1",
torchvision_version: str = "0.15.2",
@@ -222,16 +281,20 @@ def torch_requires(
cuda_version = get_cuda_version()
if cuda_version:
supported_versions = ["11.7", "11.8"]
- if cuda_version not in supported_versions:
- print(
- f"PyTorch version {torch_version} supported cuda version: {supported_versions}, replace to {supported_versions[-1]}"
- )
- cuda_version = supported_versions[-1]
- cuda_version = "cu" + cuda_version.replace(".", "")
- py_version = "cp310"
- os_pkg_name = "linux_x86_64" if os_type == OSType.LINUX else "win_amd64"
- torch_url = f"https://download.pytorch.org/whl/{cuda_version}/torch-{torch_version}+{cuda_version}-{py_version}-{py_version}-{os_pkg_name}.whl"
- torchvision_url = f"https://download.pytorch.org/whl/{cuda_version}/torchvision-{torchvision_version}+{cuda_version}-{py_version}-{py_version}-{os_pkg_name}.whl"
+ # torch_url = f"https://download.pytorch.org/whl/{cuda_version}/torch-{torch_version}+{cuda_version}-{py_version}-{py_version}-{os_pkg_name}.whl"
+ # torchvision_url = f"https://download.pytorch.org/whl/{cuda_version}/torchvision-{torchvision_version}+{cuda_version}-{py_version}-{py_version}-{os_pkg_name}.whl"
+ torch_url = _build_wheels(
+ "torch",
+ torch_version,
+ base_url_func=lambda v, x, y: f"https://download.pytorch.org/whl/{x}",
+ supported_cuda_versions=supported_versions,
+ )
+ torchvision_url = _build_wheels(
+ "torchvision",
+ torch_version,
+ base_url_func=lambda v, x, y: f"https://download.pytorch.org/whl/{x}",
+ supported_cuda_versions=supported_versions,
+ )
torch_url_cached = cache_package(
torch_url, "torch", os_type == OSType.WINDOWS
)
@@ -327,6 +390,7 @@ def core_requires():
"xlrd==2.0.1",
# for cache, TODO pympler has not been updated for a long time and needs to find a new toolkit.
"pympler",
+ "aiofiles",
]
if BUILD_FROM_SOURCE:
setup_spec.extras["framework"].append(
@@ -360,6 +424,41 @@ def llama_cpp_requires():
llama_cpp_python_cuda_requires()
+def _build_autoawq_requires() -> Optional[str]:
+ os_type, _ = get_cpu_avx_support()
+ if os_type == OSType.DARWIN:
+ return None
+ auto_gptq_version = get_latest_version(
+ "auto-gptq", "https://huggingface.github.io/autogptq-index/whl/cu118/", "0.5.1"
+ )
+ # eg. 0.5.1+cu118
+ auto_gptq_version = auto_gptq_version.split("+")[0]
+
+ def pkg_file_func(pkg_name, pkg_version, cuda_version, py_version, os_type):
+ pkg_name = pkg_name.replace("-", "_")
+ if os_type == OSType.DARWIN:
+ return None
+ os_pkg_name = (
+ "manylinux_2_17_x86_64.manylinux2014_x86_64.whl"
+ if os_type == OSType.LINUX
+ else "win_amd64.whl"
+ )
+ return f"{pkg_name}-{pkg_version}+{cuda_version}-{py_version}-{py_version}-{os_pkg_name}"
+
+ auto_gptq_url = _build_wheels(
+ "auto-gptq",
+ auto_gptq_version,
+ base_url_func=lambda v, x, y: f"https://huggingface.github.io/autogptq-index/whl/{x}/auto-gptq",
+ pkg_file_func=pkg_file_func,
+ supported_cuda_versions=["11.8"],
+ )
+ if auto_gptq_url:
+ print(f"Install auto-gptq from {auto_gptq_url}")
+ return f"auto-gptq @ {auto_gptq_url}"
+ else:
+ "auto-gptq"
+
+
def quantization_requires():
pkgs = []
os_type, _ = get_cpu_avx_support()
@@ -379,6 +478,28 @@ def quantization_requires():
print(pkgs)
# For chatglm2-6b-int4
pkgs += ["cpm_kernels"]
+
+ # Since transformers 4.35.0, the GPT-Q/AWQ model can be loaded using AutoModelForCausalLM.
+ # autoawq requirements:
+ # 1. Compute Capability 7.5 (sm75). Turing and later architectures are supported.
+ # 2. CUDA Toolkit 11.8 and later.
+ autoawq_url = _build_wheels(
+ "autoawq",
+ "0.1.7",
+ base_url_func=lambda v, x, y: f"https://github.com/casper-hansen/AutoAWQ/releases/download/v{v}",
+ supported_cuda_versions=["11.8"],
+ )
+ if autoawq_url:
+ print(f"Install autoawq from {autoawq_url}")
+ pkgs.append(f"autoawq @ {autoawq_url}")
+ else:
+ pkgs.append("autoawq")
+
+ auto_gptq_pkg = _build_autoawq_requires()
+ if auto_gptq_pkg:
+ pkgs.append(auto_gptq_pkg)
+ pkgs.append("optimum")
+
setup_spec.extras["quantization"] = pkgs
From d9cc1020c1372a301e3adfe59b067319c5c35c2c Mon Sep 17 00:00:00 2001
From: "magic.chen"
Date: Mon, 27 Nov 2023 20:54:27 +0800
Subject: [PATCH 2/2] docs: redirect documents to newdocs (#856)
---
README.md | 76 ++++++-------------------------------------------------
1 file changed, 8 insertions(+), 68 deletions(-)
diff --git a/README.md b/README.md
index 524541209..4ee75a3b8 100644
--- a/README.md
+++ b/README.md
@@ -4,7 +4,6 @@
-
@@ -73,24 +72,11 @@ In the era of Data 3.0, enterprises and developers can take the ability to creat
![macOS](https://img.shields.io/badge/mac%20os-000000?style=for-the-badge&logo=macos&logoColor=F0F0F0)
![Windows](https://img.shields.io/badge/Windows-0078D6?style=for-the-badge&logo=windows&logoColor=white)
-[**Usage Tutorial**](https://db-gpt.readthedocs.io/en/latest/getting_started/install/deploy/deploy.html)
-- [**Install**](https://db-gpt.readthedocs.io/en/latest/getting_started/install/deploy.html)
- - [**Install Step by Step**](https://db-gpt.readthedocs.io/en/latest/getting_started/install/deploy.html)
- - [**Docker Install**](https://db-gpt.readthedocs.io/en/latest/getting_started/install/docker/docker.html)
- - [**Docker Compose**](https://db-gpt.readthedocs.io/en/latest/getting_started/install/docker_compose/docker_compose.html)
-- [**How to Use**](https://db-gpt.readthedocs.io/en/latest/getting_started/application/chatdb/chatdb.html)
- - [**ChatData**](https://db-gpt.readthedocs.io/en/latest/getting_started/application/chatdb/chatdb.html)
- - [**ChatKnowledge**](https://db-gpt.readthedocs.io/en/latest/getting_started/application/kbqa/kbqa.html)
- - [**ChatExcel**](https://db-gpt.readthedocs.io/en/latest/getting_started/application/chatexcel/chatexcel.html)
- - [**Dashboard**](https://db-gpt.readthedocs.io/en/latest/getting_started/application/dashboard/dashboard.html)
- - [**LLM Management**](https://db-gpt.readthedocs.io/en/latest/getting_started/application/model/model.html)
- - [**Chat Agent**](https://db-gpt.readthedocs.io/en/latest/getting_started/application/chatagent/chatagent.html)
-- [**How to Deploy LLM**](https://db-gpt.readthedocs.io/en/latest/getting_started/install/cluster/cluster.html)
- - [**Standalone**](https://db-gpt.readthedocs.io/en/latest/getting_started/install/cluster/vms/standalone.html)
- - [**Cluster**](https://db-gpt.readthedocs.io/en/latest/getting_started/install/cluster/vms/index.html)
- - [**vLLM**](https://db-gpt.readthedocs.io/en/latest/getting_started/install/llm/vllm/vllm.html)
-- [**How to Debug**](https://db-gpt.readthedocs.io/en/latest/getting_started/observability.html)
-- [**FAQ**](https://db-gpt.readthedocs.io/en/latest/getting_started/faq/deploy/deploy_faq.html)
+[**Usage Tutorial**](http://docs.dbgpt.site/docs/overview)
+- [**Install**](http://docs.dbgpt.site/docs/installation)
+- [**Quickstart**](http://docs.dbgpt.site/docs/quickstart)
+- [**Application**](http://docs.dbgpt.site/docs/operation_manual)
+- [**Debugging**](http://docs.dbgpt.site/docs/operation_manual/advanced_tutorial/debugging)
## Features
@@ -114,61 +100,15 @@ At present, we have introduced several key features to showcase our current capa
- **SMMF(Service-oriented Multi-model Management Framework)**
- We offer extensive model support, including dozens of large language models (LLMs) from both open-source and API agents, such as LLaMA/LLaMA2, Baichuan, ChatGLM, Wenxin, Tongyi, Zhipu, and many more.
-
- - [Vicuna](https://huggingface.co/Tribbiani/vicuna-13b)
- - [vicuna-13b-v1.5](https://huggingface.co/lmsys/vicuna-13b-v1.5)
- - [LLama2](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf)
- - [baichuan2-13b](https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat)
- - [baichuan2-7b](https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat)
- - [chatglm-6b](https://huggingface.co/THUDM/chatglm-6b)
- - [chatglm2-6b](https://huggingface.co/THUDM/chatglm2-6b)
- - [chatglm3-6b](https://huggingface.co/THUDM/chatglm3-6b)
- - [falcon-40b](https://huggingface.co/tiiuae/falcon-40b)
- - [internlm-chat-7b](https://huggingface.co/internlm/internlm-chat-7b)
- - [internlm-chat-20b](https://huggingface.co/internlm/internlm-chat-20b)
- - [qwen-7b-chat](https://huggingface.co/Qwen/Qwen-7B-Chat)
- - [qwen-14b-chat](https://huggingface.co/Qwen/Qwen-14B-Chat)
- - [wizardlm-13b](https://huggingface.co/WizardLM/WizardLM-13B-V1.2)
- - [orca-2-7b](https://huggingface.co/microsoft/Orca-2-7b)
- - [orca-2-13b](https://huggingface.co/microsoft/Orca-2-13b)
- - [openchat_3.5](https://huggingface.co/openchat/openchat_3.5)
- - [zephyr-7b-alpha](https://huggingface.co/HuggingFaceH4/zephyr-7b-alpha)
- - [mistral-7b-instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1)
- - [Yi-34B-Chat](https://huggingface.co/01-ai/Yi-34B-Chat)
-
- - Support API Proxy LLMs
- - [x] [ChatGPT](https://api.openai.com/)
- - [x] [Tongyi](https://www.aliyun.com/product/dashscope)
- - [x] [Wenxin](https://cloud.baidu.com/product/wenxinworkshop?track=dingbutonglan)
- - [x] [ChatGLM](http://open.bigmodel.cn/)
+ We offer extensive model support, including dozens of large language models (LLMs) from both open-source and API agents, such as LLaMA/LLaMA2, Baichuan, ChatGLM, Wenxin, Tongyi, Zhipu, and many more.
+ - [Current Supported LLMs](http://docs.dbgpt.site/docs/modules/smmf)
- **Privacy and Security**
We ensure the privacy and security of data through the implementation of various technologies, including privatized large models and proxy desensitization.
- Support Datasources
-
-| DataSource | support | Notes |
-| ------------------------------------------------------------------------------ | ----------- | ------------------------------------------- |
-| [MySQL](https://www.mysql.com/) | Yes | |
-| [PostgreSQL](https://www.postgresql.org/) | Yes | |
-| [Spark](https://github.com/apache/spark) | Yes | |
-| [DuckDB](https://github.com/duckdb/duckdb) | Yes | |
-| [Sqlite](https://github.com/sqlite/sqlite) | Yes | |
-| [MSSQL](https://github.com/microsoft/mssql-jdbc) | Yes | |
-| [ClickHouse](https://github.com/ClickHouse/ClickHouse) | Yes | |
-| [Oracle](https://github.com/oracle) | No | TODO |
-| [Redis](https://github.com/redis/redis) | No | TODO |
-| [MongoDB](https://github.com/mongodb/mongo) | No | TODO |
-| [HBase](https://github.com/apache/hbase) | No | TODO |
-| [Doris](https://github.com/apache/doris) | No | TODO |
-| [DB2](https://github.com/IBM/Db2) | No | TODO |
-| [Couchbase](https://github.com/couchbase) | No | TODO |
-| [Elasticsearch](https://github.com/elastic/elasticsearch) | No | TODO |
-| [OceanBase](https://github.com/OceanBase) | No | TODO |
-| [TiDB](https://github.com/pingcap/tidb) | No | TODO |
-| [StarRocks](https://github.com/StarRocks/starrocks) | No | TODO |
+ - [Datasources](http://docs.dbgpt.site/docs/modules/connections)
## Introduction
The architecture of DB-GPT is shown in the following figure: