diff --git a/docker/compose_examples/cluster-docker-compose.yml b/docker/compose_examples/cluster-docker-compose.yml index c4928a53a..8d4763532 100644 --- a/docker/compose_examples/cluster-docker-compose.yml +++ b/docker/compose_examples/cluster-docker-compose.yml @@ -4,9 +4,10 @@ services: controller: image: eosphorosai/dbgpt:latest command: dbgpt start controller + restart: unless-stopped networks: - dbgptnet - worker: + llm-worker: image: eosphorosai/dbgpt:latest command: dbgpt start worker --model_name vicuna-13b-v1.5 --model_path /app/models/vicuna-13b-v1.5 --port 8001 --controller_addr http://controller:8000 environment: @@ -17,6 +18,27 @@ services: - /data:/data # Please modify it to your own model directory - /data/models:/app/models + restart: unless-stopped + networks: + - dbgptnet + deploy: + resources: + reservations: + devices: + - driver: nvidia + capabilities: [gpu] + embedding-worker: + image: eosphorosai/dbgpt:latest + command: dbgpt start worker --model_name text2vec --worker_type text2vec --model_path /app/models/text2vec-large-chinese --port 8002 --controller_addr http://controller:8000 + environment: + - DBGPT_LOG_LEVEL=DEBUG + depends_on: + - controller + volumes: + - /data:/data + # Please modify it to your own model directory + - /data/models:/app/models + restart: unless-stopped networks: - dbgptnet deploy: @@ -37,7 +59,8 @@ services: - MODEL_SERVER=http://controller:8000 depends_on: - controller - - worker + - llm-worker + - embedding-worker volumes: - /data:/data # Please modify it to your own model directory diff --git a/docs/_static/img/muti-model-cluster-overview.png b/docs/_static/img/muti-model-cluster-overview.png new file mode 100644 index 000000000..e087a0034 Binary files /dev/null and b/docs/_static/img/muti-model-cluster-overview.png differ diff --git a/docs/getting_started/install.rst b/docs/getting_started/install.rst index abb90ed6e..d6e6a15f8 100644 --- a/docs/getting_started/install.rst +++ b/docs/getting_started/install.rst @@ -9,6 +9,7 @@ DB-GPT product is a Web application that you can chat database, chat knowledge, - docker - docker_compose - environment +- cluster deployment - deploy_faq .. toctree:: @@ -20,6 +21,7 @@ DB-GPT product is a Web application that you can chat database, chat knowledge, ./install/deploy/deploy.md ./install/docker/docker.md ./install/docker_compose/docker_compose.md + ./install/cluster/cluster.rst ./install/llm/llm.rst ./install/environment/environment.md ./install/faq/deploy_faq.md \ No newline at end of file diff --git a/docs/getting_started/install/cluster/cluster.rst b/docs/getting_started/install/cluster/cluster.rst new file mode 100644 index 000000000..f81fd4d0d --- /dev/null +++ b/docs/getting_started/install/cluster/cluster.rst @@ -0,0 +1,19 @@ +Cluster deployment +================================== + +In order to deploy DB-GPT to multiple nodes, you can deploy a cluster. The cluster architecture diagram is as follows: + +.. raw:: html + + + + +* On :ref:`Deploying on local machine `. Local cluster deployment. + +.. toctree:: + :maxdepth: 2 + :caption: Cluster deployment + :name: cluster_deploy + :hidden: + + ./vms/index.md diff --git a/docs/getting_started/install/cluster/kubernetes/index.md b/docs/getting_started/install/cluster/kubernetes/index.md new file mode 100644 index 000000000..385a8b054 --- /dev/null +++ b/docs/getting_started/install/cluster/kubernetes/index.md @@ -0,0 +1,3 @@ +Kubernetes cluster deployment +================================== +(kubernetes-cluster-index)= \ No newline at end of file diff --git a/docs/getting_started/install/llm/cluster/model_cluster.md b/docs/getting_started/install/cluster/vms/index.md similarity index 76% rename from docs/getting_started/install/llm/cluster/model_cluster.md rename to docs/getting_started/install/cluster/vms/index.md index 5576dfc1e..03b5b0293 100644 --- a/docs/getting_started/install/llm/cluster/model_cluster.md +++ b/docs/getting_started/install/cluster/vms/index.md @@ -1,6 +1,6 @@ -Cluster deployment +Local cluster deployment ================================== - +(local-cluster-index)= ## Model cluster deployment @@ -17,7 +17,7 @@ dbgpt start controller By default, the Model Controller starts on port 8000. -### Launch Model Worker +### Launch LLM Model Worker If you are starting `chatglm2-6b`: @@ -39,6 +39,18 @@ dbgpt start worker --model_name vicuna-13b-v1.5 \ Note: Be sure to use your own model name and model path. +### Launch Embedding Model Worker + +```bash + +dbgpt start worker --model_name text2vec \ +--model_path /app/models/text2vec-large-chinese \ +--worker_type text2vec \ +--port 8003 \ +--controller_addr http://127.0.0.1:8000 +``` + +Note: Be sure to use your own model name and model path. Check your model: @@ -51,8 +63,12 @@ You will see the following output: +-----------------+------------+------------+------+---------+---------+-----------------+----------------------------+ | Model Name | Model Type | Host | Port | Healthy | Enabled | Prompt Template | Last Heartbeat | +-----------------+------------+------------+------+---------+---------+-----------------+----------------------------+ -| chatglm2-6b | llm | 172.17.0.6 | 8001 | True | True | None | 2023-08-31T04:48:45.252939 | -| vicuna-13b-v1.5 | llm | 172.17.0.6 | 8002 | True | True | None | 2023-08-31T04:48:55.136676 | +| chatglm2-6b | llm | 172.17.0.2 | 8001 | True | True | | 2023-09-12T23:04:31.287654 | +| WorkerManager | service | 172.17.0.2 | 8001 | True | True | | 2023-09-12T23:04:31.286668 | +| WorkerManager | service | 172.17.0.2 | 8003 | True | True | | 2023-09-12T23:04:29.845617 | +| WorkerManager | service | 172.17.0.2 | 8002 | True | True | | 2023-09-12T23:04:24.598439 | +| text2vec | text2vec | 172.17.0.2 | 8003 | True | True | | 2023-09-12T23:04:29.844796 | +| vicuna-13b-v1.5 | llm | 172.17.0.2 | 8002 | True | True | | 2023-09-12T23:04:24.597775 | +-----------------+------------+------------+------+---------+---------+-----------------+----------------------------+ ``` @@ -69,7 +85,7 @@ MODEL_SERVER=http://127.0.0.1:8000 #### Start the webserver ```bash -python pilot/server/dbgpt_server.py --light +dbgpt start webserver --light ``` `--light` indicates not to start the embedded model service. @@ -77,7 +93,7 @@ python pilot/server/dbgpt_server.py --light Alternatively, you can prepend the command with `LLM_MODEL=chatglm2-6b` to start: ```bash -LLM_MODEL=chatglm2-6b python pilot/server/dbgpt_server.py --light +LLM_MODEL=chatglm2-6b dbgpt start webserver --light ``` @@ -101,9 +117,11 @@ Options: --help Show this message and exit. Commands: - model Clients that manage model serving - start Start specific server. - stop Start specific server. + install Install dependencies, plugins, etc. + knowledge Knowledge command line tool + model Clients that manage model serving + start Start specific server. + stop Start specific server. ``` **View the `dbgpt start` help** @@ -146,10 +164,11 @@ Options: --model_name TEXT Model name [required] --model_path TEXT Model path [required] --worker_type TEXT Worker type - --worker_class TEXT Model worker class, pilot.model.worker.defau - lt_worker.DefaultModelWorker + --worker_class TEXT Model worker class, + pilot.model.cluster.DefaultModelWorker --host TEXT Model worker deploy host [default: 0.0.0.0] - --port INTEGER Model worker deploy port [default: 8000] + --port INTEGER Model worker deploy port [default: 8001] + --daemon Run Model Worker in background --limit_model_concurrency INTEGER Model concurrency limit [default: 5] --standalone Standalone mode. If True, embedded Run @@ -166,7 +185,7 @@ Options: (seconds) [default: 20] --device TEXT Device to run model. If None, the device is automatically determined - --model_type TEXT Model type, huggingface or llama.cpp + --model_type TEXT Model type, huggingface, llama.cpp and proxy [default: huggingface] --prompt_template TEXT Prompt template. If None, the prompt template is automatically determined from @@ -190,7 +209,7 @@ Options: --compute_dtype TEXT Model compute type --trust_remote_code Trust remote code [default: True] --verbose Show verbose output. - --help Show this message and exit. + --help Show this message and exit. ``` **View the `dbgpt model`help** @@ -208,10 +227,13 @@ Usage: dbgpt model [OPTIONS] COMMAND [ARGS]... Options: --address TEXT Address of the Model Controller to connect to. Just support - light deploy model [default: http://127.0.0.1:8000] + light deploy model, If the environment variable + CONTROLLER_ADDRESS is configured, read from the environment + variable --help Show this message and exit. Commands: + chat Interact with your bot from the command line list List model instances restart Restart model instances start Start model instances diff --git a/docs/getting_started/install/llm/llm.rst b/docs/getting_started/install/llm/llm.rst index accde0250..a678c3e2a 100644 --- a/docs/getting_started/install/llm/llm.rst +++ b/docs/getting_started/install/llm/llm.rst @@ -6,6 +6,7 @@ DB-GPT provides a management and deployment solution for multiple models. This c Multi LLMs Support, Supports multiple large language models, currently supporting + - 🔥 Baichuan2(7b,13b) - 🔥 Vicuna-v1.5(7b,13b) - 🔥 llama-2(7b,13b,70b) - WizardLM-v1.2(13b) @@ -19,7 +20,6 @@ Multi LLMs Support, Supports multiple large language models, currently supportin - llama_cpp - quantization -- cluster deployment .. toctree:: :maxdepth: 2 @@ -29,4 +29,3 @@ Multi LLMs Support, Supports multiple large language models, currently supportin ./llama/llama_cpp.md ./quantization/quantization.md - ./cluster/model_cluster.md diff --git a/docs/locales/zh_CN/LC_MESSAGES/getting_started/install.po b/docs/locales/zh_CN/LC_MESSAGES/getting_started/install.po index 9e31b74e5..a4cfc8e66 100644 --- a/docs/locales/zh_CN/LC_MESSAGES/getting_started/install.po +++ b/docs/locales/zh_CN/LC_MESSAGES/getting_started/install.po @@ -8,7 +8,7 @@ msgid "" msgstr "" "Project-Id-Version: DB-GPT 👏👏 0.3.5\n" "Report-Msgid-Bugs-To: \n" -"POT-Creation-Date: 2023-08-16 18:31+0800\n" +"POT-Creation-Date: 2023-09-13 09:06+0800\n" "PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" "Last-Translator: FULL NAME \n" "Language: zh_CN\n" @@ -19,34 +19,38 @@ msgstr "" "Content-Transfer-Encoding: 8bit\n" "Generated-By: Babel 2.12.1\n" -#: ../../getting_started/install.rst:2 ../../getting_started/install.rst:14 -#: 2861085e63144eaca1bb825e5f05d089 +#: ../../getting_started/install.rst:2 ../../getting_started/install.rst:15 +#: e2c13385046b4da6b6838db6ba2ea59c msgid "Install" msgstr "Install" -#: ../../getting_started/install.rst:3 01a6603d91fa4520b0f839379d4eda23 +#: ../../getting_started/install.rst:3 3cb6cd251ed440dabe5d4f556435f405 msgid "" "DB-GPT product is a Web application that you can chat database, chat " "knowledge, text2dashboard." msgstr "DB-GPT 可以生成sql,智能报表, 知识库问答的产品" -#: ../../getting_started/install.rst:8 beca85cddc9b4406aecf83d5dfcce1f7 +#: ../../getting_started/install.rst:8 6fe8104b70d24f5fbfe2ad9ebf3bc3ba msgid "deploy" msgstr "部署" -#: ../../getting_started/install.rst:9 601e9b9eb91f445fb07d2f1c807f0370 +#: ../../getting_started/install.rst:9 e67974b3672346809febf99a3b9a55d3 msgid "docker" msgstr "docker" -#: ../../getting_started/install.rst:10 6d1e094ac9284458a32a3e7fa6241c81 +#: ../../getting_started/install.rst:10 64de16a047c74598966e19a656bf6c4f msgid "docker_compose" msgstr "docker_compose" -#: ../../getting_started/install.rst:11 ff1d1c60bbdc4e8ca82b7a9f303dd167 +#: ../../getting_started/install.rst:11 9f87d65e8675435b87cb9376a5bfd85c msgid "environment" msgstr "environment" -#: ../../getting_started/install.rst:12 33bfbe8defd74244bfc24e8fbfd640f6 +#: ../../getting_started/install.rst:12 e60fa13bb24544ed9d4f902337093ebc +msgid "cluster deployment" +msgstr "集群部署" + +#: ../../getting_started/install.rst:13 7451712679c2412e858e7d3e2af6b174 msgid "deploy_faq" msgstr "deploy_faq" diff --git a/docs/locales/zh_CN/LC_MESSAGES/getting_started/install/cluster/cluster.po b/docs/locales/zh_CN/LC_MESSAGES/getting_started/install/cluster/cluster.po new file mode 100644 index 000000000..577f15089 --- /dev/null +++ b/docs/locales/zh_CN/LC_MESSAGES/getting_started/install/cluster/cluster.po @@ -0,0 +1,42 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) 2023, csunny +# This file is distributed under the same license as the DB-GPT package. +# FIRST AUTHOR , 2023. +# +#, fuzzy +msgid "" +msgstr "" +"Project-Id-Version: DB-GPT 👏👏 0.3.6\n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2023-09-13 10:11+0800\n" +"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" +"Last-Translator: FULL NAME \n" +"Language: zh_CN\n" +"Language-Team: zh_CN \n" +"Plural-Forms: nplurals=1; plural=0;\n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=utf-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Generated-By: Babel 2.12.1\n" + +#: ../../getting_started/install/cluster/cluster.rst:2 +#: ../../getting_started/install/cluster/cluster.rst:13 +#: 69804208b580447798d6946150da7bdf +msgid "Cluster deployment" +msgstr "集群部署" + +#: ../../getting_started/install/cluster/cluster.rst:4 +#: fa3e4e0ae60a45eb836bcd256baa9d91 +msgid "" +"In order to deploy DB-GPT to multiple nodes, you can deploy a cluster. " +"The cluster architecture diagram is as follows:" +msgstr "为了能将 DB-GPT 部署到多个节点上,你可以部署一个集群,集群的架构图如下:" + +#: ../../getting_started/install/cluster/cluster.rst:11 +#: e739449099ca43cabe9883233ca7e572 +#, fuzzy +msgid "" +"On :ref:`Deploying on local machine `. Local cluster" +" deployment." +msgstr "关于 :ref:`在本地机器上部署 `。本地集群部署。" + diff --git a/docs/locales/zh_CN/LC_MESSAGES/getting_started/install/cluster/kubernetes/index.po b/docs/locales/zh_CN/LC_MESSAGES/getting_started/install/cluster/kubernetes/index.po new file mode 100644 index 000000000..b785f94f4 --- /dev/null +++ b/docs/locales/zh_CN/LC_MESSAGES/getting_started/install/cluster/kubernetes/index.po @@ -0,0 +1,26 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) 2023, csunny +# This file is distributed under the same license as the DB-GPT package. +# FIRST AUTHOR , 2023. +# +#, fuzzy +msgid "" +msgstr "" +"Project-Id-Version: DB-GPT 👏👏 0.3.6\n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2023-09-13 09:06+0800\n" +"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" +"Last-Translator: FULL NAME \n" +"Language: zh_CN\n" +"Language-Team: zh_CN \n" +"Plural-Forms: nplurals=1; plural=0;\n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=utf-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Generated-By: Babel 2.12.1\n" + +#: ../../getting_started/install/cluster/kubernetes/index.md:1 +#: 48e6f08f27c74f31a8b12758fe33dc24 +msgid "Kubernetes cluster deployment" +msgstr "Kubernetes 集群部署" + diff --git a/docs/locales/zh_CN/LC_MESSAGES/getting_started/install/cluster/vms/index.po b/docs/locales/zh_CN/LC_MESSAGES/getting_started/install/cluster/vms/index.po new file mode 100644 index 000000000..af535d519 --- /dev/null +++ b/docs/locales/zh_CN/LC_MESSAGES/getting_started/install/cluster/vms/index.po @@ -0,0 +1,176 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) 2023, csunny +# This file is distributed under the same license as the DB-GPT package. +# FIRST AUTHOR , 2023. +# +#, fuzzy +msgid "" +msgstr "" +"Project-Id-Version: DB-GPT 👏👏 0.3.6\n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2023-09-13 09:06+0800\n" +"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" +"Last-Translator: FULL NAME \n" +"Language: zh_CN\n" +"Language-Team: zh_CN \n" +"Plural-Forms: nplurals=1; plural=0;\n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=utf-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Generated-By: Babel 2.12.1\n" + +#: ../../getting_started/install/cluster/vms/index.md:1 +#: 2d2e04ba49364eae9b8493bb274765a6 +msgid "Local cluster deployment" +msgstr "本地集群部署" + +#: ../../getting_started/install/cluster/vms/index.md:4 +#: e405d0e7ad8c4b2da4b4ca27c77f5fea +msgid "Model cluster deployment" +msgstr "模型集群部署" + +#: ../../getting_started/install/cluster/vms/index.md:7 +#: bba397ddac754a2bab8edca163875b65 +msgid "**Installing Command-Line Tool**" +msgstr "**安装命令行工具**" + +#: ../../getting_started/install/cluster/vms/index.md:9 +#: bc45851124354522af8c9bb9748ff1fa +msgid "" +"All operations below are performed using the `dbgpt` command. To use the " +"`dbgpt` command, you need to install the DB-GPT project with `pip install" +" -e .`. Alternatively, you can use `python pilot/scripts/cli_scripts.py` " +"as a substitute for the `dbgpt` command." +msgstr "" +"以下所有操作都使用 `dbgpt` 命令完成。要使用 `dbgpt` 命令,您需要安装DB-GPT项目,方法是使用`pip install -e .`。或者,您可以使用 `python pilot/scripts/cli_scripts.py` 作为 `dbgpt` 命令的替代。" + +#: ../../getting_started/install/cluster/vms/index.md:11 +#: 9d11f7807fd140c8949b634700adc966 +msgid "Launch Model Controller" +msgstr "启动 Model Controller" + +#: ../../getting_started/install/cluster/vms/index.md:17 +#: 97716be92ba64ce9a215433bddf77add +msgid "By default, the Model Controller starts on port 8000." +msgstr "默认情况下,Model Controller 启动在 8000 端口。" + +#: ../../getting_started/install/cluster/vms/index.md:20 +#: 3f65e6a1e59248a59c033891d1ab7ba8 +msgid "Launch LLM Model Worker" +msgstr "启动 LLM Model Worker" + +#: ../../getting_started/install/cluster/vms/index.md:22 +#: 60241d97573e4265b7fb150c378c4a08 +msgid "If you are starting `chatglm2-6b`:" +msgstr "如果您启动的是 `chatglm2-6b`:" + +#: ../../getting_started/install/cluster/vms/index.md:31 +#: 18bbeb1de110438fa96dd5c736b9a7b1 +msgid "If you are starting `vicuna-13b-v1.5`:" +msgstr "如果您启动的是 `vicuna-13b-v1.5`:" + +#: ../../getting_started/install/cluster/vms/index.md:40 +#: ../../getting_started/install/cluster/vms/index.md:53 +#: 24b1a27313c64224aaeab6cbfad1fe19 fc94a698a7904c6893eef7e7a6e52972 +msgid "Note: Be sure to use your own model name and model path." +msgstr "注意:确保使用您自己的模型名称和模型路径。" + +#: ../../getting_started/install/cluster/vms/index.md:42 +#: 19746195e85f4784bf66a9e67378c04b +msgid "Launch Embedding Model Worker" +msgstr "启动 Embedding Model Worker" + +#: ../../getting_started/install/cluster/vms/index.md:55 +#: e93ce68091f64d0294b3f912a66cc18b +msgid "Check your model:" +msgstr "检查您的模型:" + +#: ../../getting_started/install/cluster/vms/index.md:61 +#: fa0b8f3a18fe4bab88fbf002bf26d32e +msgid "You will see the following output:" +msgstr "您将看到以下输出:" + +#: ../../getting_started/install/cluster/vms/index.md:75 +#: 695262fb4f224101902bc7865ac7871f +msgid "Connect to the model service in the webserver (dbgpt_server)" +msgstr "在 webserver (dbgpt_server) 中连接到模型服务 (dbgpt_server)" + +#: ../../getting_started/install/cluster/vms/index.md:77 +#: 73bf4c2ae5c64d938e3b7e77c06fa21e +msgid "" +"**First, modify the `.env` file to change the model name and the Model " +"Controller connection address.**" +msgstr "" +"**首先,修改 `.env` 文件以更改模型名称和模型控制器连接地址。**" + +#: ../../getting_started/install/cluster/vms/index.md:85 +#: 8ab126fd72ed4368a79b821ba50e62c8 +msgid "Start the webserver" +msgstr "启动 webserver" + +#: ../../getting_started/install/cluster/vms/index.md:91 +#: 5a7e25c84ca2412bb64310bfad9e2403 +msgid "`--light` indicates not to start the embedded model service." +msgstr "`--light` 表示不启动嵌入式模型服务。" + +#: ../../getting_started/install/cluster/vms/index.md:93 +#: 8cd9ec4fa9cb4c0fa8ff05c05a85ea7f +msgid "" +"Alternatively, you can prepend the command with `LLM_MODEL=chatglm2-6b` " +"to start:" +msgstr "" +"或者,您可以在命令前加上 `LLM_MODEL=chatglm2-6b` 来启动:" + +#: ../../getting_started/install/cluster/vms/index.md:100 +#: 13ed16758a104860b5fc982d36638b17 +msgid "More Command-Line Usages" +msgstr "更多命令行用法" + +#: ../../getting_started/install/cluster/vms/index.md:102 +#: 175f614d547a4391bab9a77762f9174e +msgid "You can view more command-line usages through the help command." +msgstr "您可以通过帮助命令查看更多命令行用法。" + +#: ../../getting_started/install/cluster/vms/index.md:104 +#: 6a4475d271c347fbbb35f2936a86823f +msgid "**View the `dbgpt` help**" +msgstr "**查看 `dbgpt` 帮助**" + +#: ../../getting_started/install/cluster/vms/index.md:109 +#: 3eb11234cf504cc9ac369d8462daa14b +msgid "You will see the basic command parameters and usage:" +msgstr "您将看到基本的命令参数和用法:" + +#: ../../getting_started/install/cluster/vms/index.md:127 +#: 6eb47aecceec414e8510fe022b6fddbd +msgid "**View the `dbgpt start` help**" +msgstr "**查看 `dbgpt start` 帮助**" + +#: ../../getting_started/install/cluster/vms/index.md:133 +#: 1f4c0a4ce0704ca8ac33178bd13c69ad +msgid "Here you can see the related commands and usage for start:" +msgstr "在这里,您可以看到启动的相关命令和用法:" + +#: ../../getting_started/install/cluster/vms/index.md:150 +#: 22e8e67bc55244e79764d091f334560b +msgid "**View the `dbgpt start worker`help**" +msgstr "**查看 `dbgpt start worker` 帮助**" + +#: ../../getting_started/install/cluster/vms/index.md:156 +#: 5631b83fda714780855e99e90d4eb542 +msgid "Here you can see the parameters to start Model Worker:" +msgstr "在这里,您可以看到启动 Model Worker 的参数:" + +#: ../../getting_started/install/cluster/vms/index.md:215 +#: cf4a31fd3368481cba1b3ab382615f53 +msgid "**View the `dbgpt model`help**" +msgstr "**查看 `dbgpt model` 帮助**" + +#: ../../getting_started/install/cluster/vms/index.md:221 +#: 3740774ec4b240f2882b5b59da224d55 +msgid "" +"The `dbgpt model ` command can connect to the Model Controller via the " +"Model Controller address and then manage a remote model:" +msgstr "" +"`dbgpt model` 命令可以通过 Model Controller 地址连接到 Model Controller,然后管理远程模型:" + diff --git a/docs/locales/zh_CN/LC_MESSAGES/getting_started/install/llm/llm.po b/docs/locales/zh_CN/LC_MESSAGES/getting_started/install/llm/llm.po index b5531e9f5..5041134b7 100644 --- a/docs/locales/zh_CN/LC_MESSAGES/getting_started/install/llm/llm.po +++ b/docs/locales/zh_CN/LC_MESSAGES/getting_started/install/llm/llm.po @@ -8,7 +8,7 @@ msgid "" msgstr "" "Project-Id-Version: DB-GPT 👏👏 0.3.5\n" "Report-Msgid-Bugs-To: \n" -"POT-Creation-Date: 2023-08-31 16:38+0800\n" +"POT-Creation-Date: 2023-09-13 10:46+0800\n" "PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" "Last-Translator: FULL NAME \n" "Language: zh_CN\n" @@ -21,84 +21,88 @@ msgstr "" #: ../../getting_started/install/llm/llm.rst:2 #: ../../getting_started/install/llm/llm.rst:24 -#: b348d4df8ca44dd78b42157a8ff6d33d +#: e693a8d3769b4d9e99c4442ca77dc43c msgid "LLM Usage" msgstr "LLM使用" -#: ../../getting_started/install/llm/llm.rst:3 7f5960a7e5634254b330da27be87594b +#: ../../getting_started/install/llm/llm.rst:3 0a73562d18ba455bab04277b715c3840 msgid "" "DB-GPT provides a management and deployment solution for multiple models." " This chapter mainly discusses how to deploy different models." msgstr "DB-GPT提供了多模型的管理和部署方案,本章主要讲解针对不同的模型该怎么部署" -#: ../../getting_started/install/llm/llm.rst:18 -#: b844ab204ec740ec9d7d191bb841f09e +#: ../../getting_started/install/llm/llm.rst:19 +#: d7e4de2a7e004888897204ec76b6030b msgid "" "Multi LLMs Support, Supports multiple large language models, currently " "supporting" msgstr "目前DB-GPT已适配如下模型" -#: ../../getting_started/install/llm/llm.rst:9 c141437ddaf84c079360008343041b2f +#: ../../getting_started/install/llm/llm.rst:9 4616886b8b2244bd93355e871356d89e +#, fuzzy +msgid "🔥 Baichuan2(7b,13b)" +msgstr "Baichuan(7b,13b)" + +#: ../../getting_started/install/llm/llm.rst:10 +#: ad0e4793d4e744c1bdf59f5a3d9c84be msgid "🔥 Vicuna-v1.5(7b,13b)" msgstr "🔥 Vicuna-v1.5(7b,13b)" -#: ../../getting_started/install/llm/llm.rst:10 -#: d32b1e3f114c4eab8782b497097c1b37 +#: ../../getting_started/install/llm/llm.rst:11 +#: d291e58001ae487bbbf2a1f9f889f5fd msgid "🔥 llama-2(7b,13b,70b)" msgstr "🔥 llama-2(7b,13b,70b)" -#: ../../getting_started/install/llm/llm.rst:11 -#: 0a417ee4d008421da07fff7add5d05eb +#: ../../getting_started/install/llm/llm.rst:12 +#: 1e49702ee40b4655945a2a13efaad536 msgid "WizardLM-v1.2(13b)" msgstr "WizardLM-v1.2(13b)" -#: ../../getting_started/install/llm/llm.rst:12 -#: 199e1a9fe3324dc8a1bcd9cd0b1ef047 +#: ../../getting_started/install/llm/llm.rst:13 +#: 4ef5913ddfe840d7a12289e6e1d4cb60 msgid "Vicuna (7b,13b)" msgstr "Vicuna (7b,13b)" -#: ../../getting_started/install/llm/llm.rst:13 -#: a9e4c5100534450db3a583fa5850e4be +#: ../../getting_started/install/llm/llm.rst:14 +#: ea46c2211257459285fa48083cb59561 msgid "ChatGLM-6b (int4,int8)" msgstr "ChatGLM-6b (int4,int8)" -#: ../../getting_started/install/llm/llm.rst:14 -#: 943324289eb94042b52fd824189cd93f +#: ../../getting_started/install/llm/llm.rst:15 +#: 90688302bae4452a84f14e8ecb7f1a21 msgid "ChatGLM2-6b (int4,int8)" msgstr "ChatGLM2-6b (int4,int8)" -#: ../../getting_started/install/llm/llm.rst:15 -#: f1226fdfac3b4e9d88642ffa69d75682 +#: ../../getting_started/install/llm/llm.rst:16 +#: ee1469545a314696a36e7296c7b71960 msgid "guanaco(7b,13b,33b)" msgstr "guanaco(7b,13b,33b)" -#: ../../getting_started/install/llm/llm.rst:16 -#: 3f2457f56eb341b6bc431c9beca8f4df +#: ../../getting_started/install/llm/llm.rst:17 +#: 25abad241f4d4eee970d5938bf71311f msgid "Gorilla(7b,13b)" msgstr "Gorilla(7b,13b)" -#: ../../getting_started/install/llm/llm.rst:17 -#: 86c8ce37be1c4a7ea3fc382100d77a9c +#: ../../getting_started/install/llm/llm.rst:18 +#: 8e3d0399431a4c6a9065a8ae0ad3c8ac msgid "Baichuan(7b,13b)" msgstr "Baichuan(7b,13b)" -#: ../../getting_started/install/llm/llm.rst:18 -#: 538111af95ad414cb2e631a89f9af379 +#: ../../getting_started/install/llm/llm.rst:19 +#: c285fa7c9c6c4e3e9840761a09955348 msgid "OpenAI" msgstr "OpenAI" -#: ../../getting_started/install/llm/llm.rst:20 -#: a203325b7ec248f7bff61ae89226a000 +#: ../../getting_started/install/llm/llm.rst:21 +#: 4ac13a21f323455982750bd2e0243b72 msgid "llama_cpp" msgstr "llama_cpp" -#: ../../getting_started/install/llm/llm.rst:21 -#: 21a50634198047228bc51a03d2c31292 +#: ../../getting_started/install/llm/llm.rst:22 +#: 7231edceef584724a6f569c6b363e083 msgid "quantization" msgstr "quantization" -#: ../../getting_started/install/llm/llm.rst:22 -#: dfaec4b04e6e45ff9c884b41534b1a79 -msgid "cluster deployment" -msgstr "" +#~ msgid "cluster deployment" +#~ msgstr "" diff --git a/docs/requirements.txt b/docs/requirements.txt index 1b7f62282..7d712119a 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,4 +1,4 @@ -autodoc_pydantic==1.8.0 +autodoc_pydantic myst_parser nbsphinx==0.8.9 sphinx==4.5.0 diff --git a/pilot/componet.py b/pilot/componet.py new file mode 100644 index 000000000..705eb1193 --- /dev/null +++ b/pilot/componet.py @@ -0,0 +1,143 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Type, Dict, TypeVar, Optional, TYPE_CHECKING +import asyncio + +# Checking for type hints during runtime +if TYPE_CHECKING: + from fastapi import FastAPI + + +class LifeCycle: + """This class defines hooks for lifecycle events of a component.""" + + def before_start(self): + """Called before the component starts.""" + pass + + async def async_before_start(self): + """Asynchronous version of before_start.""" + pass + + def after_start(self): + """Called after the component has started.""" + pass + + async def async_after_start(self): + """Asynchronous version of after_start.""" + pass + + def before_stop(self): + """Called before the component stops.""" + pass + + async def async_before_stop(self): + """Asynchronous version of before_stop.""" + pass + + +class BaseComponet(LifeCycle, ABC): + """Abstract Base Component class. All custom components should extend this.""" + + name = "base_dbgpt_componet" + + def __init__(self, system_app: Optional[SystemApp] = None): + if system_app is not None: + self.init_app(system_app) + + @abstractmethod + def init_app(self, system_app: SystemApp): + """Initialize the component with the main application. + + This method needs to be implemented by every component to define how it integrates + with the main system app. + """ + pass + + +T = TypeVar("T", bound=BaseComponet) + + +class SystemApp(LifeCycle): + """Main System Application class that manages the lifecycle and registration of components.""" + + def __init__(self, asgi_app: Optional["FastAPI"] = None) -> None: + self.componets: Dict[ + str, BaseComponet + ] = {} # Dictionary to store registered components. + self._asgi_app = asgi_app + + @property + def app(self) -> Optional["FastAPI"]: + """Returns the internal ASGI app.""" + return self._asgi_app + + def register(self, componet: Type[BaseComponet], *args, **kwargs): + """Register a new component by its type.""" + instance = componet(self, *args, **kwargs) + self.register_instance(instance) + + def register_instance(self, instance: T): + """Register an already initialized component.""" + self.componets[instance.name] = instance + instance.init_app(self) + + def get_componet(self, name: str, componet_type: Type[T]) -> T: + """Retrieve a registered component by its name and type.""" + component = self.componets.get(name) + if not component: + raise ValueError(f"No component found with name {name}") + if not isinstance(component, componet_type): + raise TypeError(f"Component {name} is not of type {componet_type}") + return component + + def before_start(self): + """Invoke the before_start hooks for all registered components.""" + for _, v in self.componets.items(): + v.before_start() + + async def async_before_start(self): + """Asynchronously invoke the before_start hooks for all registered components.""" + tasks = [v.async_before_start() for _, v in self.componets.items()] + await asyncio.gather(*tasks) + + def after_start(self): + """Invoke the after_start hooks for all registered components.""" + for _, v in self.componets.items(): + v.after_start() + + async def async_after_start(self): + """Asynchronously invoke the after_start hooks for all registered components.""" + tasks = [v.async_after_start() for _, v in self.componets.items()] + await asyncio.gather(*tasks) + + def before_stop(self): + """Invoke the before_stop hooks for all registered components.""" + for _, v in self.componets.items(): + try: + v.before_stop() + except Exception as e: + pass + + async def async_before_stop(self): + """Asynchronously invoke the before_stop hooks for all registered components.""" + tasks = [v.async_before_stop() for _, v in self.componets.items()] + await asyncio.gather(*tasks) + + def _build(self): + """Integrate lifecycle events with the internal ASGI app if available.""" + if not self.app: + return + + @self.app.on_event("startup") + async def startup_event(): + """ASGI app startup event handler.""" + asyncio.create_task(self.async_after_start()) + self.after_start() + + @self.app.on_event("shutdown") + async def shutdown_event(): + """ASGI app shutdown event handler.""" + await self.async_before_stop() + self.before_stop() diff --git a/pilot/configs/config.py b/pilot/configs/config.py index 8bfaed757..0276c2a17 100644 --- a/pilot/configs/config.py +++ b/pilot/configs/config.py @@ -189,6 +189,10 @@ def __init__(self) -> None: ### Log level self.DBGPT_LOG_LEVEL = os.getenv("DBGPT_LOG_LEVEL", "INFO") + from pilot.componet import SystemApp + + self.SYSTEM_APP: SystemApp = None + def set_debug_mode(self, value: bool) -> None: """Set the debug mode value""" self.debug_mode = value diff --git a/pilot/connections/manages/connection_manager.py b/pilot/connections/manages/connection_manager.py index a933205c2..b5cfbbfdd 100644 --- a/pilot/connections/manages/connection_manager.py +++ b/pilot/connections/manages/connection_manager.py @@ -4,6 +4,7 @@ from pilot.configs.config import Config from pilot.connections.manages.connect_storage_duckdb import DuckdbConnectConfig from pilot.common.schema import DBType +from pilot.componet import SystemApp from pilot.connections.rdbms.conn_mysql import MySQLConnect from pilot.connections.base import BaseConnect @@ -46,9 +47,9 @@ def get_cls_by_dbtype(self, db_type): raise ValueError("Unsupport Db Type!" + db_type) return result - def __init__(self): + def __init__(self, system_app: SystemApp): self.storage = DuckdbConnectConfig() - self.db_summary_client = DBSummaryClient() + self.db_summary_client = DBSummaryClient(system_app) self.__load_config_db() def __load_config_db(self): diff --git a/pilot/connections/rdbms/tests/mange_t.py b/pilot/connections/rdbms/tests/mange_t.py index 5140341d6..282bb8042 100644 --- a/pilot/connections/rdbms/tests/mange_t.py +++ b/pilot/connections/rdbms/tests/mange_t.py @@ -2,6 +2,6 @@ from pilot.connections.manages.connection_manager import ConnectManager if __name__ == "__main__": - mange = ConnectManager() + mange = ConnectManager(system_app=None) types = mange.get_all_completed_types() print(str(types)) diff --git a/pilot/embedding_engine/embedding_engine.py b/pilot/embedding_engine/embedding_engine.py index f3b09ae5d..27f94eeee 100644 --- a/pilot/embedding_engine/embedding_engine.py +++ b/pilot/embedding_engine/embedding_engine.py @@ -1,9 +1,12 @@ from typing import Optional from chromadb.errors import NotEnoughElementsException -from langchain.embeddings import HuggingFaceEmbeddings from langchain.text_splitter import TextSplitter +from pilot.embedding_engine.embedding_factory import ( + EmbeddingFactory, + DefaultEmbeddingFactory, +) from pilot.embedding_engine.knowledge_type import get_knowledge_embedding, KnowledgeType from pilot.vector_store.connector import VectorStoreConnector @@ -24,13 +27,16 @@ def __init__( knowledge_source: Optional[str] = None, source_reader: Optional = None, text_splitter: Optional[TextSplitter] = None, + embedding_factory: EmbeddingFactory = None, ): """Initialize with knowledge embedding client, model_name, vector_store_config, knowledge_type, knowledge_source""" self.knowledge_source = knowledge_source self.model_name = model_name self.vector_store_config = vector_store_config self.knowledge_type = knowledge_type - self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name) + if not embedding_factory: + embedding_factory = DefaultEmbeddingFactory() + self.embeddings = embedding_factory.create(model_name=self.model_name) self.vector_store_config["embeddings"] = self.embeddings self.source_reader = source_reader self.text_splitter = text_splitter diff --git a/pilot/embedding_engine/embedding_factory.py b/pilot/embedding_engine/embedding_factory.py new file mode 100644 index 000000000..e7345d952 --- /dev/null +++ b/pilot/embedding_engine/embedding_factory.py @@ -0,0 +1,39 @@ +from abc import ABC, abstractmethod +from typing import Any, Type, TYPE_CHECKING + +from pilot.componet import BaseComponet + +if TYPE_CHECKING: + from langchain.embeddings.base import Embeddings + + +class EmbeddingFactory(BaseComponet, ABC): + name = "embedding_factory" + + @abstractmethod + def create( + self, model_name: str = None, embedding_cls: Type = None + ) -> "Embeddings": + """Create embedding""" + + +class DefaultEmbeddingFactory(EmbeddingFactory): + def __init__(self, system_app=None, model_name: str = None, **kwargs: Any) -> None: + super().__init__(system_app=system_app) + self._default_model_name = model_name + self.kwargs = kwargs + + def init_app(self, system_app): + pass + + def create( + self, model_name: str = None, embedding_cls: Type = None + ) -> "Embeddings": + if not model_name: + model_name = self._default_model_name + if embedding_cls: + return embedding_cls(model_name=model_name, **self.kwargs) + else: + from langchain.embeddings import HuggingFaceEmbeddings + + return HuggingFaceEmbeddings(model_name=model_name, **self.kwargs) diff --git a/pilot/model/adapter.py b/pilot/model/adapter.py index 5d3acf505..0b93faa40 100644 --- a/pilot/model/adapter.py +++ b/pilot/model/adapter.py @@ -96,13 +96,18 @@ def get_llm_model_adapter(model_name: str, model_path: str) -> BaseLLMAdaper: def _dynamic_model_parser() -> Callable[[None], List[Type]]: from pilot.utils.parameter_utils import _SimpleArgParser + from pilot.model.parameter import EmbeddingModelParameters, WorkerType - pre_args = _SimpleArgParser("model_name", "model_path") + pre_args = _SimpleArgParser("model_name", "model_path", "worker_type") pre_args.parse() model_name = pre_args.get("model_name") model_path = pre_args.get("model_path") + worker_type = pre_args.get("worker_type") if model_name is None: return None + if worker_type == WorkerType.TEXT2VEC: + return [EmbeddingModelParameters] + llm_adapter = get_llm_model_adapter(model_name, model_path) param_class = llm_adapter.model_param_class() return [param_class] diff --git a/pilot/model/cli.py b/pilot/model/cli.py index 2406f6920..3e94c7045 100644 --- a/pilot/model/cli.py +++ b/pilot/model/cli.py @@ -167,7 +167,6 @@ def stop(model_name: str, model_type: str, host: str, port: int): def _remote_model_dynamic_factory() -> Callable[[None], List[Type]]: - from pilot.model.adapter import _dynamic_model_parser from pilot.utils.parameter_utils import _SimpleArgParser from pilot.model.cluster import RemoteWorkerManager from pilot.model.parameter import WorkerType diff --git a/pilot/model/cluster/__init__.py b/pilot/model/cluster/__init__.py index b518a756b..b73fd7873 100644 --- a/pilot/model/cluster/__init__.py +++ b/pilot/model/cluster/__init__.py @@ -5,6 +5,9 @@ WorkerParameterRequest, WorkerStartupRequest, ) +from pilot.model.cluster.worker_base import ModelWorker +from pilot.model.cluster.worker.default_worker import DefaultModelWorker + from pilot.model.cluster.worker.manager import ( initialize_worker_manager_in_client, run_worker_manager, @@ -23,11 +26,15 @@ "EmbeddingsRequest", "PromptRequest", "WorkerApplyRequest", - "WorkerParameterRequest" - "WorkerStartupRequest" - "worker_manager" + "WorkerParameterRequest", + "WorkerStartupRequest", + "ModelWorker", + "DefaultModelWorker", + "worker_manager", "run_worker_manager", "initialize_worker_manager_in_client", "ModelRegistry", - "ModelRegistryClient" "RemoteWorkerManager" "run_model_controller", + "ModelRegistryClient", + "RemoteWorkerManager", + "run_model_controller", ] diff --git a/pilot/model/cluster/controller/controller.py b/pilot/model/cluster/controller/controller.py index d48e1362f..e1d55eab7 100644 --- a/pilot/model/cluster/controller/controller.py +++ b/pilot/model/cluster/controller/controller.py @@ -8,7 +8,10 @@ from pilot.model.parameter import ModelControllerParameters from pilot.model.cluster.registry import EmbeddedModelRegistry, ModelRegistry from pilot.utils.parameter_utils import EnvArgumentParser -from pilot.utils.api_utils import _api_remote as api_remote +from pilot.utils.api_utils import ( + _api_remote as api_remote, + _sync_api_remote as sync_api_remote, +) class BaseModelController(ABC): @@ -89,6 +92,12 @@ class ModelRegistryClient(_RemoteModelController, ModelRegistry): async def get_all_model_instances(self) -> List[ModelInstance]: return await self.get_all_instances() + @sync_api_remote(path="/api/controller/models") + def sync_get_all_instances( + self, model_name: str, healthy_only: bool = False + ) -> List[ModelInstance]: + pass + class ModelControllerAdapter(BaseModelController): def __init__(self, backend: BaseModelController = None) -> None: diff --git a/pilot/model/cluster/embedding/__init__.py b/pilot/model/cluster/embedding/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/model/cluster/embedding/remote_embedding.py b/pilot/model/cluster/embedding/remote_embedding.py new file mode 100644 index 000000000..98fb8af16 --- /dev/null +++ b/pilot/model/cluster/embedding/remote_embedding.py @@ -0,0 +1,28 @@ +from typing import List +from langchain.embeddings.base import Embeddings + +from pilot.model.cluster.manager_base import WorkerManager + + +class RemoteEmbeddings(Embeddings): + def __init__(self, model_name: str, worker_manager: WorkerManager) -> None: + self.model_name = model_name + self.worker_manager = worker_manager + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Embed search docs.""" + params = {"model": self.model_name, "input": texts} + return self.worker_manager.sync_embeddings(params) + + def embed_query(self, text: str) -> List[float]: + """Embed query text.""" + return self.embed_documents([text])[0] + + async def aembed_documents(self, texts: List[str]) -> List[List[float]]: + """Asynchronous Embed search docs.""" + params = {"model": self.model_name, "input": texts} + return await self.worker_manager.embeddings(params) + + async def aembed_query(self, text: str) -> List[float]: + """Asynchronous Embed query text.""" + return await self.aembed_documents([text])[0] diff --git a/pilot/model/cluster/manager_base.py b/pilot/model/cluster/manager_base.py index 4f3fd27e3..d66991cae 100644 --- a/pilot/model/cluster/manager_base.py +++ b/pilot/model/cluster/manager_base.py @@ -1,6 +1,6 @@ import asyncio from dataclasses import dataclass -from typing import List, Optional, Dict, Iterator +from typing import List, Optional, Dict, Iterator, Callable from abc import ABC, abstractmethod from datetime import datetime from concurrent.futures import Future @@ -35,15 +35,31 @@ async def start(self): async def stop(self): """Stop worker manager""" + @abstractmethod + def after_start(self, listener: Callable[["WorkerManager"], None]): + """Add a listener after WorkerManager startup""" + @abstractmethod async def get_model_instances( self, worker_type: str, model_name: str, healthy_only: bool = True + ) -> List[WorkerRunData]: + """Asynchronous get model instances by worker type and model name""" + + @abstractmethod + def sync_get_model_instances( + self, worker_type: str, model_name: str, healthy_only: bool = True ) -> List[WorkerRunData]: """Get model instances by worker type and model name""" @abstractmethod async def select_one_instance( self, worker_type: str, model_name: str, healthy_only: bool = True + ) -> WorkerRunData: + """Asynchronous select one instance""" + + @abstractmethod + def sync_select_one_instance( + self, worker_type: str, model_name: str, healthy_only: bool = True ) -> WorkerRunData: """Select one instance""" @@ -69,7 +85,15 @@ async def generate(self, params: Dict) -> ModelOutput: @abstractmethod async def embeddings(self, params: Dict) -> List[List[float]]: - """Embed input""" + """Asynchronous embed input""" + + @abstractmethod + def sync_embeddings(self, params: Dict) -> List[List[float]]: + """Embed input + + This function may be passed to a third-party system call for synchronous calls. + We must provide a synchronous version. + """ @abstractmethod async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput: diff --git a/pilot/model/cluster/registry.py b/pilot/model/cluster/registry.py index e5e9b0618..398882eb9 100644 --- a/pilot/model/cluster/registry.py +++ b/pilot/model/cluster/registry.py @@ -58,6 +58,12 @@ async def get_all_instances( - List[ModelInstance]: A list of instances for the given model. """ + @abstractmethod + def sync_get_all_instances( + self, model_name: str, healthy_only: bool = False + ) -> List[ModelInstance]: + """Fetch all instances of a given model. Optionally, fetch only the healthy instances.""" + @abstractmethod async def get_all_model_instances(self) -> List[ModelInstance]: """ @@ -163,6 +169,11 @@ async def deregister_instance(self, instance: ModelInstance) -> bool: async def get_all_instances( self, model_name: str, healthy_only: bool = False + ) -> List[ModelInstance]: + return self.sync_get_all_instances(model_name, healthy_only) + + def sync_get_all_instances( + self, model_name: str, healthy_only: bool = False ) -> List[ModelInstance]: instances = self.registry[model_name] if healthy_only: @@ -179,7 +190,7 @@ async def send_heartbeat(self, instance: ModelInstance) -> bool: ) if not exist_ins: # register new install from heartbeat - self.register_instance(instance) + await self.register_instance(instance) return True ins = exist_ins[0] diff --git a/pilot/model/cluster/worker/embedding_worker.py b/pilot/model/cluster/worker/embedding_worker.py index a8824f228..80f06b145 100644 --- a/pilot/model/cluster/worker/embedding_worker.py +++ b/pilot/model/cluster/worker/embedding_worker.py @@ -24,7 +24,7 @@ def __init__(self) -> None: "Could not import langchain.embeddings.HuggingFaceEmbeddings python package. " "Please install it with `pip install langchain`." ) from exc - self.embeddings: Embeddings = None + self._embeddings_impl: Embeddings = None self._model_params = None def load_worker(self, model_name: str, model_path: str, **kwargs) -> None: @@ -75,16 +75,16 @@ def start( kwargs = model_params.build_kwargs(model_name=model_params.model_path) logger.info(f"Start HuggingFaceEmbeddings with kwargs: {kwargs}") - self.embeddings = HuggingFaceEmbeddings(**kwargs) + self._embeddings_impl = HuggingFaceEmbeddings(**kwargs) def __del__(self): self.stop() def stop(self) -> None: - if not self.embeddings: + if not self._embeddings_impl: return - del self.embeddings - self.embeddings = None + del self._embeddings_impl + self._embeddings_impl = None _clear_torch_cache(self._model_params.device) def generate_stream(self, params: Dict): @@ -96,5 +96,7 @@ def generate(self, params: Dict): raise NotImplementedError("Not supported generate for embeddings model") def embeddings(self, params: Dict) -> List[List[float]]: + model = params.get("model") + logger.info(f"Receive embeddings request, model: {model}") input: List[str] = params["input"] - return self.embeddings.embed_documents(input) + return self._embeddings_impl.embed_documents(input) diff --git a/pilot/model/cluster/worker/manager.py b/pilot/model/cluster/worker/manager.py index 7e4852527..93f2a373a 100644 --- a/pilot/model/cluster/worker/manager.py +++ b/pilot/model/cluster/worker/manager.py @@ -72,6 +72,7 @@ def __init__( self.model_registry = model_registry self.host = host self.port = port + self.start_listeners = [] self.run_data = WorkerRunData( host=self.host, @@ -105,6 +106,8 @@ async def start(self): asyncio.create_task( _async_heartbeat_sender(self.run_data, 20, self.send_heartbeat_func) ) + for listener in self.start_listeners: + listener(self) async def stop(self): if not self.run_data.stop_event.is_set(): @@ -116,6 +119,9 @@ async def stop(self): stop_tasks.append(self.deregister_func(self.run_data)) await asyncio.gather(*stop_tasks) + def after_start(self, listener: Callable[["WorkerManager"], None]): + self.start_listeners.append(listener) + def add_worker( self, worker: ModelWorker, @@ -137,14 +143,7 @@ def add_worker( worker_key = self._worker_key( worker_params.worker_type, worker_params.model_name ) - host = worker_params.host - port = worker_params.port - instances = self.workers.get(worker_key) - if not instances: - instances = [] - self.workers[worker_key] = instances - logger.info(f"Init empty instances list for {worker_key}") # Load model params from persist storage model_params = worker.parse_parameters(command_args=command_args) @@ -159,14 +158,15 @@ def add_worker( semaphore=asyncio.Semaphore(worker_params.limit_model_concurrency), command_args=command_args, ) - exist_instances = [ - ins for ins in instances if ins.host == host and ins.port == port - ] - if not exist_instances: - instances.append(worker_run_data) + instances = self.workers.get(worker_key) + if not instances: + instances = [worker_run_data] + self.workers[worker_key] = instances + logger.info(f"Init empty instances list for {worker_key}") return True else: # TODO Update worker + logger.warn(f"Instance {worker_key} exist") return False async def model_startup(self, startup_req: WorkerStartupRequest) -> bool: @@ -222,16 +222,18 @@ async def supported_models(self) -> List[WorkerSupportedModel]: async def get_model_instances( self, worker_type: str, model_name: str, healthy_only: bool = True + ) -> List[WorkerRunData]: + return self.sync_get_model_instances(worker_type, model_name, healthy_only) + + def sync_get_model_instances( + self, worker_type: str, model_name: str, healthy_only: bool = True ) -> List[WorkerRunData]: worker_key = self._worker_key(worker_type, model_name) return self.workers.get(worker_key) - async def select_one_instance( - self, worker_type: str, model_name: str, healthy_only: bool = True + def _simple_select( + self, worker_type: str, model_name: str, worker_instances: List[WorkerRunData] ) -> WorkerRunData: - worker_instances = await self.get_model_instances( - worker_type, model_name, healthy_only - ) if not worker_instances: raise Exception( f"Cound not found worker instances for model name {model_name} and worker type {worker_type}" @@ -239,12 +241,34 @@ async def select_one_instance( worker_run_data = random.choice(worker_instances) return worker_run_data + async def select_one_instance( + self, worker_type: str, model_name: str, healthy_only: bool = True + ) -> WorkerRunData: + worker_instances = await self.get_model_instances( + worker_type, model_name, healthy_only + ) + return self._simple_select(worker_type, model_name, worker_instances) + + def sync_select_one_instance( + self, worker_type: str, model_name: str, healthy_only: bool = True + ) -> WorkerRunData: + worker_instances = self.sync_get_model_instances( + worker_type, model_name, healthy_only + ) + return self._simple_select(worker_type, model_name, worker_instances) + async def _get_model(self, params: Dict, worker_type: str = "llm") -> WorkerRunData: model = params.get("model") if not model: raise Exception("Model name count not be empty") return await self.select_one_instance(worker_type, model, healthy_only=True) + def _sync_get_model(self, params: Dict, worker_type: str = "llm") -> WorkerRunData: + model = params.get("model") + if not model: + raise Exception("Model name count not be empty") + return self.sync_select_one_instance(worker_type, model, healthy_only=True) + async def generate_stream( self, params: Dict, async_wrapper=None, **kwargs ) -> Iterator[ModelOutput]: @@ -304,6 +328,10 @@ async def embeddings(self, params: Dict) -> List[List[float]]: worker_run_data.worker.embeddings, params ) + def sync_embeddings(self, params: Dict) -> List[List[float]]: + worker_run_data = self._sync_get_model(params, worker_type="text2vec") + return worker_run_data.worker.embeddings(params) + async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput: apply_func: Callable[[WorkerApplyRequest], Awaitable[str]] = None if apply_req.apply_type == WorkerApplyType.START: @@ -458,6 +486,10 @@ async def start(self): async def stop(self): return await self.worker_manager.stop() + def after_start(self, listener: Callable[["WorkerManager"], None]): + if listener is not None: + self.worker_manager.after_start(listener) + async def supported_models(self) -> List[WorkerSupportedModel]: return await self.worker_manager.supported_models() @@ -474,6 +506,13 @@ async def get_model_instances( worker_type, model_name, healthy_only ) + def sync_get_model_instances( + self, worker_type: str, model_name: str, healthy_only: bool = True + ) -> List[WorkerRunData]: + return self.worker_manager.sync_get_model_instances( + worker_type, model_name, healthy_only + ) + async def select_one_instance( self, worker_type: str, model_name: str, healthy_only: bool = True ) -> WorkerRunData: @@ -481,6 +520,13 @@ async def select_one_instance( worker_type, model_name, healthy_only ) + def sync_select_one_instance( + self, worker_type: str, model_name: str, healthy_only: bool = True + ) -> WorkerRunData: + return self.worker_manager.sync_select_one_instance( + worker_type, model_name, healthy_only + ) + async def generate_stream(self, params: Dict, **kwargs) -> Iterator[ModelOutput]: async for output in self.worker_manager.generate_stream(params, **kwargs): yield output @@ -491,6 +537,9 @@ async def generate(self, params: Dict) -> ModelOutput: async def embeddings(self, params: Dict) -> List[List[float]]: return await self.worker_manager.embeddings(params) + def sync_embeddings(self, params: Dict) -> List[List[float]]: + return self.worker_manager.sync_embeddings(params) + async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput: return await self.worker_manager.worker_apply(apply_req) @@ -586,11 +635,11 @@ def _setup_fastapi(worker_params: ModelWorkerParameters, app=None): @app.on_event("startup") async def startup_event(): - asyncio.create_task(worker_manager.worker_manager.start()) + asyncio.create_task(worker_manager.start()) @app.on_event("shutdown") async def startup_event(): - await worker_manager.worker_manager.stop() + await worker_manager.stop() return app @@ -666,29 +715,60 @@ async def send_heartbeat_func(worker_run_data: WorkerRunData): def _build_worker(worker_params: ModelWorkerParameters): - if worker_params.worker_class: + worker_class = worker_params.worker_class + if worker_class: from pilot.utils.module_utils import import_from_checked_string - worker_cls = import_from_checked_string(worker_params.worker_class, ModelWorker) - logger.info( - f"Import worker class from {worker_params.worker_class} successfully" - ) - worker: ModelWorker = worker_cls() + worker_cls = import_from_checked_string(worker_class, ModelWorker) + logger.info(f"Import worker class from {worker_class} successfully") else: - from pilot.model.cluster.worker.default_worker import DefaultModelWorker + if ( + worker_params.worker_type is None + or worker_params.worker_type == WorkerType.LLM + ): + from pilot.model.cluster.worker.default_worker import DefaultModelWorker + + worker_cls = DefaultModelWorker + elif worker_params.worker_type == WorkerType.TEXT2VEC: + from pilot.model.cluster.worker.embedding_worker import ( + EmbeddingsModelWorker, + ) - worker = DefaultModelWorker() - return worker + worker_cls = EmbeddingsModelWorker + else: + raise Exception("Unsupported worker type: {worker_params.worker_type}") + + return worker_cls() def _start_local_worker( worker_manager: WorkerManagerAdapter, worker_params: ModelWorkerParameters ): worker = _build_worker(worker_params) - worker_manager.worker_manager = _create_local_model_manager(worker_params) + if not worker_manager.worker_manager: + worker_manager.worker_manager = _create_local_model_manager(worker_params) worker_manager.worker_manager.add_worker(worker, worker_params) +def _start_local_embedding_worker( + worker_manager: WorkerManagerAdapter, + embedding_model_name: str = None, + embedding_model_path: str = None, +): + if not embedding_model_name or not embedding_model_path: + return + embedding_worker_params = ModelWorkerParameters( + model_name=embedding_model_name, + model_path=embedding_model_path, + worker_type=WorkerType.TEXT2VEC, + worker_class="pilot.model.cluster.worker.embedding_worker.EmbeddingsModelWorker", + ) + logger.info( + f"Start local embedding worker with embedding parameters\n{embedding_worker_params}" + ) + _start_local_worker(worker_manager, embedding_worker_params) + + def initialize_worker_manager_in_client( app=None, include_router: bool = True, @@ -697,6 +777,9 @@ def initialize_worker_manager_in_client( run_locally: bool = True, controller_addr: str = None, local_port: int = 5000, + embedding_model_name: str = None, + embedding_model_path: str = None, + start_listener: Callable[["WorkerManager"], None] = None, ): """Initialize WorkerManager in client. If run_locally is True: @@ -728,6 +811,10 @@ def initialize_worker_manager_in_client( logger.info(f"Worker params: {worker_params}") _setup_fastapi(worker_params, app) _start_local_worker(worker_manager, worker_params) + worker_manager.after_start(start_listener) + _start_local_embedding_worker( + worker_manager, embedding_model_name, embedding_model_path + ) else: from pilot.model.cluster.controller.controller import ( ModelRegistryClient, @@ -741,9 +828,12 @@ def initialize_worker_manager_in_client( logger.info(f"Worker params: {worker_params}") client = ModelRegistryClient(worker_params.controller_addr) worker_manager.worker_manager = RemoteWorkerManager(client) + worker_manager.after_start(start_listener) initialize_controller( app=app, remote_controller_addr=worker_params.controller_addr ) + loop = asyncio.get_event_loop() + loop.run_until_complete(worker_manager.start()) if include_router and app: # mount WorkerManager router @@ -757,6 +847,8 @@ def run_worker_manager( model_path: str = None, standalone: bool = False, port: int = None, + embedding_model_name: str = None, + embedding_model_path: str = None, ): global worker_manager @@ -765,15 +857,22 @@ def run_worker_manager( ) embedded_mod = True + logger.info(f"Worker params: {worker_params}") if not app: # Run worker manager independently embedded_mod = False app = _setup_fastapi(worker_params) _start_local_worker(worker_manager, worker_params) + _start_local_embedding_worker( + worker_manager, embedding_model_name, embedding_model_path + ) else: _start_local_worker(worker_manager, worker_params) + _start_local_embedding_worker( + worker_manager, embedding_model_name, embedding_model_path + ) loop = asyncio.get_event_loop() - loop.run_until_complete(worker_manager.worker_manager.start()) + loop.run_until_complete(worker_manager.start()) if include_router: app.include_router(router, prefix="/api") diff --git a/pilot/model/cluster/worker/remote_manager.py b/pilot/model/cluster/worker/remote_manager.py index afdf10774..7f9de4d62 100644 --- a/pilot/model/cluster/worker/remote_manager.py +++ b/pilot/model/cluster/worker/remote_manager.py @@ -1,14 +1,12 @@ -from typing import Callable, Any -import httpx import asyncio +from typing import Any, Callable + +import httpx +from pilot.model.base import ModelInstance, WorkerApplyOutput, WorkerSupportedModel +from pilot.model.cluster.base import * from pilot.model.cluster.registry import ModelRegistry from pilot.model.cluster.worker.manager import LocalWorkerManager, WorkerRunData, logger -from pilot.model.cluster.base import * -from pilot.model.base import ( - ModelInstance, - WorkerApplyOutput, - WorkerSupportedModel, -) +from pilot.model.cluster.worker.remote_worker import RemoteModelWorker class RemoteWorkerManager(LocalWorkerManager): @@ -16,7 +14,8 @@ def __init__(self, model_registry: ModelRegistry = None) -> None: super().__init__(model_registry=model_registry) async def start(self): - pass + for listener in self.start_listeners: + listener(self) async def stop(self): pass @@ -125,15 +124,9 @@ async def model_shutdown(self, shutdown_req: WorkerStartupRequest) -> bool: success_handler=lambda x: True, ) - async def get_model_instances( - self, worker_type: str, model_name: str, healthy_only: bool = True + def _build_worker_instances( + self, model_name: str, instances: List[ModelInstance] ) -> List[WorkerRunData]: - from pilot.model.cluster.worker.remote_worker import RemoteModelWorker - - worker_key = self._worker_key(worker_type, model_name) - instances: List[ModelInstance] = await self.model_registry.get_all_instances( - worker_key, healthy_only - ) worker_instances = [] for ins in instances: worker = RemoteModelWorker() @@ -151,6 +144,24 @@ async def get_model_instances( worker_instances.append(wr) return worker_instances + async def get_model_instances( + self, worker_type: str, model_name: str, healthy_only: bool = True + ) -> List[WorkerRunData]: + worker_key = self._worker_key(worker_type, model_name) + instances: List[ModelInstance] = await self.model_registry.get_all_instances( + worker_key, healthy_only + ) + return self._build_worker_instances(model_name, instances) + + def sync_get_model_instances( + self, worker_type: str, model_name: str, healthy_only: bool = True + ) -> List[WorkerRunData]: + worker_key = self._worker_key(worker_type, model_name) + instances: List[ModelInstance] = self.model_registry.sync_get_all_instances( + worker_key, healthy_only + ) + return self._build_worker_instances(model_name, instances) + async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput: async def _remote_apply_func(worker_run_data: WorkerRunData): return await self._fetch_from_worker( diff --git a/pilot/model/cluster/worker/remote_worker.py b/pilot/model/cluster/worker/remote_worker.py index b123f1aa7..a0c306bbf 100644 --- a/pilot/model/cluster/worker/remote_worker.py +++ b/pilot/model/cluster/worker/remote_worker.py @@ -87,7 +87,15 @@ async def async_generate(self, params: Dict) -> ModelOutput: def embeddings(self, params: Dict) -> List[List[float]]: """Get embeddings for input""" - raise NotImplementedError + import requests + + response = requests.post( + self.worker_addr + "/embeddings", + headers=self.headers, + json=params, + timeout=self.timeout, + ) + return response.json() async def async_embeddings(self, params: Dict) -> List[List[float]]: """Asynchronous get embeddings for input""" diff --git a/pilot/model/parameter.py b/pilot/model/parameter.py index f8bc91a7c..400ba0396 100644 --- a/pilot/model/parameter.py +++ b/pilot/model/parameter.py @@ -46,9 +46,7 @@ class ModelWorkerParameters(BaseModelParameters): ) worker_class: Optional[str] = field( default=None, - metadata={ - "help": "Model worker class, pilot.model.worker.default_worker.DefaultModelWorker" - }, + metadata={"help": "Model worker class, pilot.model.cluster.DefaultModelWorker"}, ) host: Optional[str] = field( default="0.0.0.0", metadata={"help": "Model worker deploy host"} diff --git a/pilot/openapi/api_v1/api_v1.py b/pilot/openapi/api_v1/api_v1.py index fdbfcf2b1..f45a6af95 100644 --- a/pilot/openapi/api_v1/api_v1.py +++ b/pilot/openapi/api_v1/api_v1.py @@ -111,7 +111,7 @@ async def db_connect_delete(db_name: str = None): async def async_db_summary_embedding(db_name, db_type): # 在这里执行需要异步运行的代码 - db_summary_client = DBSummaryClient() + db_summary_client = DBSummaryClient(system_app=CFG.SYSTEM_APP) db_summary_client.db_summary_embedding(db_name, db_type) diff --git a/pilot/scene/chat_dashboard/chat.py b/pilot/scene/chat_dashboard/chat.py index a40ae768f..1aeca8406 100644 --- a/pilot/scene/chat_dashboard/chat.py +++ b/pilot/scene/chat_dashboard/chat.py @@ -61,7 +61,7 @@ def generate_input_values(self): except ImportError: raise ValueError("Could not import DBSummaryClient. ") - client = DBSummaryClient() + client = DBSummaryClient(system_app=CFG.SYSTEM_APP) try: table_infos = client.get_similar_tables( dbname=self.db_name, query=self.current_user_input, topk=self.top_k diff --git a/pilot/scene/chat_db/auto_execute/chat.py b/pilot/scene/chat_db/auto_execute/chat.py index 785881746..689222c0a 100644 --- a/pilot/scene/chat_db/auto_execute/chat.py +++ b/pilot/scene/chat_db/auto_execute/chat.py @@ -35,7 +35,7 @@ def generate_input_values(self): from pilot.summary.db_summary_client import DBSummaryClient except ImportError: raise ValueError("Could not import DBSummaryClient. ") - client = DBSummaryClient() + client = DBSummaryClient(system_app=CFG.SYSTEM_APP) try: table_infos = client.get_db_summary( dbname=self.db_name, diff --git a/pilot/scene/chat_db/professional_qa/chat.py b/pilot/scene/chat_db/professional_qa/chat.py index 1c246700c..374131bfe 100644 --- a/pilot/scene/chat_db/professional_qa/chat.py +++ b/pilot/scene/chat_db/professional_qa/chat.py @@ -41,7 +41,7 @@ def generate_input_values(self): except ImportError: raise ValueError("Could not import DBSummaryClient. ") if self.db_name: - client = DBSummaryClient() + client = DBSummaryClient(system_app=CFG.SYSTEM_APP) try: table_infos = client.get_db_summary( dbname=self.db_name, query=self.current_user_input, topk=self.top_k diff --git a/pilot/scene/chat_knowledge/v1/chat.py b/pilot/scene/chat_knowledge/v1/chat.py index 65402b4e3..c41345f1d 100644 --- a/pilot/scene/chat_knowledge/v1/chat.py +++ b/pilot/scene/chat_knowledge/v1/chat.py @@ -23,6 +23,7 @@ class ChatKnowledge(BaseChat): def __init__(self, chat_session_id, user_input, select_param: str = None): """ """ from pilot.embedding_engine.embedding_engine import EmbeddingEngine + from pilot.embedding_engine.embedding_factory import EmbeddingFactory self.knowledge_space = select_param super().__init__( @@ -47,9 +48,13 @@ def __init__(self, chat_session_id, user_input, select_param: str = None): "vector_store_type": CFG.VECTOR_STORE_TYPE, "chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH, } + embedding_factory = CFG.SYSTEM_APP.get_componet( + "embedding_factory", EmbeddingFactory + ) self.knowledge_embedding_client = EmbeddingEngine( model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL], vector_store_config=vector_store_config, + embedding_factory=embedding_factory, ) def generate_input_values(self): diff --git a/pilot/server/base.py b/pilot/server/base.py index 6b44e8472..59116e09c 100644 --- a/pilot/server/base.py +++ b/pilot/server/base.py @@ -2,10 +2,11 @@ import os import threading import sys -from typing import Optional +from typing import Optional, Any from dataclasses import dataclass, field from pilot.configs.config import Config +from pilot.componet import SystemApp from pilot.utils.parameter_utils import BaseParameters @@ -18,30 +19,28 @@ def signal_handler(sig, frame): os._exit(0) -def async_db_summery(): +def async_db_summery(system_app: SystemApp): from pilot.summary.db_summary_client import DBSummaryClient - client = DBSummaryClient() + client = DBSummaryClient(system_app=system_app) thread = threading.Thread(target=client.init_db_summary) thread.start() -def server_init(args): +def server_init(args, system_app: SystemApp): from pilot.commands.command_mange import CommandRegistry - from pilot.connections.manages.connection_manager import ConnectManager + from pilot.common.plugins import scan_plugins # logger.info(f"args: {args}") # init config cfg = Config() - # init connect manage - conn_manage = ConnectManager() - cfg.LOCAL_DB_MANAGE = conn_manage + cfg.SYSTEM_APP = system_app # load_native_plugins(cfg) signal.signal(signal.SIGINT, signal_handler) - async_db_summery() + cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode)) # Loader plugins and commands @@ -70,6 +69,22 @@ def server_init(args): cfg.command_disply = command_disply_registry +def _create_model_start_listener(system_app: SystemApp): + from pilot.connections.manages.connection_manager import ConnectManager + from pilot.model.cluster import worker_manager + + cfg = Config() + + def startup_event(wh): + # init connect manage + print("begin run _add_app_startup_event") + conn_manage = ConnectManager(system_app) + cfg.LOCAL_DB_MANAGE = conn_manage + async_db_summery(system_app) + + return startup_event + + @dataclass class WebWerverParameters(BaseParameters): host: Optional[str] = field( diff --git a/pilot/server/componet_configs.py b/pilot/server/componet_configs.py new file mode 100644 index 000000000..745937068 --- /dev/null +++ b/pilot/server/componet_configs.py @@ -0,0 +1,38 @@ +from typing import Any, Type, TYPE_CHECKING + +from pilot.componet import SystemApp +from pilot.embedding_engine.embedding_factory import EmbeddingFactory + +if TYPE_CHECKING: + from langchain.embeddings.base import Embeddings + + +def initialize_componets(system_app: SystemApp, embedding_model_name: str): + from pilot.model.cluster import worker_manager + + system_app.register( + RemoteEmbeddingFactory, worker_manager, model_name=embedding_model_name + ) + + +class RemoteEmbeddingFactory(EmbeddingFactory): + def __init__( + self, system_app, worker_manager, model_name: str = None, **kwargs: Any + ) -> None: + super().__init__(system_app=system_app) + self._worker_manager = worker_manager + self._default_model_name = model_name + self.kwargs = kwargs + + def init_app(self, system_app): + pass + + def create( + self, model_name: str = None, embedding_cls: Type = None + ) -> "Embeddings": + from pilot.model.cluster.embedding.remote_embedding import RemoteEmbeddings + + if embedding_cls: + raise NotImplementedError + # Ignore model_name args + return RemoteEmbeddings(self._default_model_name, self._worker_manager) diff --git a/pilot/server/dbgpt_server.py b/pilot/server/dbgpt_server.py index 011b1146a..3019f067f 100644 --- a/pilot/server/dbgpt_server.py +++ b/pilot/server/dbgpt_server.py @@ -7,9 +7,15 @@ sys.path.append(ROOT_PATH) import signal from pilot.configs.config import Config -from pilot.configs.model_config import LLM_MODEL_CONFIG +from pilot.configs.model_config import LLM_MODEL_CONFIG, EMBEDDING_MODEL_CONFIG +from pilot.componet import SystemApp -from pilot.server.base import server_init, WebWerverParameters +from pilot.server.base import ( + server_init, + WebWerverParameters, + _create_model_start_listener, +) +from pilot.server.componet_configs import initialize_componets from fastapi.staticfiles import StaticFiles from fastapi import FastAPI, applications @@ -48,6 +54,8 @@ def swagger_monkey_patch(*args, **kwargs): applications.get_swagger_ui_html = swagger_monkey_patch app = FastAPI() +system_app = SystemApp(app) + origins = ["*"] # 添加跨域中间件 @@ -98,7 +106,12 @@ def initialize_app(param: WebWerverParameters = None, args: List[str] = None): param = WebWerverParameters(**vars(parser.parse_args(args=args))) setup_logging(logging_level=param.log_level) - server_init(param) + # Before start + system_app.before_start() + + server_init(param, system_app) + model_start_listener = _create_model_start_listener(system_app) + initialize_componets(system_app, CFG.EMBEDDING_MODEL) model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL] if not param.light: @@ -108,6 +121,9 @@ def initialize_app(param: WebWerverParameters = None, args: List[str] = None): model_name=CFG.LLM_MODEL, model_path=model_path, local_port=param.port, + embedding_model_name=CFG.EMBEDDING_MODEL, + embedding_model_path=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL], + start_listener=model_start_listener, ) CFG.NEW_SERVER_MODE = True @@ -120,6 +136,7 @@ def initialize_app(param: WebWerverParameters = None, args: List[str] = None): run_locally=False, controller_addr=CFG.MODEL_SERVER, local_port=param.port, + start_listener=model_start_listener, ) CFG.SERVER_LIGHT_MODE = True diff --git a/pilot/server/knowledge/api.py b/pilot/server/knowledge/api.py index 5d9c7522b..f43333aa1 100644 --- a/pilot/server/knowledge/api.py +++ b/pilot/server/knowledge/api.py @@ -4,8 +4,6 @@ from fastapi import APIRouter, File, UploadFile, Form -from langchain.embeddings import HuggingFaceEmbeddings - from pilot.configs.config import Config from pilot.configs.model_config import ( EMBEDDING_MODEL_CONFIG, @@ -14,6 +12,7 @@ from pilot.openapi.api_view_model import Result from pilot.embedding_engine.embedding_engine import EmbeddingEngine +from pilot.embedding_engine.embedding_factory import EmbeddingFactory from pilot.server.knowledge.service import KnowledgeService from pilot.server.knowledge.request.request import ( @@ -32,10 +31,6 @@ router = APIRouter() -embeddings = HuggingFaceEmbeddings( - model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL] -) - knowledge_space_service = KnowledgeService() @@ -186,8 +181,13 @@ def document_list(space_name: str, query_request: ChunkQueryRequest): @router.post("/knowledge/{vector_name}/query") def similar_query(space_name: str, query_request: KnowledgeQueryRequest): print(f"Received params: {space_name}, {query_request}") + embedding_factory = CFG.SYSTEM_APP.get_componet( + "embedding_factory", EmbeddingFactory + ) client = EmbeddingEngine( - model_name=embeddings, vector_store_config={"vector_store_name": space_name} + model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL], + vector_store_config={"vector_store_name": space_name}, + embedding_factory=embedding_factory, ) docs = client.similar_search(query_request.query, query_request.top_k) res = [ diff --git a/pilot/server/knowledge/service.py b/pilot/server/knowledge/service.py index dc95cf9c9..0c04dee3a 100644 --- a/pilot/server/knowledge/service.py +++ b/pilot/server/knowledge/service.py @@ -154,6 +154,7 @@ def get_knowledge_documents(self, space, request: DocumentQueryRequest): def sync_knowledge_document(self, space_name, doc_ids): from pilot.embedding_engine.embedding_engine import EmbeddingEngine + from pilot.embedding_engine.embedding_factory import EmbeddingFactory from langchain.text_splitter import ( RecursiveCharacterTextSplitter, SpacyTextSplitter, @@ -204,6 +205,9 @@ def sync_knowledge_document(self, space_name, doc_ids): chunk_size=chunk_size, chunk_overlap=chunk_overlap, ) + embedding_factory = CFG.SYSTEM_APP.get_componet( + "embedding_factory", EmbeddingFactory + ) client = EmbeddingEngine( knowledge_source=doc.content, knowledge_type=doc.doc_type.upper(), @@ -214,6 +218,7 @@ def sync_knowledge_document(self, space_name, doc_ids): "chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH, }, text_splitter=text_splitter, + embedding_factory=embedding_factory, ) chunk_docs = client.read() # update document status diff --git a/pilot/server/llmserver.py b/pilot/server/llmserver.py index 2e67b370f..d521a062d 100644 --- a/pilot/server/llmserver.py +++ b/pilot/server/llmserver.py @@ -8,7 +8,7 @@ sys.path.append(ROOT_PATH) from pilot.configs.config import Config -from pilot.configs.model_config import LLM_MODEL_CONFIG +from pilot.configs.model_config import LLM_MODEL_CONFIG, EMBEDDING_MODEL_CONFIG from pilot.model.cluster import run_worker_manager CFG = Config() @@ -21,4 +21,6 @@ model_path=model_path, standalone=True, port=CFG.MODEL_PORT, + embedding_model_name=CFG.EMBEDDING_MODEL, + embedding_model_path=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL], ) diff --git a/pilot/summary/db_summary_client.py b/pilot/summary/db_summary_client.py index 1b6cc251f..f41043601 100644 --- a/pilot/summary/db_summary_client.py +++ b/pilot/summary/db_summary_client.py @@ -2,6 +2,7 @@ import uuid from pilot.common.schema import DBType +from pilot.componet import SystemApp from pilot.configs.config import Config from pilot.configs.model_config import ( KNOWLEDGE_UPLOAD_ROOT_PATH, @@ -26,16 +27,19 @@ class DBSummaryClient: , get_similar_tables method(get user query related tables info) """ - def __init__(self): - pass + def __init__(self, system_app: SystemApp): + self.system_app = system_app def db_summary_embedding(self, dbname, db_type): """put db profile and table profile summary into vector store""" - from langchain.embeddings import HuggingFaceEmbeddings from pilot.embedding_engine.string_embedding import StringEmbedding + from pilot.embedding_engine.embedding_factory import EmbeddingFactory db_summary_client = RdbmsSummary(dbname, db_type) - embeddings = HuggingFaceEmbeddings( + embedding_factory = self.system_app.get_componet( + "embedding_factory", EmbeddingFactory + ) + embeddings = embedding_factory.create( model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL] ) vector_store_config = { @@ -83,15 +87,20 @@ def db_summary_embedding(self, dbname, db_type): def get_db_summary(self, dbname, query, topk): from pilot.embedding_engine.embedding_engine import EmbeddingEngine + from pilot.embedding_engine.embedding_factory import EmbeddingFactory vector_store_config = { "vector_store_name": dbname + "_profile", "vector_store_type": CFG.VECTOR_STORE_TYPE, "chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH, } + embedding_factory = CFG.SYSTEM_APP.get_componet( + "embedding_factory", EmbeddingFactory + ) knowledge_embedding_client = EmbeddingEngine( model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL], vector_store_config=vector_store_config, + embedding_factory=embedding_factory, ) table_docs = knowledge_embedding_client.similar_search(query, topk) ans = [d.page_content for d in table_docs] @@ -100,6 +109,7 @@ def get_db_summary(self, dbname, query, topk): def get_similar_tables(self, dbname, query, topk): """get user query related tables info""" from pilot.embedding_engine.embedding_engine import EmbeddingEngine + from pilot.embedding_engine.embedding_factory import EmbeddingFactory vector_store_config = { "vector_store_name": dbname + "_summary", @@ -107,9 +117,13 @@ def get_similar_tables(self, dbname, query, topk): "vector_store_type": CFG.VECTOR_STORE_TYPE, "chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH, } + embedding_factory = CFG.SYSTEM_APP.get_componet( + "embedding_factory", EmbeddingFactory + ) knowledge_embedding_client = EmbeddingEngine( model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL], vector_store_config=vector_store_config, + embedding_factory=embedding_factory, ) if CFG.SUMMARY_CONFIG == "FAST": table_docs = knowledge_embedding_client.similar_search(query, topk) @@ -136,6 +150,7 @@ def get_similar_tables(self, dbname, query, topk): knowledge_embedding_client = EmbeddingEngine( model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL], vector_store_config=vector_store_config, + embedding_factory=embedding_factory, ) table_summery = knowledge_embedding_client.similar_search(query, 1) related_table_summaries.append(table_summery[0].page_content) diff --git a/pilot/utils/api_utils.py b/pilot/utils/api_utils.py index 58dbd1a9c..93b280188 100644 --- a/pilot/utils/api_utils.py +++ b/pilot/utils/api_utils.py @@ -15,60 +15,60 @@ def _extract_dataclass_from_generic(type_hint: Type[T]) -> Union[Type[T], None]: return None -def _api_remote(path, method="GET"): - def decorator(func): - return_type = get_type_hints(func).get("return") - if return_type is None: - raise TypeError("Return type must be annotated in the decorated function.") - - actual_dataclass = _extract_dataclass_from_generic(return_type) - logging.debug( - f"return_type: {return_type}, actual_dataclass: {actual_dataclass}" - ) - if not actual_dataclass: - actual_dataclass = return_type - sig = signature(func) - - async def wrapper(self, *args, **kwargs): - import httpx - - base_url = self.base_url # Get base_url from class instance - - bound = sig.bind(self, *args, **kwargs) - bound.apply_defaults() - - formatted_url = base_url + path.format(**bound.arguments) +def _build_request(self, func, path, method, *args, **kwargs): + return_type = get_type_hints(func).get("return") + if return_type is None: + raise TypeError("Return type must be annotated in the decorated function.") + + actual_dataclass = _extract_dataclass_from_generic(return_type) + logging.debug(f"return_type: {return_type}, actual_dataclass: {actual_dataclass}") + if not actual_dataclass: + actual_dataclass = return_type + sig = signature(func) + base_url = self.base_url # Get base_url from class instance + + bound = sig.bind(self, *args, **kwargs) + bound.apply_defaults() + + formatted_url = base_url + path.format(**bound.arguments) + + # Extract args names from signature, except "self" + arg_names = list(sig.parameters.keys())[1:] + + # Combine args and kwargs into a single dictionary + combined_args = dict(zip(arg_names, args)) + combined_args.update(kwargs) + + request_data = {} + for key, value in combined_args.items(): + if is_dataclass(value): + # Here, instead of adding it as a nested dictionary, + # we set request_data directly to its dictionary representation. + request_data = asdict(value) + else: + request_data[key] = value - # Extract args names from signature, except "self" - arg_names = list(sig.parameters.keys())[1:] + request_params = {"method": method, "url": formatted_url} - # Combine args and kwargs into a single dictionary - combined_args = dict(zip(arg_names, args)) - combined_args.update(kwargs) + if method in ["POST", "PUT", "PATCH"]: + request_params["json"] = request_data + else: # For GET, DELETE, etc. + request_params["params"] = request_data - request_data = {} - for key, value in combined_args.items(): - if is_dataclass(value): - # Here, instead of adding it as a nested dictionary, - # we set request_data directly to its dictionary representation. - request_data = asdict(value) - else: - request_data[key] = value + logging.debug(f"request_params: {request_params}, args: {args}, kwargs: {kwargs}") + return return_type, actual_dataclass, request_params - request_params = {"method": method, "url": formatted_url} - if method in ["POST", "PUT", "PATCH"]: - request_params["json"] = request_data - else: # For GET, DELETE, etc. - request_params["params"] = request_data +def _api_remote(path, method="GET"): + def decorator(func): + async def wrapper(self, *args, **kwargs): + import httpx - logging.info( - f"request_params: {request_params}, args: {args}, kwargs: {kwargs}" + return_type, actual_dataclass, request_params = _build_request( + self, func, path, method, *args, **kwargs ) - async with httpx.AsyncClient() as client: response = await client.request(**request_params) - if response.status_code == 200: return _parse_response( response.json(), return_type, actual_dataclass @@ -82,6 +82,28 @@ async def wrapper(self, *args, **kwargs): return decorator +def _sync_api_remote(path, method="GET"): + def decorator(func): + def wrapper(self, *args, **kwargs): + import requests + + return_type, actual_dataclass, request_params = _build_request( + self, func, path, method, *args, **kwargs + ) + + response = requests.request(**request_params) + + if response.status_code == 200: + return _parse_response(response.json(), return_type, actual_dataclass) + else: + error_msg = f"Remote request error, error code: {response.status_code}, error msg: {response.text}" + raise Exception(error_msg) + + return wrapper + + return decorator + + def _parse_response(json_response, return_type, actual_dataclass): # print(f'return_type.__origin__: {return_type.__origin__}, actual_dataclass: {actual_dataclass}, json_response: {json_response}') if is_dataclass(actual_dataclass): diff --git a/pilot/vector_store/extract_tovec.py b/pilot/vector_store/extract_tovec.py index c3cf2577f..b37a48ddd 100644 --- a/pilot/vector_store/extract_tovec.py +++ b/pilot/vector_store/extract_tovec.py @@ -30,8 +30,9 @@ def knownledge_tovec_st(filename): https://github.com/UKPLab/sentence-transformers """ from pilot.configs.model_config import EMBEDDING_MODEL_CONFIG + from pilot.embedding_engine.embedding_factory import DefaultEmbeddingFactory - embeddings = HuggingFaceEmbeddings( + embeddings = DefaultEmbeddingFactory().create( model_name=EMBEDDING_MODEL_CONFIG["sentence-transforms"] ) @@ -58,8 +59,9 @@ def load_knownledge_from_doc(): ) from pilot.configs.model_config import EMBEDDING_MODEL_CONFIG + from pilot.embedding_engine.embedding_factory import DefaultEmbeddingFactory - embeddings = HuggingFaceEmbeddings( + embeddings = DefaultEmbeddingFactory().create( model_name=EMBEDDING_MODEL_CONFIG["sentence-transforms"] ) diff --git a/pilot/vector_store/file_loader.py b/pilot/vector_store/file_loader.py index 76411004d..f01f96867 100644 --- a/pilot/vector_store/file_loader.py +++ b/pilot/vector_store/file_loader.py @@ -44,7 +44,11 @@ class KnownLedge2Vector: def __init__(self, model_name=None) -> None: if not model_name: # use default embedding model - self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name) + from pilot.embedding_engine.embedding_factory import DefaultEmbeddingFactory + + self.embeddings = DefaultEmbeddingFactory().create( + model_name=self.model_name + ) def init_vector_store(self): persist_dir = os.path.join(VECTORE_PATH, ".vectordb")