diff --git a/pilot/memory/chat_history/chat_hisotry_factory.py b/pilot/memory/chat_history/chat_hisotry_factory.py index 64d30e971..c1a8f9cab 100644 --- a/pilot/memory/chat_history/chat_hisotry_factory.py +++ b/pilot/memory/chat_history/chat_hisotry_factory.py @@ -1,5 +1,6 @@ from .base import MemoryStoreType from pilot.configs.config import Config +from pilot.memory.chat_history.base import BaseChatHistoryMemory CFG = Config() @@ -18,7 +19,15 @@ def __init__(self): self.mem_store_class_map[DbHistoryMemory.store_type] = DbHistoryMemory self.mem_store_class_map[MemHistoryMemory.store_type] = MemHistoryMemory - def get_store_instance(self, chat_session_id): + def get_store_instance(self, chat_session_id: str) -> BaseChatHistoryMemory: + """New store instance for store chat histories + + Args: + chat_session_id (str): conversation session id + + Returns: + BaseChatHistoryMemory: Store instance + """ return self.mem_store_class_map.get(CFG.CHAT_HISTORY_STORE_TYPE)( chat_session_id ) diff --git a/pilot/openapi/api_v1/api_v1.py b/pilot/openapi/api_v1/api_v1.py index ea569b3e4..ae8e74a2b 100644 --- a/pilot/openapi/api_v1/api_v1.py +++ b/pilot/openapi/api_v1/api_v1.py @@ -19,6 +19,7 @@ from fastapi.exceptions import RequestValidationError from typing import List import tempfile +from concurrent.futures import Executor from pilot.component import ComponentType from pilot.openapi.api_view_model import ( @@ -46,6 +47,7 @@ from pilot.memory.chat_history.chat_hisotry_factory import ChatHistory from pilot.model.cluster import BaseModelController, WorkerManager, WorkerManagerFactory from pilot.model.base import FlatSupportedModel +from pilot.utils.executor_utils import ExecutorFactory, blocking_func_to_async router = APIRouter() CFG = Config() @@ -129,6 +131,13 @@ def get_worker_manager() -> WorkerManager: return worker_manager +def get_executor() -> Executor: + """Get the global default executor""" + return CFG.SYSTEM_APP.get_component( + ComponentType.EXECUTOR_DEFAULT, ExecutorFactory + ).create() + + @router.get("/v1/chat/db/list", response_model=Result[DBConfig]) async def db_connect_list(): return Result.succ(CFG.LOCAL_DB_MANAGE.get_db_list()) @@ -158,6 +167,7 @@ async def async_db_summary_embedding(db_name, db_type): @router.post("/v1/chat/db/test/connect", response_model=Result[bool]) async def test_connect(db_config: DBConfig = Body()): try: + # TODO Change the synchronous call to the asynchronous call CFG.LOCAL_DB_MANAGE.test_connect(db_config) return Result.succ(True) except Exception as e: @@ -166,6 +176,7 @@ async def test_connect(db_config: DBConfig = Body()): @router.post("/v1/chat/db/summary", response_model=Result[bool]) async def db_summary(db_name: str, db_type: str): + # TODO Change the synchronous call to the asynchronous call async_db_summary_embedding(db_name, db_type) return Result.succ(True) @@ -185,6 +196,7 @@ async def db_support_types(): async def dialogue_list(user_id: str = None): dialogues: List = [] chat_history_service = ChatHistory() + # TODO Change the synchronous call to the asynchronous call datas = chat_history_service.get_store_cls().conv_list(user_id) for item in datas: conv_uid = item.get("conv_uid") @@ -285,7 +297,7 @@ async def params_load( select_param=doc_file.filename, model_name=model_name, ) - chat: BaseChat = get_chat_instance(dialogue) + chat: BaseChat = await get_chat_instance(dialogue) resp = await chat.prepare() ### refresh messages @@ -299,6 +311,7 @@ async def params_load( async def dialogue_delete(con_uid: str): history_fac = ChatHistory() history_mem = history_fac.get_store_instance(con_uid) + # TODO Change the synchronous call to the asynchronous call history_mem.delete() return Result.succ(None) @@ -324,10 +337,11 @@ def get_hist_messages(conv_uid: str): @router.get("/v1/chat/dialogue/messages/history", response_model=Result[MessageVo]) async def dialogue_history_messages(con_uid: str): print(f"dialogue_history_messages:{con_uid}") + # TODO Change the synchronous call to the asynchronous call return Result.succ(get_hist_messages(con_uid)) -def get_chat_instance(dialogue: ConversationVo = Body()) -> BaseChat: +async def get_chat_instance(dialogue: ConversationVo = Body()) -> BaseChat: logger.info(f"get_chat_instance:{dialogue}") if not dialogue.chat_mode: dialogue.chat_mode = ChatScene.ChatNormal.value() @@ -346,8 +360,14 @@ def get_chat_instance(dialogue: ConversationVo = Body()) -> BaseChat: "select_param": dialogue.select_param, "model_name": dialogue.model_name, } - chat: BaseChat = CHAT_FACTORY.get_implementation( - dialogue.chat_mode, **{"chat_param": chat_param} + # chat: BaseChat = CHAT_FACTORY.get_implementation( + # dialogue.chat_mode, **{"chat_param": chat_param} + # ) + chat: BaseChat = await blocking_func_to_async( + get_executor(), + CHAT_FACTORY.get_implementation, + dialogue.chat_mode, + **{"chat_param": chat_param}, ) return chat @@ -357,7 +377,7 @@ async def chat_prepare(dialogue: ConversationVo = Body()): # dialogue.model_name = CFG.LLM_MODEL logger.info(f"chat_prepare:{dialogue}") ## check conv_uid - chat: BaseChat = get_chat_instance(dialogue) + chat: BaseChat = await get_chat_instance(dialogue) if len(chat.history_message) > 0: return Result.succ(None) resp = await chat.prepare() @@ -369,7 +389,7 @@ async def chat_completions(dialogue: ConversationVo = Body()): print( f"chat_completions:{dialogue.chat_mode},{dialogue.select_param},{dialogue.model_name}" ) - chat: BaseChat = get_chat_instance(dialogue) + chat: BaseChat = await get_chat_instance(dialogue) # background_tasks = BackgroundTasks() # background_tasks.add_task(release_model_semaphore) headers = { diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index 5d685573e..d22a0a85f 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -12,6 +12,7 @@ from pilot.scene.base_message import ModelMessage, ModelMessageRoleType from pilot.scene.message import OnceConversation from pilot.utils import get_or_create_event_loop +from pilot.utils.executor_utils import ExecutorFactory, blocking_func_to_async from pydantic import Extra from pilot.memory.chat_history.chat_hisotry_factory import ChatHistory @@ -80,6 +81,10 @@ def __init__(self, chat_param: Dict): self.current_message.param_type = self.chat_mode.param_types()[0] self.current_message.param_value = chat_param["select_param"] self.current_tokens_used: int = 0 + # The executor to submit blocking function + self._executor = CFG.SYSTEM_APP.get_component( + ComponentType.EXECUTOR_DEFAULT, ExecutorFactory + ).create() class Config: """Configuration for this pydantic object.""" @@ -92,8 +97,14 @@ def chat_type(self) -> str: raise NotImplementedError("Not supported for this chat type.") @abstractmethod - def generate_input_values(self): - pass + async def generate_input_values(self) -> Dict: + """Generate input to LLM + + Please note that you must not perform any blocking operations in this function + + Returns: + a dictionary to be formatted by prompt template + """ def do_action(self, prompt_response): return prompt_response @@ -116,8 +127,8 @@ def get_llm_speak(self, prompt_define_response): speak_to_user = prompt_define_response return speak_to_user - def __call_base(self): - input_values = self.generate_input_values() + async def __call_base(self): + input_values = await self.generate_input_values() ### Chat sequence advance self.current_message.chat_order = len(self.history_message) + 1 self.current_message.add_user_message(self.current_user_input) @@ -159,7 +170,7 @@ async def check_iterator_end(iterator): async def stream_call(self): # TODO Retry when server connection error - payload = self.__call_base() + payload = await self.__call_base() self.skip_echo_len = len(payload.get("prompt").replace("", " ")) + 11 logger.info(f"Requert: \n{payload}") @@ -190,7 +201,7 @@ async def stream_call(self): self.memory.append(self.current_message) async def nostream_call(self): - payload = self.__call_base() + payload = await self.__call_base() logger.info(f"Request: \n{payload}") ai_response_text = "" try: @@ -216,14 +227,24 @@ async def nostream_call(self): ) ) ### run - result = self.do_action(prompt_define_response) + # result = self.do_action(prompt_define_response) + result = await blocking_func_to_async( + self._executor, self.do_action, prompt_define_response + ) ### llm speaker speak_to_user = self.get_llm_speak(prompt_define_response) - view_message = self.prompt_template.output_parser.parse_view_response( - speak_to_user, result + # view_message = self.prompt_template.output_parser.parse_view_response( + # speak_to_user, result + # ) + view_message = await blocking_func_to_async( + self._executor, + self.prompt_template.output_parser.parse_view_response, + speak_to_user, + result, ) + view_message = view_message.replace("\n", "\\n") self.current_message.add_view_message(view_message) except Exception as e: diff --git a/pilot/scene/chat_agent/chat.py b/pilot/scene/chat_agent/chat.py index 0fc5a0375..d9a8f60c1 100644 --- a/pilot/scene/chat_agent/chat.py +++ b/pilot/scene/chat_agent/chat.py @@ -51,7 +51,7 @@ def __init__(self, chat_param: Dict): self.api_call = ApiCall(plugin_generator=self.plugins_prompt_generator) - def generate_input_values(self): + async def generate_input_values(self) -> Dict[str, str]: input_values = { "user_goal": self.current_user_input, "expand_constraints": self.__list_to_prompt_str( diff --git a/pilot/scene/chat_dashboard/chat.py b/pilot/scene/chat_dashboard/chat.py index 7e4433670..211aa7c04 100644 --- a/pilot/scene/chat_dashboard/chat.py +++ b/pilot/scene/chat_dashboard/chat.py @@ -12,6 +12,7 @@ ) from pilot.scene.chat_dashboard.prompt import prompt from pilot.scene.chat_dashboard.data_loader import DashboardDataLoader +from pilot.utils.executor_utils import blocking_func_to_async CFG = Config() @@ -52,7 +53,7 @@ def __load_dashboard_template(self, template_name): data = f.read() return json.loads(data) - def generate_input_values(self): + async def generate_input_values(self) -> Dict: try: from pilot.summary.db_summary_client import DBSummaryClient except ImportError: @@ -60,9 +61,16 @@ def generate_input_values(self): client = DBSummaryClient(system_app=CFG.SYSTEM_APP) try: - table_infos = client.get_similar_tables( - dbname=self.db_name, query=self.current_user_input, topk=self.top_k + table_infos = await blocking_func_to_async( + self._executor, + client.get_similar_tables, + self.db_name, + self.current_user_input, + self.top_k, ) + # table_infos = client.get_similar_tables( + # dbname=self.db_name, query=self.current_user_input, topk=self.top_k + # ) print("dashboard vector find tables:{}", table_infos) except Exception as e: print("db summary find error!" + str(e)) diff --git a/pilot/scene/chat_data/chat_excel/excel_analyze/chat.py b/pilot/scene/chat_data/chat_excel/excel_analyze/chat.py index 9a316cd34..064e7586c 100644 --- a/pilot/scene/chat_data/chat_excel/excel_analyze/chat.py +++ b/pilot/scene/chat_data/chat_excel/excel_analyze/chat.py @@ -62,7 +62,7 @@ def _generate_numbered_list(self) -> str: # ] return "\n".join(f"{i+1}. {item}" for i, item in enumerate(command_strings)) - def generate_input_values(self): + async def generate_input_values(self) -> Dict: input_values = { "user_input": self.current_user_input, "table_name": self.excel_reader.table_name, diff --git a/pilot/scene/chat_data/chat_excel/excel_learning/chat.py b/pilot/scene/chat_data/chat_excel/excel_learning/chat.py index c7eb82584..f05221eba 100644 --- a/pilot/scene/chat_data/chat_excel/excel_learning/chat.py +++ b/pilot/scene/chat_data/chat_excel/excel_learning/chat.py @@ -1,6 +1,5 @@ import json -import os -from typing import Any +from typing import Any, Dict from pilot.scene.base_message import ( HumanMessage, @@ -13,6 +12,7 @@ from pilot.scene.chat_data.chat_excel.excel_learning.prompt import prompt from pilot.scene.chat_data.chat_excel.excel_reader import ExcelReader from pilot.json_utils.utilities import DateTimeEncoder +from pilot.utils.executor_utils import blocking_func_to_async CFG = Config() @@ -44,13 +44,15 @@ def __init__( if parent_mode: self.current_message.chat_mode = parent_mode.value() - def generate_input_values(self): - colunms, datas = self.excel_reader.get_sample_data() + async def generate_input_values(self) -> Dict: + # colunms, datas = self.excel_reader.get_sample_data() + colunms, datas = await blocking_func_to_async( + self._executor, self.excel_reader.get_sample_data + ) + copy_datas = datas.copy() datas.insert(0, colunms) input_values = { - "data_example": json.dumps( - self.excel_reader.get_sample_data(), cls=DateTimeEncoder - ), + "data_example": json.dumps(copy_datas, cls=DateTimeEncoder), } return input_values diff --git a/pilot/scene/chat_db/auto_execute/chat.py b/pilot/scene/chat_db/auto_execute/chat.py index f92df7a3a..d9b901772 100644 --- a/pilot/scene/chat_db/auto_execute/chat.py +++ b/pilot/scene/chat_db/auto_execute/chat.py @@ -5,6 +5,7 @@ from pilot.common.sql_database import Database from pilot.configs.config import Config from pilot.scene.chat_db.auto_execute.prompt import prompt +from pilot.utils.executor_utils import blocking_func_to_async CFG = Config() @@ -38,7 +39,7 @@ def __init__(self, chat_param: Dict): self.database = CFG.LOCAL_DB_MANAGE.get_connect(self.db_name) self.top_k: int = 200 - def generate_input_values(self): + async def generate_input_values(self) -> Dict: """ generate input values """ @@ -47,19 +48,27 @@ def generate_input_values(self): except ImportError: raise ValueError("Could not import DBSummaryClient. ") client = DBSummaryClient(system_app=CFG.SYSTEM_APP) + table_infos = None try: - table_infos = client.get_db_summary( - dbname=self.db_name, - query=self.current_user_input, - topk=CFG.KNOWLEDGE_SEARCH_TOP_SIZE, + # table_infos = client.get_db_summary( + # dbname=self.db_name, + # query=self.current_user_input, + # topk=CFG.KNOWLEDGE_SEARCH_TOP_SIZE, + # ) + table_infos = await blocking_func_to_async( + self._executor, + client.get_db_summary, + self.db_name, + self.current_user_input, + CFG.KNOWLEDGE_SEARCH_TOP_SIZE, ) except Exception as e: print("db summary find error!" + str(e)) - table_infos = self.database.table_simple_info() if not table_infos: - table_infos = self.database.table_simple_info() - - # table_infos = self.database.table_simple_info() + # table_infos = self.database.table_simple_info() + table_infos = await blocking_func_to_async( + self._executor, self.database.table_simple_info + ) input_values = { "input": self.current_user_input, diff --git a/pilot/scene/chat_db/professional_qa/chat.py b/pilot/scene/chat_db/professional_qa/chat.py index abdfd9f00..5ae76d37d 100644 --- a/pilot/scene/chat_db/professional_qa/chat.py +++ b/pilot/scene/chat_db/professional_qa/chat.py @@ -5,6 +5,7 @@ from pilot.common.sql_database import Database from pilot.configs.config import Config from pilot.scene.chat_db.professional_qa.prompt import prompt +from pilot.utils.executor_utils import blocking_func_to_async CFG = Config() @@ -38,7 +39,7 @@ def __init__(self, chat_param: Dict): else len(self.tables) ) - def generate_input_values(self): + async def generate_input_values(self) -> Dict: table_info = "" dialect = "mysql" try: @@ -48,12 +49,22 @@ def generate_input_values(self): if self.db_name: client = DBSummaryClient(system_app=CFG.SYSTEM_APP) try: - table_infos = client.get_db_summary( - dbname=self.db_name, query=self.current_user_input, topk=self.top_k + # table_infos = client.get_db_summary( + # dbname=self.db_name, query=self.current_user_input, topk=self.top_k + # ) + table_infos = await blocking_func_to_async( + self._executor, + client.get_db_summary, + self.db_name, + self.current_user_input, + self.top_k, ) except Exception as e: print("db summary find error!" + str(e)) - table_infos = self.database.table_simple_info() + # table_infos = self.database.table_simple_info() + table_infos = await blocking_func_to_async( + self._executor, self.database.table_simple_info + ) # table_infos = self.database.table_simple_info() dialect = self.database.dialect diff --git a/pilot/scene/chat_execution/chat.py b/pilot/scene/chat_execution/chat.py index c6d7bbe2f..bdd78d7b7 100644 --- a/pilot/scene/chat_execution/chat.py +++ b/pilot/scene/chat_execution/chat.py @@ -50,7 +50,7 @@ def __init__(self, chat_param: Dict): self.plugins_prompt_generator ) - def generate_input_values(self): + async def generate_input_values(self) -> Dict: input_values = { "input": self.current_user_input, "constraints": self.__list_to_prompt_str( diff --git a/pilot/scene/chat_knowledge/inner_db_summary/chat.py b/pilot/scene/chat_knowledge/inner_db_summary/chat.py index 34c8260e3..07a64aea9 100644 --- a/pilot/scene/chat_knowledge/inner_db_summary/chat.py +++ b/pilot/scene/chat_knowledge/inner_db_summary/chat.py @@ -1,3 +1,4 @@ +from typing import Dict from pilot.scene.base_chat import BaseChat from pilot.scene.base import ChatScene from pilot.configs.config import Config @@ -30,7 +31,7 @@ def __init__( self.db_input = db_select self.db_summary = db_summary - def generate_input_values(self): + async def generate_input_values(self) -> Dict: input_values = { "db_input": self.db_input, "db_profile_summary": self.db_summary, diff --git a/pilot/scene/chat_knowledge/v1/chat.py b/pilot/scene/chat_knowledge/v1/chat.py index 9681f13c6..236c9c1a7 100644 --- a/pilot/scene/chat_knowledge/v1/chat.py +++ b/pilot/scene/chat_knowledge/v1/chat.py @@ -12,6 +12,7 @@ from pilot.scene.chat_knowledge.v1.prompt import prompt from pilot.server.knowledge.service import KnowledgeService +from pilot.utils.executor_utils import blocking_func_to_async CFG = Config() @@ -65,7 +66,7 @@ def __init__(self, chat_param: Dict): self.prompt_template.template_is_strict = False async def stream_call(self): - input_values = self.generate_input_values() + input_values = await self.generate_input_values() # Source of knowledge file relations = input_values.get("relations") last_output = None @@ -85,12 +86,18 @@ async def stream_call(self): ) yield last_output - def generate_input_values(self): + async def generate_input_values(self) -> Dict: if self.space_context: self.prompt_template.template_define = self.space_context["prompt"]["scene"] self.prompt_template.template = self.space_context["prompt"]["template"] - docs = self.knowledge_embedding_client.similar_search( - self.current_user_input, self.top_k + # docs = self.knowledge_embedding_client.similar_search( + # self.current_user_input, self.top_k + # ) + docs = await blocking_func_to_async( + self._executor, + self.knowledge_embedding_client.similar_search, + self.current_user_input, + self.top_k, ) if not docs: raise ValueError( diff --git a/pilot/scene/chat_normal/chat.py b/pilot/scene/chat_normal/chat.py index 47d1e70b5..5999d5c3c 100644 --- a/pilot/scene/chat_normal/chat.py +++ b/pilot/scene/chat_normal/chat.py @@ -21,7 +21,7 @@ def __init__(self, chat_param: Dict): chat_param=chat_param, ) - def generate_input_values(self): + async def generate_input_values(self) -> Dict: input_values = {"input": self.current_user_input} return input_values diff --git a/pilot/server/componet_configs.py b/pilot/server/componet_configs.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/pilot/utils/executor_utils.py b/pilot/utils/executor_utils.py index dbc3dac81..2aac0d04d 100644 --- a/pilot/utils/executor_utils.py +++ b/pilot/utils/executor_utils.py @@ -1,5 +1,8 @@ +from typing import Callable, Awaitable, Any +import asyncio from abc import ABC, abstractmethod from concurrent.futures import Executor, ThreadPoolExecutor +from functools import partial from pilot.component import BaseComponent, ComponentType, SystemApp @@ -24,3 +27,34 @@ def init_app(self, system_app: SystemApp): def create(self) -> Executor: return self._executor + + +BlockingFunction = Callable[..., Any] + + +async def blocking_func_to_async( + executor: Executor, func: BlockingFunction, *args, **kwargs +): + """Run a potentially blocking function within an executor. + + Args: + executor (Executor): The concurrent.futures.Executor to run the function within. + func (ApplyFunction): The callable function, which should be a synchronous function. + It should accept any number and type of arguments and return an asynchronous coroutine. + *args (Any): Any additional arguments to pass to the function. + **kwargs (Any): Other arguments to pass to the function + + Returns: + Any: The result of the function's execution. + + Raises: + ValueError: If the provided function 'func' is an asynchronous coroutine function. + + This function allows you to execute a potentially blocking function within an executor. + It expects 'func' to be a synchronous function and will raise an error if 'func' is an asynchronous coroutine. + """ + if asyncio.iscoroutinefunction(func): + raise ValueError(f"The function {func} is not blocking function") + loop = asyncio.get_event_loop() + sync_function_noargs = partial(func, *args, **kwargs) + return await loop.run_in_executor(executor, sync_function_noargs) diff --git a/scripts/setup_autodl_env.sh b/scripts/setup_autodl_env.sh index 2eafbcebb..13d08736d 100644 --- a/scripts/setup_autodl_env.sh +++ b/scripts/setup_autodl_env.sh @@ -35,15 +35,14 @@ clone_repositories() { cd /root && git clone https://github.com/eosphoros-ai/DB-GPT.git mkdir -p /root/DB-GPT/models && cd /root/DB-GPT/models git clone https://huggingface.co/GanymedeNil/text2vec-large-chinese - git clone https://huggingface.co/THUDM/chatglm2-6b-int4 + git clone https://huggingface.co/THUDM/chatglm2-6b rm -rf /root/DB-GPT/models/text2vec-large-chinese/.git - rm -rf /root/DB-GPT/models/chatglm2-6b-int4/.git + rm -rf /root/DB-GPT/models/chatglm2-6b/.git } install_dbgpt_packages() { - conda activate dbgpt && cd /root/DB-GPT && pip install -e . && cp .env.template .env - cp .env.template .env && sed -i 's/LLM_MODEL=vicuna-13b-v1.5/LLM_MODEL=chatglm2-6b-int4/' .env - + conda activate dbgpt && cd /root/DB-GPT && pip install -e ".[default]" + cp .env.template .env && sed -i 's/LLM_MODEL=vicuna-13b-v1.5/LLM_MODEL=chatglm2-6b/' .env } clean_up() { diff --git a/setup.py b/setup.py index 2c807b010..8a62faf4a 100644 --- a/setup.py +++ b/setup.py @@ -317,6 +317,8 @@ def core_requires(): # TODO move transformers to default "transformers>=4.31.0", "alembic==1.12.0", + # for excel + "openpyxl", ] @@ -361,6 +363,8 @@ def quantization_requires(): ) pkgs = [f"bitsandbytes @ {local_pkg}"] print(pkgs) + # For chatglm2-6b-int4 + pkgs += ["cpm_kernels"] setup_spec.extras["quantization"] = pkgs