diff --git a/assets/schema/knowledge_management.sql b/assets/schema/knowledge_management.sql index a6f1bc478..eb292b358 100644 --- a/assets/schema/knowledge_management.sql +++ b/assets/schema/knowledge_management.sql @@ -66,6 +66,7 @@ CREATE TABLE `connect_config` ( `db_user` varchar(255) DEFAULT NULL COMMENT 'db user', `db_pwd` varchar(255) DEFAULT NULL COMMENT 'db password', `comment` text COMMENT 'db comment', + `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', PRIMARY KEY (`id`), UNIQUE KEY `uk_db` (`db_name`), KEY `idx_q_db_type` (`db_type`) @@ -78,6 +79,7 @@ CREATE TABLE `chat_history` ( `summary` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'Conversation record summary', `user_name` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'interlocutor', `messages` text COLLATE utf8mb4_unicode_ci COMMENT 'Conversation details', + `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', PRIMARY KEY (`id`) ) ENGINE=InnoDB AUTO_INCREMENT=2 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT 'Chat history'; @@ -110,6 +112,7 @@ CREATE TABLE `my_plugin` ( `version` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin version', `use_count` int DEFAULT NULL COMMENT 'plugin total use count', `succ_count` int DEFAULT NULL COMMENT 'plugin total success count', + `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', `gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'plugin install time', PRIMARY KEY (`id`), UNIQUE KEY `name` (`name`) @@ -141,6 +144,7 @@ CREATE TABLE `prompt_manage` ( `prompt_name` varchar(512) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'prompt name', `content` longtext COLLATE utf8mb4_unicode_ci COMMENT 'Prompt content', `user_name` varchar(128) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'User name', + `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', `gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', `gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', PRIMARY KEY (`id`), diff --git a/pilot/base_modules/agent/db/my_plugin_db.py b/pilot/base_modules/agent/db/my_plugin_db.py index fb4fe25ee..cca4e176d 100644 --- a/pilot/base_modules/agent/db/my_plugin_db.py +++ b/pilot/base_modules/agent/db/my_plugin_db.py @@ -32,6 +32,7 @@ class MyPluginEntity(Base): succ_count = Column( Integer, nullable=True, default=0, comment="plugin total success count" ) + sys_code = Column(String(128), index=True, nullable=True, comment="System code") gmt_created = Column( DateTime, default=datetime.utcnow, comment="plugin install time" ) @@ -58,6 +59,7 @@ def add(self, engity: MyPluginEntity): version=engity.version, use_count=engity.use_count or 0, succ_count=engity.succ_count or 0, + sys_code=engity.sys_code, gmt_created=datetime.now(), ) session.add(my_plugin) @@ -107,6 +109,8 @@ def list(self, query: MyPluginEntity, page=1, page_size=20) -> list[MyPluginEnti my_plugins = my_plugins.filter(MyPluginEntity.user_code == query.user_code) if query.user_name is not None: my_plugins = my_plugins.filter(MyPluginEntity.user_name == query.user_name) + if query.sys_code is not None: + my_plugins = my_plugins.filter(MyPluginEntity.sys_code == query.sys_code) my_plugins = my_plugins.order_by(MyPluginEntity.id.desc()) my_plugins = my_plugins.offset((page - 1) * page_size).limit(page_size) @@ -133,6 +137,8 @@ def count(self, query: MyPluginEntity): my_plugins = my_plugins.filter(MyPluginEntity.user_code == query.user_code) if query.user_name is not None: my_plugins = my_plugins.filter(MyPluginEntity.user_name == query.user_name) + if query.sys_code is not None: + my_plugins = my_plugins.filter(MyPluginEntity.sys_code == query.sys_code) count = my_plugins.scalar() session.close() return count diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index 890d380f9..b66102e24 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -128,6 +128,10 @@ def get_device() -> str: "xwin-lm-70b-v0.1": os.path.join(MODEL_PATH, "Xwin-LM-70B-V0.1"), # https://huggingface.co/01-ai/Yi-34B-Chat "yi-34b-chat": os.path.join(MODEL_PATH, "Yi-34B-Chat"), + # https://huggingface.co/01-ai/Yi-34B-Chat-8bits + "yi-34b-chat-8bits": os.path.join(MODEL_PATH, "Yi-34B-Chat-8bits"), + # https://huggingface.co/01-ai/Yi-34B-Chat-4bits + "yi-34b-chat-4bits": os.path.join(MODEL_PATH, "Yi-34B-Chat-4bits"), "yi-6b-chat": os.path.join(MODEL_PATH, "Yi-6B-Chat"), } diff --git a/pilot/connections/manages/connect_config_db.py b/pilot/connections/manages/connect_config_db.py index 0898bfbea..602ff719b 100644 --- a/pilot/connections/manages/connect_config_db.py +++ b/pilot/connections/manages/connect_config_db.py @@ -24,6 +24,7 @@ class ConnectConfigEntity(Base): db_user = Column(String(255), nullable=True, comment="db user") db_pwd = Column(String(255), nullable=True, comment="db password") comment = Column(Text, nullable=True, comment="db comment") + sys_code = Column(String(128), index=True, nullable=True, comment="System code") __table_args__ = ( UniqueConstraint("db_name", name="uk_db"), diff --git a/pilot/memory/chat_history/base.py b/pilot/memory/chat_history/base.py index a8a09153c..6ad9a604b 100644 --- a/pilot/memory/chat_history/base.py +++ b/pilot/memory/chat_history/base.py @@ -1,7 +1,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import List +from typing import List, Optional, Dict from enum import Enum from pilot.scene.message import OnceConversation @@ -35,11 +35,6 @@ def append(self, message: OnceConversation) -> None: # def clear(self) -> None: # """Clear session memory from the local file""" - @abstractmethod - def conv_list(self, user_name: str = None) -> None: - """get user's conversation list""" - pass - @abstractmethod def update(self, messages: List[OnceConversation]) -> None: pass @@ -49,7 +44,7 @@ def delete(self) -> bool: pass @abstractmethod - def conv_info(self, conv_uid: str = None) -> None: + def conv_info(self, conv_uid: Optional[str] = None) -> None: pass @abstractmethod @@ -57,5 +52,7 @@ def get_messages(self) -> List[OnceConversation]: pass @staticmethod - def conv_list(cls, user_name: str = None) -> None: - pass + def conv_list( + user_name: Optional[str] = None, sys_code: Optional[str] = None + ) -> List[Dict]: + """get user's conversation list""" diff --git a/pilot/memory/chat_history/chat_hisotry_factory.py b/pilot/memory/chat_history/chat_hisotry_factory.py index c1a8f9cab..f8f02f346 100644 --- a/pilot/memory/chat_history/chat_hisotry_factory.py +++ b/pilot/memory/chat_history/chat_hisotry_factory.py @@ -1,3 +1,4 @@ +from typing import Type from .base import MemoryStoreType from pilot.configs.config import Config from pilot.memory.chat_history.base import BaseChatHistoryMemory @@ -32,5 +33,5 @@ def get_store_instance(self, chat_session_id: str) -> BaseChatHistoryMemory: chat_session_id ) - def get_store_cls(self): + def get_store_cls(self) -> Type[BaseChatHistoryMemory]: return self.mem_store_class_map.get(CFG.CHAT_HISTORY_STORE_TYPE) diff --git a/pilot/memory/chat_history/chat_history_db.py b/pilot/memory/chat_history/chat_history_db.py index 8ef898c27..010699d45 100644 --- a/pilot/memory/chat_history/chat_history_db.py +++ b/pilot/memory/chat_history/chat_history_db.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Optional from sqlalchemy import Column, Integer, String, Index, DateTime, func, Boolean, Text from sqlalchemy import UniqueConstraint @@ -32,7 +32,7 @@ class ChatHistoryEntity(Base): messages = Column( Text(length=2**31 - 1), nullable=True, comment="Conversation details" ) - + sys_code = Column(String(128), index=True, nullable=True, comment="System code") UniqueConstraint("conv_uid", name="uk_conversation") Index("idx_q_user", "user_name") Index("idx_q_mode", "chat_mode") @@ -48,11 +48,15 @@ def __init__(self): session=session, ) - def list_last_20(self, user_name: str = None): + def list_last_20( + self, user_name: Optional[str] = None, sys_code: Optional[str] = None + ): session = self.get_session() chat_history = session.query(ChatHistoryEntity) if user_name: chat_history = chat_history.filter(ChatHistoryEntity.user_name == user_name) + if sys_code: + chat_history = chat_history.filter(ChatHistoryEntity.sys_code == sys_code) chat_history = chat_history.order_by(ChatHistoryEntity.id.desc()) diff --git a/pilot/memory/chat_history/store_type/duckdb_history.py b/pilot/memory/chat_history/store_type/duckdb_history.py index 28c92a142..790577343 100644 --- a/pilot/memory/chat_history/store_type/duckdb_history.py +++ b/pilot/memory/chat_history/store_type/duckdb_history.py @@ -1,7 +1,7 @@ import json import os import duckdb -from typing import List +from typing import List, Dict, Optional from pilot.configs.config import Config from pilot.memory.chat_history.base import BaseChatHistoryMemory @@ -37,7 +37,7 @@ def __init_chat_history_tables(self): if not result: # 如果表不存在,则创建新表 self.connect.execute( - "CREATE TABLE chat_history (id integer primary key, conv_uid VARCHAR(100) UNIQUE, chat_mode VARCHAR(50), summary VARCHAR(255), user_name VARCHAR(100), messages TEXT)" + "CREATE TABLE chat_history (id integer primary key, conv_uid VARCHAR(100) UNIQUE, chat_mode VARCHAR(50), summary VARCHAR(255), user_name VARCHAR(100), sys_code VARCHAR(128), messages TEXT)" ) self.connect.execute("CREATE SEQUENCE seq_id START 1;") @@ -61,8 +61,8 @@ def create(self, chat_mode, summary: str, user_name: str) -> None: try: cursor = self.connect.cursor() cursor.execute( - "INSERT INTO chat_history(id, conv_uid, chat_mode summary, user_name, messages)VALUES(nextval('seq_id'),?,?,?,?,?)", - [self.chat_seesion_id, chat_mode, summary, user_name, ""], + "INSERT INTO chat_history(id, conv_uid, chat_mode summary, user_name, sys_code, messages)VALUES(nextval('seq_id'),?,?,?,?,?,?)", + [self.chat_seesion_id, chat_mode, summary, user_name, "", ""], ) cursor.commit() self.connect.commit() @@ -83,12 +83,13 @@ def append(self, once_message: OnceConversation) -> None: ) else: cursor.execute( - "INSERT INTO chat_history(id, conv_uid, chat_mode, summary, user_name, messages)VALUES(nextval('seq_id'),?,?,?,?,?)", + "INSERT INTO chat_history(id, conv_uid, chat_mode, summary, user_name, sys_code, messages)VALUES(nextval('seq_id'),?,?,?,?,?,?)", [ self.chat_seesion_id, once_message.chat_mode, once_message.get_user_conv().content, - "", + once_message.user_name, + once_message.sys_code, json.dumps(conversations, ensure_ascii=False), ], ) @@ -149,17 +150,26 @@ def get_messages(self) -> List[OnceConversation]: return None @staticmethod - def conv_list(cls, user_name: str = None) -> None: + def conv_list( + user_name: Optional[str] = None, sys_code: Optional[str] = None + ) -> List[Dict]: if os.path.isfile(duckdb_path): cursor = duckdb.connect(duckdb_path).cursor() + query = "SELECT * FROM chat_history" + params = [] + conditions = [] if user_name: - cursor.execute( - "SELECT * FROM chat_history where user_name=? order by id desc limit 20", - [user_name], - ) - else: - cursor.execute("SELECT * FROM chat_history order by id desc limit 20") - # 获取查询结果字段名 + conditions.append("user_name = ?") + params.append(user_name) + if sys_code: + conditions.append("sys_code = ?") + params.append(sys_code) + + if conditions: + query += " WHERE " + " AND ".join(conditions) + + query += " ORDER BY id DESC LIMIT 20" + cursor.execute(query, params) fields = [field[0] for field in cursor.description] data = [] for row in cursor.fetchall(): diff --git a/pilot/memory/chat_history/store_type/meta_db_history.py b/pilot/memory/chat_history/store_type/meta_db_history.py index f1c25d633..2e69e3ec5 100644 --- a/pilot/memory/chat_history/store_type/meta_db_history.py +++ b/pilot/memory/chat_history/store_type/meta_db_history.py @@ -1,6 +1,6 @@ import json import logging -from typing import List +from typing import List, Dict, Optional from sqlalchemy import Column, Integer, String, Index, DateTime, func, Boolean, Text from sqlalchemy import UniqueConstraint from pilot.configs.config import Config @@ -62,7 +62,8 @@ def append(self, once_message: OnceConversation) -> None: chat_history: ChatHistoryEntity = ChatHistoryEntity() chat_history.conv_uid = self.chat_seesion_id chat_history.chat_mode = once_message.chat_mode - chat_history.user_name = "default" + chat_history.user_name = once_message.user_name + chat_history.sys_code = once_message.sys_code chat_history.summary = once_message.get_user_conv().content conversations.append(_conversation_to_dic(once_message)) @@ -92,9 +93,11 @@ def get_messages(self) -> List[OnceConversation]: return [] @staticmethod - def conv_list(cls, user_name: str = None) -> None: + def conv_list( + user_name: Optional[str] = None, sys_code: Optional[str] = None + ) -> List[Dict]: chat_history_dao = ChatHistoryDao() - history_list = chat_history_dao.list_last_20() + history_list = chat_history_dao.list_last_20(user_name, sys_code) result = [] for history in history_list: result.append(history.__dict__) diff --git a/pilot/openapi/api_v1/api_v1.py b/pilot/openapi/api_v1/api_v1.py index 70bfc2e9f..262ff6f6c 100644 --- a/pilot/openapi/api_v1/api_v1.py +++ b/pilot/openapi/api_v1/api_v1.py @@ -2,7 +2,7 @@ import uuid import asyncio import os -import shutil +import aiofiles import logging from fastapi import ( APIRouter, @@ -17,7 +17,7 @@ from fastapi.responses import StreamingResponse from fastapi.exceptions import RequestValidationError -from typing import List +from typing import List, Optional import tempfile from concurrent.futures import Executor @@ -48,7 +48,11 @@ from pilot.model.cluster import BaseModelController, WorkerManager, WorkerManagerFactory from pilot.model.base import FlatSupportedModel from pilot.utils.tracer import root_tracer, SpanType -from pilot.utils.executor_utils import ExecutorFactory, blocking_func_to_async +from pilot.utils.executor_utils import ( + ExecutorFactory, + blocking_func_to_async, + DefaultExecutorFactory, +) router = APIRouter() CFG = Config() @@ -68,9 +72,11 @@ def __get_conv_user_message(conversations: dict): return "" -def __new_conversation(chat_mode, user_id) -> ConversationVo: +def __new_conversation(chat_mode, user_name: str, sys_code: str) -> ConversationVo: unique_id = uuid.uuid1() - return ConversationVo(conv_uid=str(unique_id), chat_mode=chat_mode) + return ConversationVo( + conv_uid=str(unique_id), chat_mode=chat_mode, sys_code=sys_code + ) def get_db_list(): @@ -141,7 +147,9 @@ def get_worker_manager() -> WorkerManager: def get_executor() -> Executor: """Get the global default executor""" return CFG.SYSTEM_APP.get_component( - ComponentType.EXECUTOR_DEFAULT, ExecutorFactory + ComponentType.EXECUTOR_DEFAULT, + ExecutorFactory, + or_register_component=DefaultExecutorFactory, ).create() @@ -166,7 +174,6 @@ async def db_connect_delete(db_name: str = None): async def async_db_summary_embedding(db_name, db_type): - # 在这里执行需要异步运行的代码 db_summary_client = DBSummaryClient(system_app=CFG.SYSTEM_APP) db_summary_client.db_summary_embedding(db_name, db_type) @@ -200,16 +207,21 @@ async def db_support_types(): @router.get("/v1/chat/dialogue/list", response_model=Result[ConversationVo]) -async def dialogue_list(user_id: str = None): +async def dialogue_list( + user_name: str = None, user_id: str = None, sys_code: str = None +): dialogues: List = [] chat_history_service = ChatHistory() # TODO Change the synchronous call to the asynchronous call - datas = chat_history_service.get_store_cls().conv_list(user_id) + user_name = user_name or user_id + datas = chat_history_service.get_store_cls().conv_list(user_name, sys_code) for item in datas: conv_uid = item.get("conv_uid") summary = item.get("summary") chat_mode = item.get("chat_mode") model_name = item.get("model_name", CFG.LLM_MODEL) + user_name = item.get("user_name") + sys_code = item.get("sys_code") messages = json.loads(item.get("messages")) last_round = max(messages, key=lambda x: x["chat_order"]) @@ -223,6 +235,8 @@ async def dialogue_list(user_id: str = None): chat_mode=chat_mode, model_name=model_name, select_param=select_param, + user_name=user_name, + sys_code=sys_code, ) dialogues.append(conv_vo) @@ -254,9 +268,14 @@ async def dialogue_scenes(): @router.post("/v1/chat/dialogue/new", response_model=Result[ConversationVo]) async def dialogue_new( - chat_mode: str = ChatScene.ChatNormal.value(), user_id: str = None + chat_mode: str = ChatScene.ChatNormal.value(), + user_name: str = None, + # TODO remove user id + user_id: str = None, + sys_code: str = None, ): - conv_vo = __new_conversation(chat_mode, user_id) + user_name = user_name or user_id + conv_vo = __new_conversation(chat_mode, user_name, sys_code) return Result.succ(conv_vo) @@ -280,40 +299,40 @@ async def params_list(chat_mode: str = ChatScene.ChatNormal.value()): @router.post("/v1/chat/mode/params/file/load") async def params_load( - conv_uid: str, chat_mode: str, model_name: str, doc_file: UploadFile = File(...) + conv_uid: str, + chat_mode: str, + model_name: str, + user_name: Optional[str] = None, + sys_code: Optional[str] = None, + doc_file: UploadFile = File(...), ): print(f"params_load: {conv_uid},{chat_mode},{model_name}") try: if doc_file: - ## file save - if not os.path.exists(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode)): - os.makedirs(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode)) - # We can not move temp file in windows system when we open file in context of `with` - tmp_fd, tmp_path = tempfile.mkstemp( - dir=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode) - ) - # TODO Use noblocking file save with aiofiles - with os.fdopen(tmp_fd, "wb") as tmp: - tmp.write(await doc_file.read()) - shutil.move( - tmp_path, - os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode, doc_file.filename), - ) - ## chat prepare + # Save the uploaded file + upload_dir = os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode) + os.makedirs(upload_dir, exist_ok=True) + upload_path = os.path.join(upload_dir, doc_file.filename) + async with aiofiles.open(upload_path, "wb") as f: + await f.write(await doc_file.read()) + + # Prepare the chat dialogue = ConversationVo( conv_uid=conv_uid, chat_mode=chat_mode, select_param=doc_file.filename, model_name=model_name, + user_name=user_name, + sys_code=sys_code, ) chat: BaseChat = await get_chat_instance(dialogue) resp = await chat.prepare() - ### refresh messages + # Refresh messages return Result.succ(get_hist_messages(conv_uid)) except Exception as e: logger.error("excel load error!", e) - return Result.failed(code="E000X", msg=f"File Load Error {e}") + return Result.failed(code="E000X", msg=f"File Load Error {str(e)}") @router.post("/v1/chat/dialogue/delete") @@ -354,7 +373,9 @@ async def get_chat_instance(dialogue: ConversationVo = Body()) -> BaseChat: if not dialogue.chat_mode: dialogue.chat_mode = ChatScene.ChatNormal.value() if not dialogue.conv_uid: - conv_vo = __new_conversation(dialogue.chat_mode, dialogue.user_name) + conv_vo = __new_conversation( + dialogue.chat_mode, dialogue.user_name, dialogue.sys_code + ) dialogue.conv_uid = conv_vo.conv_uid if not ChatScene.is_valid_mode(dialogue.chat_mode): @@ -364,13 +385,12 @@ async def get_chat_instance(dialogue: ConversationVo = Body()) -> BaseChat: chat_param = { "chat_session_id": dialogue.conv_uid, + "user_name": dialogue.user_name, + "sys_code": dialogue.sys_code, "current_user_input": dialogue.user_input, "select_param": dialogue.select_param, "model_name": dialogue.model_name, } - # chat: BaseChat = CHAT_FACTORY.get_implementation( - # dialogue.chat_mode, **{"chat_param": chat_param} - # ) chat: BaseChat = await blocking_func_to_async( get_executor(), CHAT_FACTORY.get_implementation, @@ -401,8 +421,6 @@ async def chat_completions(dialogue: ConversationVo = Body()): "get_chat_instance", span_type=SpanType.CHAT, metadata=dialogue.dict() ): chat: BaseChat = await get_chat_instance(dialogue) - # background_tasks = BackgroundTasks() - # background_tasks.add_task(release_model_semaphore) headers = { "Content-Type": "text/event-stream", "Cache-Control": "no-cache", diff --git a/pilot/openapi/api_view_model.py b/pilot/openapi/api_view_model.py index af1aa4b9c..5212076e8 100644 --- a/pilot/openapi/api_view_model.py +++ b/pilot/openapi/api_view_model.py @@ -66,6 +66,8 @@ class ConversationVo(BaseModel): """ incremental: bool = False + sys_code: Optional[str] = None + class MessageVo(BaseModel): """ diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index 0e263f7e5..864eb34b6 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -78,7 +78,9 @@ def __init__(self, chat_param: Dict): self.history_message: List[OnceConversation] = self.memory.messages() self.current_message: OnceConversation = OnceConversation( - self.chat_mode.value() + self.chat_mode.value(), + user_name=chat_param.get("user_name"), + sys_code=chat_param.get("sys_code"), ) self.current_message.model_name = self.llm_model if chat_param["select_param"]: @@ -171,7 +173,6 @@ async def __call_base(self): "messages": llm_messages, "temperature": float(self.prompt_template.temperature), "max_new_tokens": int(self.prompt_template.max_new_tokens), - # "stop": self.prompt_template.sep, "echo": self.llm_echo, } return payload diff --git a/pilot/scene/message.py b/pilot/scene/message.py index 4d5a5c383..592f3cda0 100644 --- a/pilot/scene/message.py +++ b/pilot/scene/message.py @@ -18,7 +18,7 @@ class OnceConversation: All the information of a conversation, the current single service in memory, can expand cache and database support distributed services """ - def __init__(self, chat_mode): + def __init__(self, chat_mode, user_name: str = None, sys_code: str = None): self.chat_mode: str = chat_mode self.messages: List[BaseMessage] = [] self.start_date: str = "" @@ -28,6 +28,8 @@ def __init__(self, chat_mode): self.param_value: str = "" self.cost: int = 0 self.tokens: int = 0 + self.user_name: str = user_name + self.sys_code: str = sys_code def add_user_message(self, message: str) -> None: """Add a user message to the store""" @@ -113,6 +115,8 @@ def _conversation_to_dic(once: OnceConversation) -> dict: "messages": messages_to_dict(once.messages), "param_type": once.param_type, "param_value": once.param_value, + "user_name": once.user_name, + "sys_code": once.sys_code, } @@ -121,7 +125,9 @@ def conversations_to_dict(conversations: List[OnceConversation]) -> List[dict]: def conversation_from_dict(once: dict) -> OnceConversation: - conversation = OnceConversation() + conversation = OnceConversation( + once.get("chat_mode"), once.get("user_name"), once.get("sys_code") + ) conversation.cost = once.get("cost", 0) conversation.chat_mode = once.get("chat_mode", "chat_normal") conversation.tokens = once.get("tokens", 0) diff --git a/pilot/server/prompt/prompt_manage_db.py b/pilot/server/prompt/prompt_manage_db.py index 3245b3450..eff75c270 100644 --- a/pilot/server/prompt/prompt_manage_db.py +++ b/pilot/server/prompt/prompt_manage_db.py @@ -29,6 +29,7 @@ class PromptManageEntity(Base): prompt_name = Column(String(512)) content = Column(Text) user_name = Column(String(128)) + sys_code = Column(String(128), index=True, nullable=True, comment="System code") gmt_created = Column(DateTime) gmt_modified = Column(DateTime) diff --git a/setup.py b/setup.py index ca0bd818f..d1386c79e 100644 --- a/setup.py +++ b/setup.py @@ -1,5 +1,4 @@ -from typing import List, Tuple - +from typing import List, Tuple, Optional, Callable import setuptools import platform import subprocess @@ -10,6 +9,7 @@ import re import shutil from setuptools import find_packages +import functools with open("README.md", mode="r", encoding="utf-8") as fh: long_description = fh.read() @@ -34,8 +34,15 @@ def parse_requirements(file_name: str) -> List[str]: def get_latest_version(package_name: str, index_url: str, default_version: str): + python_command = shutil.which("python") + if not python_command: + python_command = shutil.which("python3") + if not python_command: + print("Python command not found.") + return default_version + command = [ - "python", + python_command, "-m", "pip", "index", @@ -125,6 +132,7 @@ class OSType(Enum): OTHER = "other" +@functools.cache def get_cpu_avx_support() -> Tuple[OSType, AVXType]: system = platform.system() os_type = OSType.OTHER @@ -206,6 +214,57 @@ def get_cuda_version() -> str: return None +def _build_wheels( + pkg_name: str, + pkg_version: str, + base_url: str = None, + base_url_func: Callable[[str, str, str], str] = None, + pkg_file_func: Callable[[str, str, str, str, OSType], str] = None, + supported_cuda_versions: List[str] = ["11.7", "11.8"], +) -> Optional[str]: + """ + Build the URL for the package wheel file based on the package name, version, and CUDA version. + Args: + pkg_name (str): The name of the package. + pkg_version (str): The version of the package. + base_url (str): The base URL for downloading the package. + base_url_func (Callable): A function to generate the base URL. + pkg_file_func (Callable): build package file function. + function params: pkg_name, pkg_version, cuda_version, py_version, OSType + supported_cuda_versions (List[str]): The list of supported CUDA versions. + Returns: + Optional[str]: The URL for the package wheel file. + """ + os_type, _ = get_cpu_avx_support() + cuda_version = get_cuda_version() + py_version = platform.python_version() + py_version = "cp" + "".join(py_version.split(".")[0:2]) + if os_type == OSType.DARWIN or not cuda_version: + return None + if cuda_version not in supported_cuda_versions: + print( + f"Warnning: {pkg_name} supported cuda version: {supported_cuda_versions}, replace to {supported_cuda_versions[-1]}" + ) + cuda_version = supported_cuda_versions[-1] + + cuda_version = "cu" + cuda_version.replace(".", "") + os_pkg_name = "linux_x86_64" if os_type == OSType.LINUX else "win_amd64" + if base_url_func: + base_url = base_url_func(pkg_version, cuda_version, py_version) + if base_url and base_url.endswith("/"): + base_url = base_url[:-1] + if pkg_file_func: + full_pkg_file = pkg_file_func( + pkg_name, pkg_version, cuda_version, py_version, os_type + ) + else: + full_pkg_file = f"{pkg_name}-{pkg_version}+{cuda_version}-{py_version}-{py_version}-{os_pkg_name}.whl" + if not base_url: + return full_pkg_file + else: + return f"{base_url}/{full_pkg_file}" + + def torch_requires( torch_version: str = "2.0.1", torchvision_version: str = "0.15.2", @@ -222,16 +281,20 @@ def torch_requires( cuda_version = get_cuda_version() if cuda_version: supported_versions = ["11.7", "11.8"] - if cuda_version not in supported_versions: - print( - f"PyTorch version {torch_version} supported cuda version: {supported_versions}, replace to {supported_versions[-1]}" - ) - cuda_version = supported_versions[-1] - cuda_version = "cu" + cuda_version.replace(".", "") - py_version = "cp310" - os_pkg_name = "linux_x86_64" if os_type == OSType.LINUX else "win_amd64" - torch_url = f"https://download.pytorch.org/whl/{cuda_version}/torch-{torch_version}+{cuda_version}-{py_version}-{py_version}-{os_pkg_name}.whl" - torchvision_url = f"https://download.pytorch.org/whl/{cuda_version}/torchvision-{torchvision_version}+{cuda_version}-{py_version}-{py_version}-{os_pkg_name}.whl" + # torch_url = f"https://download.pytorch.org/whl/{cuda_version}/torch-{torch_version}+{cuda_version}-{py_version}-{py_version}-{os_pkg_name}.whl" + # torchvision_url = f"https://download.pytorch.org/whl/{cuda_version}/torchvision-{torchvision_version}+{cuda_version}-{py_version}-{py_version}-{os_pkg_name}.whl" + torch_url = _build_wheels( + "torch", + torch_version, + base_url_func=lambda v, x, y: f"https://download.pytorch.org/whl/{x}", + supported_cuda_versions=supported_versions, + ) + torchvision_url = _build_wheels( + "torchvision", + torch_version, + base_url_func=lambda v, x, y: f"https://download.pytorch.org/whl/{x}", + supported_cuda_versions=supported_versions, + ) torch_url_cached = cache_package( torch_url, "torch", os_type == OSType.WINDOWS ) @@ -327,6 +390,7 @@ def core_requires(): "xlrd==2.0.1", # for cache, TODO pympler has not been updated for a long time and needs to find a new toolkit. "pympler", + "aiofiles", ] if BUILD_FROM_SOURCE: setup_spec.extras["framework"].append( @@ -360,6 +424,41 @@ def llama_cpp_requires(): llama_cpp_python_cuda_requires() +def _build_autoawq_requires() -> Optional[str]: + os_type, _ = get_cpu_avx_support() + if os_type == OSType.DARWIN: + return None + auto_gptq_version = get_latest_version( + "auto-gptq", "https://huggingface.github.io/autogptq-index/whl/cu118/", "0.5.1" + ) + # eg. 0.5.1+cu118 + auto_gptq_version = auto_gptq_version.split("+")[0] + + def pkg_file_func(pkg_name, pkg_version, cuda_version, py_version, os_type): + pkg_name = pkg_name.replace("-", "_") + if os_type == OSType.DARWIN: + return None + os_pkg_name = ( + "manylinux_2_17_x86_64.manylinux2014_x86_64.whl" + if os_type == OSType.LINUX + else "win_amd64.whl" + ) + return f"{pkg_name}-{pkg_version}+{cuda_version}-{py_version}-{py_version}-{os_pkg_name}" + + auto_gptq_url = _build_wheels( + "auto-gptq", + auto_gptq_version, + base_url_func=lambda v, x, y: f"https://huggingface.github.io/autogptq-index/whl/{x}/auto-gptq", + pkg_file_func=pkg_file_func, + supported_cuda_versions=["11.8"], + ) + if auto_gptq_url: + print(f"Install auto-gptq from {auto_gptq_url}") + return f"auto-gptq @ {auto_gptq_url}" + else: + "auto-gptq" + + def quantization_requires(): pkgs = [] os_type, _ = get_cpu_avx_support() @@ -379,6 +478,28 @@ def quantization_requires(): print(pkgs) # For chatglm2-6b-int4 pkgs += ["cpm_kernels"] + + # Since transformers 4.35.0, the GPT-Q/AWQ model can be loaded using AutoModelForCausalLM. + # autoawq requirements: + # 1. Compute Capability 7.5 (sm75). Turing and later architectures are supported. + # 2. CUDA Toolkit 11.8 and later. + autoawq_url = _build_wheels( + "autoawq", + "0.1.7", + base_url_func=lambda v, x, y: f"https://github.com/casper-hansen/AutoAWQ/releases/download/v{v}", + supported_cuda_versions=["11.8"], + ) + if autoawq_url: + print(f"Install autoawq from {autoawq_url}") + pkgs.append(f"autoawq @ {autoawq_url}") + else: + pkgs.append("autoawq") + + auto_gptq_pkg = _build_autoawq_requires() + if auto_gptq_pkg: + pkgs.append(auto_gptq_pkg) + pkgs.append("optimum") + setup_spec.extras["quantization"] = pkgs