diff --git a/docker/compose_examples/cluster-docker-compose.yml b/docker/compose_examples/cluster-docker-compose.yml
index c4928a53a..8d4763532 100644
--- a/docker/compose_examples/cluster-docker-compose.yml
+++ b/docker/compose_examples/cluster-docker-compose.yml
@@ -4,9 +4,10 @@ services:
controller:
image: eosphorosai/dbgpt:latest
command: dbgpt start controller
+ restart: unless-stopped
networks:
- dbgptnet
- worker:
+ llm-worker:
image: eosphorosai/dbgpt:latest
command: dbgpt start worker --model_name vicuna-13b-v1.5 --model_path /app/models/vicuna-13b-v1.5 --port 8001 --controller_addr http://controller:8000
environment:
@@ -17,6 +18,27 @@ services:
- /data:/data
# Please modify it to your own model directory
- /data/models:/app/models
+ restart: unless-stopped
+ networks:
+ - dbgptnet
+ deploy:
+ resources:
+ reservations:
+ devices:
+ - driver: nvidia
+ capabilities: [gpu]
+ embedding-worker:
+ image: eosphorosai/dbgpt:latest
+ command: dbgpt start worker --model_name text2vec --worker_type text2vec --model_path /app/models/text2vec-large-chinese --port 8002 --controller_addr http://controller:8000
+ environment:
+ - DBGPT_LOG_LEVEL=DEBUG
+ depends_on:
+ - controller
+ volumes:
+ - /data:/data
+ # Please modify it to your own model directory
+ - /data/models:/app/models
+ restart: unless-stopped
networks:
- dbgptnet
deploy:
@@ -37,7 +59,8 @@ services:
- MODEL_SERVER=http://controller:8000
depends_on:
- controller
- - worker
+ - llm-worker
+ - embedding-worker
volumes:
- /data:/data
# Please modify it to your own model directory
diff --git a/docs/_static/img/muti-model-cluster-overview.png b/docs/_static/img/muti-model-cluster-overview.png
new file mode 100644
index 000000000..e087a0034
Binary files /dev/null and b/docs/_static/img/muti-model-cluster-overview.png differ
diff --git a/docs/getting_started/install.rst b/docs/getting_started/install.rst
index abb90ed6e..d6e6a15f8 100644
--- a/docs/getting_started/install.rst
+++ b/docs/getting_started/install.rst
@@ -9,6 +9,7 @@ DB-GPT product is a Web application that you can chat database, chat knowledge,
- docker
- docker_compose
- environment
+- cluster deployment
- deploy_faq
.. toctree::
@@ -20,6 +21,7 @@ DB-GPT product is a Web application that you can chat database, chat knowledge,
./install/deploy/deploy.md
./install/docker/docker.md
./install/docker_compose/docker_compose.md
+ ./install/cluster/cluster.rst
./install/llm/llm.rst
./install/environment/environment.md
./install/faq/deploy_faq.md
\ No newline at end of file
diff --git a/docs/getting_started/install/cluster/cluster.rst b/docs/getting_started/install/cluster/cluster.rst
new file mode 100644
index 000000000..f81fd4d0d
--- /dev/null
+++ b/docs/getting_started/install/cluster/cluster.rst
@@ -0,0 +1,19 @@
+Cluster deployment
+==================================
+
+In order to deploy DB-GPT to multiple nodes, you can deploy a cluster. The cluster architecture diagram is as follows:
+
+.. raw:: html
+
+
+
+
+* On :ref:`Deploying on local machine `. Local cluster deployment.
+
+.. toctree::
+ :maxdepth: 2
+ :caption: Cluster deployment
+ :name: cluster_deploy
+ :hidden:
+
+ ./vms/index.md
diff --git a/docs/getting_started/install/cluster/kubernetes/index.md b/docs/getting_started/install/cluster/kubernetes/index.md
new file mode 100644
index 000000000..385a8b054
--- /dev/null
+++ b/docs/getting_started/install/cluster/kubernetes/index.md
@@ -0,0 +1,3 @@
+Kubernetes cluster deployment
+==================================
+(kubernetes-cluster-index)=
\ No newline at end of file
diff --git a/docs/getting_started/install/llm/cluster/model_cluster.md b/docs/getting_started/install/cluster/vms/index.md
similarity index 76%
rename from docs/getting_started/install/llm/cluster/model_cluster.md
rename to docs/getting_started/install/cluster/vms/index.md
index 5576dfc1e..03b5b0293 100644
--- a/docs/getting_started/install/llm/cluster/model_cluster.md
+++ b/docs/getting_started/install/cluster/vms/index.md
@@ -1,6 +1,6 @@
-Cluster deployment
+Local cluster deployment
==================================
-
+(local-cluster-index)=
## Model cluster deployment
@@ -17,7 +17,7 @@ dbgpt start controller
By default, the Model Controller starts on port 8000.
-### Launch Model Worker
+### Launch LLM Model Worker
If you are starting `chatglm2-6b`:
@@ -39,6 +39,18 @@ dbgpt start worker --model_name vicuna-13b-v1.5 \
Note: Be sure to use your own model name and model path.
+### Launch Embedding Model Worker
+
+```bash
+
+dbgpt start worker --model_name text2vec \
+--model_path /app/models/text2vec-large-chinese \
+--worker_type text2vec \
+--port 8003 \
+--controller_addr http://127.0.0.1:8000
+```
+
+Note: Be sure to use your own model name and model path.
Check your model:
@@ -51,8 +63,12 @@ You will see the following output:
+-----------------+------------+------------+------+---------+---------+-----------------+----------------------------+
| Model Name | Model Type | Host | Port | Healthy | Enabled | Prompt Template | Last Heartbeat |
+-----------------+------------+------------+------+---------+---------+-----------------+----------------------------+
-| chatglm2-6b | llm | 172.17.0.6 | 8001 | True | True | None | 2023-08-31T04:48:45.252939 |
-| vicuna-13b-v1.5 | llm | 172.17.0.6 | 8002 | True | True | None | 2023-08-31T04:48:55.136676 |
+| chatglm2-6b | llm | 172.17.0.2 | 8001 | True | True | | 2023-09-12T23:04:31.287654 |
+| WorkerManager | service | 172.17.0.2 | 8001 | True | True | | 2023-09-12T23:04:31.286668 |
+| WorkerManager | service | 172.17.0.2 | 8003 | True | True | | 2023-09-12T23:04:29.845617 |
+| WorkerManager | service | 172.17.0.2 | 8002 | True | True | | 2023-09-12T23:04:24.598439 |
+| text2vec | text2vec | 172.17.0.2 | 8003 | True | True | | 2023-09-12T23:04:29.844796 |
+| vicuna-13b-v1.5 | llm | 172.17.0.2 | 8002 | True | True | | 2023-09-12T23:04:24.597775 |
+-----------------+------------+------------+------+---------+---------+-----------------+----------------------------+
```
@@ -69,7 +85,7 @@ MODEL_SERVER=http://127.0.0.1:8000
#### Start the webserver
```bash
-python pilot/server/dbgpt_server.py --light
+dbgpt start webserver --light
```
`--light` indicates not to start the embedded model service.
@@ -77,7 +93,7 @@ python pilot/server/dbgpt_server.py --light
Alternatively, you can prepend the command with `LLM_MODEL=chatglm2-6b` to start:
```bash
-LLM_MODEL=chatglm2-6b python pilot/server/dbgpt_server.py --light
+LLM_MODEL=chatglm2-6b dbgpt start webserver --light
```
@@ -101,9 +117,11 @@ Options:
--help Show this message and exit.
Commands:
- model Clients that manage model serving
- start Start specific server.
- stop Start specific server.
+ install Install dependencies, plugins, etc.
+ knowledge Knowledge command line tool
+ model Clients that manage model serving
+ start Start specific server.
+ stop Start specific server.
```
**View the `dbgpt start` help**
@@ -146,10 +164,11 @@ Options:
--model_name TEXT Model name [required]
--model_path TEXT Model path [required]
--worker_type TEXT Worker type
- --worker_class TEXT Model worker class, pilot.model.worker.defau
- lt_worker.DefaultModelWorker
+ --worker_class TEXT Model worker class,
+ pilot.model.cluster.DefaultModelWorker
--host TEXT Model worker deploy host [default: 0.0.0.0]
- --port INTEGER Model worker deploy port [default: 8000]
+ --port INTEGER Model worker deploy port [default: 8001]
+ --daemon Run Model Worker in background
--limit_model_concurrency INTEGER
Model concurrency limit [default: 5]
--standalone Standalone mode. If True, embedded Run
@@ -166,7 +185,7 @@ Options:
(seconds) [default: 20]
--device TEXT Device to run model. If None, the device is
automatically determined
- --model_type TEXT Model type, huggingface or llama.cpp
+ --model_type TEXT Model type, huggingface, llama.cpp and proxy
[default: huggingface]
--prompt_template TEXT Prompt template. If None, the prompt
template is automatically determined from
@@ -190,7 +209,7 @@ Options:
--compute_dtype TEXT Model compute type
--trust_remote_code Trust remote code [default: True]
--verbose Show verbose output.
- --help Show this message and exit.
+ --help Show this message and exit.
```
**View the `dbgpt model`help**
@@ -208,10 +227,13 @@ Usage: dbgpt model [OPTIONS] COMMAND [ARGS]...
Options:
--address TEXT Address of the Model Controller to connect to. Just support
- light deploy model [default: http://127.0.0.1:8000]
+ light deploy model, If the environment variable
+ CONTROLLER_ADDRESS is configured, read from the environment
+ variable
--help Show this message and exit.
Commands:
+ chat Interact with your bot from the command line
list List model instances
restart Restart model instances
start Start model instances
diff --git a/docs/getting_started/install/llm/llm.rst b/docs/getting_started/install/llm/llm.rst
index accde0250..a678c3e2a 100644
--- a/docs/getting_started/install/llm/llm.rst
+++ b/docs/getting_started/install/llm/llm.rst
@@ -6,6 +6,7 @@ DB-GPT provides a management and deployment solution for multiple models. This c
Multi LLMs Support, Supports multiple large language models, currently supporting
+ - 🔥 Baichuan2(7b,13b)
- 🔥 Vicuna-v1.5(7b,13b)
- 🔥 llama-2(7b,13b,70b)
- WizardLM-v1.2(13b)
@@ -19,7 +20,6 @@ Multi LLMs Support, Supports multiple large language models, currently supportin
- llama_cpp
- quantization
-- cluster deployment
.. toctree::
:maxdepth: 2
@@ -29,4 +29,3 @@ Multi LLMs Support, Supports multiple large language models, currently supportin
./llama/llama_cpp.md
./quantization/quantization.md
- ./cluster/model_cluster.md
diff --git a/docs/locales/zh_CN/LC_MESSAGES/getting_started/install.po b/docs/locales/zh_CN/LC_MESSAGES/getting_started/install.po
index 9e31b74e5..a4cfc8e66 100644
--- a/docs/locales/zh_CN/LC_MESSAGES/getting_started/install.po
+++ b/docs/locales/zh_CN/LC_MESSAGES/getting_started/install.po
@@ -8,7 +8,7 @@ msgid ""
msgstr ""
"Project-Id-Version: DB-GPT 👏👏 0.3.5\n"
"Report-Msgid-Bugs-To: \n"
-"POT-Creation-Date: 2023-08-16 18:31+0800\n"
+"POT-Creation-Date: 2023-09-13 09:06+0800\n"
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
"Last-Translator: FULL NAME \n"
"Language: zh_CN\n"
@@ -19,34 +19,38 @@ msgstr ""
"Content-Transfer-Encoding: 8bit\n"
"Generated-By: Babel 2.12.1\n"
-#: ../../getting_started/install.rst:2 ../../getting_started/install.rst:14
-#: 2861085e63144eaca1bb825e5f05d089
+#: ../../getting_started/install.rst:2 ../../getting_started/install.rst:15
+#: e2c13385046b4da6b6838db6ba2ea59c
msgid "Install"
msgstr "Install"
-#: ../../getting_started/install.rst:3 01a6603d91fa4520b0f839379d4eda23
+#: ../../getting_started/install.rst:3 3cb6cd251ed440dabe5d4f556435f405
msgid ""
"DB-GPT product is a Web application that you can chat database, chat "
"knowledge, text2dashboard."
msgstr "DB-GPT 可以生成sql,智能报表, 知识库问答的产品"
-#: ../../getting_started/install.rst:8 beca85cddc9b4406aecf83d5dfcce1f7
+#: ../../getting_started/install.rst:8 6fe8104b70d24f5fbfe2ad9ebf3bc3ba
msgid "deploy"
msgstr "部署"
-#: ../../getting_started/install.rst:9 601e9b9eb91f445fb07d2f1c807f0370
+#: ../../getting_started/install.rst:9 e67974b3672346809febf99a3b9a55d3
msgid "docker"
msgstr "docker"
-#: ../../getting_started/install.rst:10 6d1e094ac9284458a32a3e7fa6241c81
+#: ../../getting_started/install.rst:10 64de16a047c74598966e19a656bf6c4f
msgid "docker_compose"
msgstr "docker_compose"
-#: ../../getting_started/install.rst:11 ff1d1c60bbdc4e8ca82b7a9f303dd167
+#: ../../getting_started/install.rst:11 9f87d65e8675435b87cb9376a5bfd85c
msgid "environment"
msgstr "environment"
-#: ../../getting_started/install.rst:12 33bfbe8defd74244bfc24e8fbfd640f6
+#: ../../getting_started/install.rst:12 e60fa13bb24544ed9d4f902337093ebc
+msgid "cluster deployment"
+msgstr "集群部署"
+
+#: ../../getting_started/install.rst:13 7451712679c2412e858e7d3e2af6b174
msgid "deploy_faq"
msgstr "deploy_faq"
diff --git a/docs/locales/zh_CN/LC_MESSAGES/getting_started/install/cluster/cluster.po b/docs/locales/zh_CN/LC_MESSAGES/getting_started/install/cluster/cluster.po
new file mode 100644
index 000000000..577f15089
--- /dev/null
+++ b/docs/locales/zh_CN/LC_MESSAGES/getting_started/install/cluster/cluster.po
@@ -0,0 +1,42 @@
+# SOME DESCRIPTIVE TITLE.
+# Copyright (C) 2023, csunny
+# This file is distributed under the same license as the DB-GPT package.
+# FIRST AUTHOR , 2023.
+#
+#, fuzzy
+msgid ""
+msgstr ""
+"Project-Id-Version: DB-GPT 👏👏 0.3.6\n"
+"Report-Msgid-Bugs-To: \n"
+"POT-Creation-Date: 2023-09-13 10:11+0800\n"
+"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
+"Last-Translator: FULL NAME \n"
+"Language: zh_CN\n"
+"Language-Team: zh_CN \n"
+"Plural-Forms: nplurals=1; plural=0;\n"
+"MIME-Version: 1.0\n"
+"Content-Type: text/plain; charset=utf-8\n"
+"Content-Transfer-Encoding: 8bit\n"
+"Generated-By: Babel 2.12.1\n"
+
+#: ../../getting_started/install/cluster/cluster.rst:2
+#: ../../getting_started/install/cluster/cluster.rst:13
+#: 69804208b580447798d6946150da7bdf
+msgid "Cluster deployment"
+msgstr "集群部署"
+
+#: ../../getting_started/install/cluster/cluster.rst:4
+#: fa3e4e0ae60a45eb836bcd256baa9d91
+msgid ""
+"In order to deploy DB-GPT to multiple nodes, you can deploy a cluster. "
+"The cluster architecture diagram is as follows:"
+msgstr "为了能将 DB-GPT 部署到多个节点上,你可以部署一个集群,集群的架构图如下:"
+
+#: ../../getting_started/install/cluster/cluster.rst:11
+#: e739449099ca43cabe9883233ca7e572
+#, fuzzy
+msgid ""
+"On :ref:`Deploying on local machine `. Local cluster"
+" deployment."
+msgstr "关于 :ref:`在本地机器上部署 `。本地集群部署。"
+
diff --git a/docs/locales/zh_CN/LC_MESSAGES/getting_started/install/cluster/kubernetes/index.po b/docs/locales/zh_CN/LC_MESSAGES/getting_started/install/cluster/kubernetes/index.po
new file mode 100644
index 000000000..b785f94f4
--- /dev/null
+++ b/docs/locales/zh_CN/LC_MESSAGES/getting_started/install/cluster/kubernetes/index.po
@@ -0,0 +1,26 @@
+# SOME DESCRIPTIVE TITLE.
+# Copyright (C) 2023, csunny
+# This file is distributed under the same license as the DB-GPT package.
+# FIRST AUTHOR , 2023.
+#
+#, fuzzy
+msgid ""
+msgstr ""
+"Project-Id-Version: DB-GPT 👏👏 0.3.6\n"
+"Report-Msgid-Bugs-To: \n"
+"POT-Creation-Date: 2023-09-13 09:06+0800\n"
+"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
+"Last-Translator: FULL NAME \n"
+"Language: zh_CN\n"
+"Language-Team: zh_CN \n"
+"Plural-Forms: nplurals=1; plural=0;\n"
+"MIME-Version: 1.0\n"
+"Content-Type: text/plain; charset=utf-8\n"
+"Content-Transfer-Encoding: 8bit\n"
+"Generated-By: Babel 2.12.1\n"
+
+#: ../../getting_started/install/cluster/kubernetes/index.md:1
+#: 48e6f08f27c74f31a8b12758fe33dc24
+msgid "Kubernetes cluster deployment"
+msgstr "Kubernetes 集群部署"
+
diff --git a/docs/locales/zh_CN/LC_MESSAGES/getting_started/install/cluster/vms/index.po b/docs/locales/zh_CN/LC_MESSAGES/getting_started/install/cluster/vms/index.po
new file mode 100644
index 000000000..af535d519
--- /dev/null
+++ b/docs/locales/zh_CN/LC_MESSAGES/getting_started/install/cluster/vms/index.po
@@ -0,0 +1,176 @@
+# SOME DESCRIPTIVE TITLE.
+# Copyright (C) 2023, csunny
+# This file is distributed under the same license as the DB-GPT package.
+# FIRST AUTHOR , 2023.
+#
+#, fuzzy
+msgid ""
+msgstr ""
+"Project-Id-Version: DB-GPT 👏👏 0.3.6\n"
+"Report-Msgid-Bugs-To: \n"
+"POT-Creation-Date: 2023-09-13 09:06+0800\n"
+"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
+"Last-Translator: FULL NAME \n"
+"Language: zh_CN\n"
+"Language-Team: zh_CN \n"
+"Plural-Forms: nplurals=1; plural=0;\n"
+"MIME-Version: 1.0\n"
+"Content-Type: text/plain; charset=utf-8\n"
+"Content-Transfer-Encoding: 8bit\n"
+"Generated-By: Babel 2.12.1\n"
+
+#: ../../getting_started/install/cluster/vms/index.md:1
+#: 2d2e04ba49364eae9b8493bb274765a6
+msgid "Local cluster deployment"
+msgstr "本地集群部署"
+
+#: ../../getting_started/install/cluster/vms/index.md:4
+#: e405d0e7ad8c4b2da4b4ca27c77f5fea
+msgid "Model cluster deployment"
+msgstr "模型集群部署"
+
+#: ../../getting_started/install/cluster/vms/index.md:7
+#: bba397ddac754a2bab8edca163875b65
+msgid "**Installing Command-Line Tool**"
+msgstr "**安装命令行工具**"
+
+#: ../../getting_started/install/cluster/vms/index.md:9
+#: bc45851124354522af8c9bb9748ff1fa
+msgid ""
+"All operations below are performed using the `dbgpt` command. To use the "
+"`dbgpt` command, you need to install the DB-GPT project with `pip install"
+" -e .`. Alternatively, you can use `python pilot/scripts/cli_scripts.py` "
+"as a substitute for the `dbgpt` command."
+msgstr ""
+"以下所有操作都使用 `dbgpt` 命令完成。要使用 `dbgpt` 命令,您需要安装DB-GPT项目,方法是使用`pip install -e .`。或者,您可以使用 `python pilot/scripts/cli_scripts.py` 作为 `dbgpt` 命令的替代。"
+
+#: ../../getting_started/install/cluster/vms/index.md:11
+#: 9d11f7807fd140c8949b634700adc966
+msgid "Launch Model Controller"
+msgstr "启动 Model Controller"
+
+#: ../../getting_started/install/cluster/vms/index.md:17
+#: 97716be92ba64ce9a215433bddf77add
+msgid "By default, the Model Controller starts on port 8000."
+msgstr "默认情况下,Model Controller 启动在 8000 端口。"
+
+#: ../../getting_started/install/cluster/vms/index.md:20
+#: 3f65e6a1e59248a59c033891d1ab7ba8
+msgid "Launch LLM Model Worker"
+msgstr "启动 LLM Model Worker"
+
+#: ../../getting_started/install/cluster/vms/index.md:22
+#: 60241d97573e4265b7fb150c378c4a08
+msgid "If you are starting `chatglm2-6b`:"
+msgstr "如果您启动的是 `chatglm2-6b`:"
+
+#: ../../getting_started/install/cluster/vms/index.md:31
+#: 18bbeb1de110438fa96dd5c736b9a7b1
+msgid "If you are starting `vicuna-13b-v1.5`:"
+msgstr "如果您启动的是 `vicuna-13b-v1.5`:"
+
+#: ../../getting_started/install/cluster/vms/index.md:40
+#: ../../getting_started/install/cluster/vms/index.md:53
+#: 24b1a27313c64224aaeab6cbfad1fe19 fc94a698a7904c6893eef7e7a6e52972
+msgid "Note: Be sure to use your own model name and model path."
+msgstr "注意:确保使用您自己的模型名称和模型路径。"
+
+#: ../../getting_started/install/cluster/vms/index.md:42
+#: 19746195e85f4784bf66a9e67378c04b
+msgid "Launch Embedding Model Worker"
+msgstr "启动 Embedding Model Worker"
+
+#: ../../getting_started/install/cluster/vms/index.md:55
+#: e93ce68091f64d0294b3f912a66cc18b
+msgid "Check your model:"
+msgstr "检查您的模型:"
+
+#: ../../getting_started/install/cluster/vms/index.md:61
+#: fa0b8f3a18fe4bab88fbf002bf26d32e
+msgid "You will see the following output:"
+msgstr "您将看到以下输出:"
+
+#: ../../getting_started/install/cluster/vms/index.md:75
+#: 695262fb4f224101902bc7865ac7871f
+msgid "Connect to the model service in the webserver (dbgpt_server)"
+msgstr "在 webserver (dbgpt_server) 中连接到模型服务 (dbgpt_server)"
+
+#: ../../getting_started/install/cluster/vms/index.md:77
+#: 73bf4c2ae5c64d938e3b7e77c06fa21e
+msgid ""
+"**First, modify the `.env` file to change the model name and the Model "
+"Controller connection address.**"
+msgstr ""
+"**首先,修改 `.env` 文件以更改模型名称和模型控制器连接地址。**"
+
+#: ../../getting_started/install/cluster/vms/index.md:85
+#: 8ab126fd72ed4368a79b821ba50e62c8
+msgid "Start the webserver"
+msgstr "启动 webserver"
+
+#: ../../getting_started/install/cluster/vms/index.md:91
+#: 5a7e25c84ca2412bb64310bfad9e2403
+msgid "`--light` indicates not to start the embedded model service."
+msgstr "`--light` 表示不启动嵌入式模型服务。"
+
+#: ../../getting_started/install/cluster/vms/index.md:93
+#: 8cd9ec4fa9cb4c0fa8ff05c05a85ea7f
+msgid ""
+"Alternatively, you can prepend the command with `LLM_MODEL=chatglm2-6b` "
+"to start:"
+msgstr ""
+"或者,您可以在命令前加上 `LLM_MODEL=chatglm2-6b` 来启动:"
+
+#: ../../getting_started/install/cluster/vms/index.md:100
+#: 13ed16758a104860b5fc982d36638b17
+msgid "More Command-Line Usages"
+msgstr "更多命令行用法"
+
+#: ../../getting_started/install/cluster/vms/index.md:102
+#: 175f614d547a4391bab9a77762f9174e
+msgid "You can view more command-line usages through the help command."
+msgstr "您可以通过帮助命令查看更多命令行用法。"
+
+#: ../../getting_started/install/cluster/vms/index.md:104
+#: 6a4475d271c347fbbb35f2936a86823f
+msgid "**View the `dbgpt` help**"
+msgstr "**查看 `dbgpt` 帮助**"
+
+#: ../../getting_started/install/cluster/vms/index.md:109
+#: 3eb11234cf504cc9ac369d8462daa14b
+msgid "You will see the basic command parameters and usage:"
+msgstr "您将看到基本的命令参数和用法:"
+
+#: ../../getting_started/install/cluster/vms/index.md:127
+#: 6eb47aecceec414e8510fe022b6fddbd
+msgid "**View the `dbgpt start` help**"
+msgstr "**查看 `dbgpt start` 帮助**"
+
+#: ../../getting_started/install/cluster/vms/index.md:133
+#: 1f4c0a4ce0704ca8ac33178bd13c69ad
+msgid "Here you can see the related commands and usage for start:"
+msgstr "在这里,您可以看到启动的相关命令和用法:"
+
+#: ../../getting_started/install/cluster/vms/index.md:150
+#: 22e8e67bc55244e79764d091f334560b
+msgid "**View the `dbgpt start worker`help**"
+msgstr "**查看 `dbgpt start worker` 帮助**"
+
+#: ../../getting_started/install/cluster/vms/index.md:156
+#: 5631b83fda714780855e99e90d4eb542
+msgid "Here you can see the parameters to start Model Worker:"
+msgstr "在这里,您可以看到启动 Model Worker 的参数:"
+
+#: ../../getting_started/install/cluster/vms/index.md:215
+#: cf4a31fd3368481cba1b3ab382615f53
+msgid "**View the `dbgpt model`help**"
+msgstr "**查看 `dbgpt model` 帮助**"
+
+#: ../../getting_started/install/cluster/vms/index.md:221
+#: 3740774ec4b240f2882b5b59da224d55
+msgid ""
+"The `dbgpt model ` command can connect to the Model Controller via the "
+"Model Controller address and then manage a remote model:"
+msgstr ""
+"`dbgpt model` 命令可以通过 Model Controller 地址连接到 Model Controller,然后管理远程模型:"
+
diff --git a/docs/locales/zh_CN/LC_MESSAGES/getting_started/install/llm/llm.po b/docs/locales/zh_CN/LC_MESSAGES/getting_started/install/llm/llm.po
index b5531e9f5..5041134b7 100644
--- a/docs/locales/zh_CN/LC_MESSAGES/getting_started/install/llm/llm.po
+++ b/docs/locales/zh_CN/LC_MESSAGES/getting_started/install/llm/llm.po
@@ -8,7 +8,7 @@ msgid ""
msgstr ""
"Project-Id-Version: DB-GPT 👏👏 0.3.5\n"
"Report-Msgid-Bugs-To: \n"
-"POT-Creation-Date: 2023-08-31 16:38+0800\n"
+"POT-Creation-Date: 2023-09-13 10:46+0800\n"
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
"Last-Translator: FULL NAME \n"
"Language: zh_CN\n"
@@ -21,84 +21,88 @@ msgstr ""
#: ../../getting_started/install/llm/llm.rst:2
#: ../../getting_started/install/llm/llm.rst:24
-#: b348d4df8ca44dd78b42157a8ff6d33d
+#: e693a8d3769b4d9e99c4442ca77dc43c
msgid "LLM Usage"
msgstr "LLM使用"
-#: ../../getting_started/install/llm/llm.rst:3 7f5960a7e5634254b330da27be87594b
+#: ../../getting_started/install/llm/llm.rst:3 0a73562d18ba455bab04277b715c3840
msgid ""
"DB-GPT provides a management and deployment solution for multiple models."
" This chapter mainly discusses how to deploy different models."
msgstr "DB-GPT提供了多模型的管理和部署方案,本章主要讲解针对不同的模型该怎么部署"
-#: ../../getting_started/install/llm/llm.rst:18
-#: b844ab204ec740ec9d7d191bb841f09e
+#: ../../getting_started/install/llm/llm.rst:19
+#: d7e4de2a7e004888897204ec76b6030b
msgid ""
"Multi LLMs Support, Supports multiple large language models, currently "
"supporting"
msgstr "目前DB-GPT已适配如下模型"
-#: ../../getting_started/install/llm/llm.rst:9 c141437ddaf84c079360008343041b2f
+#: ../../getting_started/install/llm/llm.rst:9 4616886b8b2244bd93355e871356d89e
+#, fuzzy
+msgid "🔥 Baichuan2(7b,13b)"
+msgstr "Baichuan(7b,13b)"
+
+#: ../../getting_started/install/llm/llm.rst:10
+#: ad0e4793d4e744c1bdf59f5a3d9c84be
msgid "🔥 Vicuna-v1.5(7b,13b)"
msgstr "🔥 Vicuna-v1.5(7b,13b)"
-#: ../../getting_started/install/llm/llm.rst:10
-#: d32b1e3f114c4eab8782b497097c1b37
+#: ../../getting_started/install/llm/llm.rst:11
+#: d291e58001ae487bbbf2a1f9f889f5fd
msgid "🔥 llama-2(7b,13b,70b)"
msgstr "🔥 llama-2(7b,13b,70b)"
-#: ../../getting_started/install/llm/llm.rst:11
-#: 0a417ee4d008421da07fff7add5d05eb
+#: ../../getting_started/install/llm/llm.rst:12
+#: 1e49702ee40b4655945a2a13efaad536
msgid "WizardLM-v1.2(13b)"
msgstr "WizardLM-v1.2(13b)"
-#: ../../getting_started/install/llm/llm.rst:12
-#: 199e1a9fe3324dc8a1bcd9cd0b1ef047
+#: ../../getting_started/install/llm/llm.rst:13
+#: 4ef5913ddfe840d7a12289e6e1d4cb60
msgid "Vicuna (7b,13b)"
msgstr "Vicuna (7b,13b)"
-#: ../../getting_started/install/llm/llm.rst:13
-#: a9e4c5100534450db3a583fa5850e4be
+#: ../../getting_started/install/llm/llm.rst:14
+#: ea46c2211257459285fa48083cb59561
msgid "ChatGLM-6b (int4,int8)"
msgstr "ChatGLM-6b (int4,int8)"
-#: ../../getting_started/install/llm/llm.rst:14
-#: 943324289eb94042b52fd824189cd93f
+#: ../../getting_started/install/llm/llm.rst:15
+#: 90688302bae4452a84f14e8ecb7f1a21
msgid "ChatGLM2-6b (int4,int8)"
msgstr "ChatGLM2-6b (int4,int8)"
-#: ../../getting_started/install/llm/llm.rst:15
-#: f1226fdfac3b4e9d88642ffa69d75682
+#: ../../getting_started/install/llm/llm.rst:16
+#: ee1469545a314696a36e7296c7b71960
msgid "guanaco(7b,13b,33b)"
msgstr "guanaco(7b,13b,33b)"
-#: ../../getting_started/install/llm/llm.rst:16
-#: 3f2457f56eb341b6bc431c9beca8f4df
+#: ../../getting_started/install/llm/llm.rst:17
+#: 25abad241f4d4eee970d5938bf71311f
msgid "Gorilla(7b,13b)"
msgstr "Gorilla(7b,13b)"
-#: ../../getting_started/install/llm/llm.rst:17
-#: 86c8ce37be1c4a7ea3fc382100d77a9c
+#: ../../getting_started/install/llm/llm.rst:18
+#: 8e3d0399431a4c6a9065a8ae0ad3c8ac
msgid "Baichuan(7b,13b)"
msgstr "Baichuan(7b,13b)"
-#: ../../getting_started/install/llm/llm.rst:18
-#: 538111af95ad414cb2e631a89f9af379
+#: ../../getting_started/install/llm/llm.rst:19
+#: c285fa7c9c6c4e3e9840761a09955348
msgid "OpenAI"
msgstr "OpenAI"
-#: ../../getting_started/install/llm/llm.rst:20
-#: a203325b7ec248f7bff61ae89226a000
+#: ../../getting_started/install/llm/llm.rst:21
+#: 4ac13a21f323455982750bd2e0243b72
msgid "llama_cpp"
msgstr "llama_cpp"
-#: ../../getting_started/install/llm/llm.rst:21
-#: 21a50634198047228bc51a03d2c31292
+#: ../../getting_started/install/llm/llm.rst:22
+#: 7231edceef584724a6f569c6b363e083
msgid "quantization"
msgstr "quantization"
-#: ../../getting_started/install/llm/llm.rst:22
-#: dfaec4b04e6e45ff9c884b41534b1a79
-msgid "cluster deployment"
-msgstr ""
+#~ msgid "cluster deployment"
+#~ msgstr ""
diff --git a/docs/requirements.txt b/docs/requirements.txt
index 1b7f62282..7d712119a 100644
--- a/docs/requirements.txt
+++ b/docs/requirements.txt
@@ -1,4 +1,4 @@
-autodoc_pydantic==1.8.0
+autodoc_pydantic
myst_parser
nbsphinx==0.8.9
sphinx==4.5.0
diff --git a/pilot/componet.py b/pilot/componet.py
new file mode 100644
index 000000000..705eb1193
--- /dev/null
+++ b/pilot/componet.py
@@ -0,0 +1,143 @@
+from __future__ import annotations
+
+from abc import ABC, abstractmethod
+from typing import Type, Dict, TypeVar, Optional, TYPE_CHECKING
+import asyncio
+
+# Checking for type hints during runtime
+if TYPE_CHECKING:
+ from fastapi import FastAPI
+
+
+class LifeCycle:
+ """This class defines hooks for lifecycle events of a component."""
+
+ def before_start(self):
+ """Called before the component starts."""
+ pass
+
+ async def async_before_start(self):
+ """Asynchronous version of before_start."""
+ pass
+
+ def after_start(self):
+ """Called after the component has started."""
+ pass
+
+ async def async_after_start(self):
+ """Asynchronous version of after_start."""
+ pass
+
+ def before_stop(self):
+ """Called before the component stops."""
+ pass
+
+ async def async_before_stop(self):
+ """Asynchronous version of before_stop."""
+ pass
+
+
+class BaseComponet(LifeCycle, ABC):
+ """Abstract Base Component class. All custom components should extend this."""
+
+ name = "base_dbgpt_componet"
+
+ def __init__(self, system_app: Optional[SystemApp] = None):
+ if system_app is not None:
+ self.init_app(system_app)
+
+ @abstractmethod
+ def init_app(self, system_app: SystemApp):
+ """Initialize the component with the main application.
+
+ This method needs to be implemented by every component to define how it integrates
+ with the main system app.
+ """
+ pass
+
+
+T = TypeVar("T", bound=BaseComponet)
+
+
+class SystemApp(LifeCycle):
+ """Main System Application class that manages the lifecycle and registration of components."""
+
+ def __init__(self, asgi_app: Optional["FastAPI"] = None) -> None:
+ self.componets: Dict[
+ str, BaseComponet
+ ] = {} # Dictionary to store registered components.
+ self._asgi_app = asgi_app
+
+ @property
+ def app(self) -> Optional["FastAPI"]:
+ """Returns the internal ASGI app."""
+ return self._asgi_app
+
+ def register(self, componet: Type[BaseComponet], *args, **kwargs):
+ """Register a new component by its type."""
+ instance = componet(self, *args, **kwargs)
+ self.register_instance(instance)
+
+ def register_instance(self, instance: T):
+ """Register an already initialized component."""
+ self.componets[instance.name] = instance
+ instance.init_app(self)
+
+ def get_componet(self, name: str, componet_type: Type[T]) -> T:
+ """Retrieve a registered component by its name and type."""
+ component = self.componets.get(name)
+ if not component:
+ raise ValueError(f"No component found with name {name}")
+ if not isinstance(component, componet_type):
+ raise TypeError(f"Component {name} is not of type {componet_type}")
+ return component
+
+ def before_start(self):
+ """Invoke the before_start hooks for all registered components."""
+ for _, v in self.componets.items():
+ v.before_start()
+
+ async def async_before_start(self):
+ """Asynchronously invoke the before_start hooks for all registered components."""
+ tasks = [v.async_before_start() for _, v in self.componets.items()]
+ await asyncio.gather(*tasks)
+
+ def after_start(self):
+ """Invoke the after_start hooks for all registered components."""
+ for _, v in self.componets.items():
+ v.after_start()
+
+ async def async_after_start(self):
+ """Asynchronously invoke the after_start hooks for all registered components."""
+ tasks = [v.async_after_start() for _, v in self.componets.items()]
+ await asyncio.gather(*tasks)
+
+ def before_stop(self):
+ """Invoke the before_stop hooks for all registered components."""
+ for _, v in self.componets.items():
+ try:
+ v.before_stop()
+ except Exception as e:
+ pass
+
+ async def async_before_stop(self):
+ """Asynchronously invoke the before_stop hooks for all registered components."""
+ tasks = [v.async_before_stop() for _, v in self.componets.items()]
+ await asyncio.gather(*tasks)
+
+ def _build(self):
+ """Integrate lifecycle events with the internal ASGI app if available."""
+ if not self.app:
+ return
+
+ @self.app.on_event("startup")
+ async def startup_event():
+ """ASGI app startup event handler."""
+ asyncio.create_task(self.async_after_start())
+ self.after_start()
+
+ @self.app.on_event("shutdown")
+ async def shutdown_event():
+ """ASGI app shutdown event handler."""
+ await self.async_before_stop()
+ self.before_stop()
diff --git a/pilot/configs/config.py b/pilot/configs/config.py
index 8bfaed757..0276c2a17 100644
--- a/pilot/configs/config.py
+++ b/pilot/configs/config.py
@@ -189,6 +189,10 @@ def __init__(self) -> None:
### Log level
self.DBGPT_LOG_LEVEL = os.getenv("DBGPT_LOG_LEVEL", "INFO")
+ from pilot.componet import SystemApp
+
+ self.SYSTEM_APP: SystemApp = None
+
def set_debug_mode(self, value: bool) -> None:
"""Set the debug mode value"""
self.debug_mode = value
diff --git a/pilot/connections/manages/connection_manager.py b/pilot/connections/manages/connection_manager.py
index a933205c2..b5cfbbfdd 100644
--- a/pilot/connections/manages/connection_manager.py
+++ b/pilot/connections/manages/connection_manager.py
@@ -4,6 +4,7 @@
from pilot.configs.config import Config
from pilot.connections.manages.connect_storage_duckdb import DuckdbConnectConfig
from pilot.common.schema import DBType
+from pilot.componet import SystemApp
from pilot.connections.rdbms.conn_mysql import MySQLConnect
from pilot.connections.base import BaseConnect
@@ -46,9 +47,9 @@ def get_cls_by_dbtype(self, db_type):
raise ValueError("Unsupport Db Type!" + db_type)
return result
- def __init__(self):
+ def __init__(self, system_app: SystemApp):
self.storage = DuckdbConnectConfig()
- self.db_summary_client = DBSummaryClient()
+ self.db_summary_client = DBSummaryClient(system_app)
self.__load_config_db()
def __load_config_db(self):
diff --git a/pilot/connections/rdbms/tests/mange_t.py b/pilot/connections/rdbms/tests/mange_t.py
index 5140341d6..282bb8042 100644
--- a/pilot/connections/rdbms/tests/mange_t.py
+++ b/pilot/connections/rdbms/tests/mange_t.py
@@ -2,6 +2,6 @@
from pilot.connections.manages.connection_manager import ConnectManager
if __name__ == "__main__":
- mange = ConnectManager()
+ mange = ConnectManager(system_app=None)
types = mange.get_all_completed_types()
print(str(types))
diff --git a/pilot/embedding_engine/embedding_engine.py b/pilot/embedding_engine/embedding_engine.py
index f3b09ae5d..27f94eeee 100644
--- a/pilot/embedding_engine/embedding_engine.py
+++ b/pilot/embedding_engine/embedding_engine.py
@@ -1,9 +1,12 @@
from typing import Optional
from chromadb.errors import NotEnoughElementsException
-from langchain.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import TextSplitter
+from pilot.embedding_engine.embedding_factory import (
+ EmbeddingFactory,
+ DefaultEmbeddingFactory,
+)
from pilot.embedding_engine.knowledge_type import get_knowledge_embedding, KnowledgeType
from pilot.vector_store.connector import VectorStoreConnector
@@ -24,13 +27,16 @@ def __init__(
knowledge_source: Optional[str] = None,
source_reader: Optional = None,
text_splitter: Optional[TextSplitter] = None,
+ embedding_factory: EmbeddingFactory = None,
):
"""Initialize with knowledge embedding client, model_name, vector_store_config, knowledge_type, knowledge_source"""
self.knowledge_source = knowledge_source
self.model_name = model_name
self.vector_store_config = vector_store_config
self.knowledge_type = knowledge_type
- self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name)
+ if not embedding_factory:
+ embedding_factory = DefaultEmbeddingFactory()
+ self.embeddings = embedding_factory.create(model_name=self.model_name)
self.vector_store_config["embeddings"] = self.embeddings
self.source_reader = source_reader
self.text_splitter = text_splitter
diff --git a/pilot/embedding_engine/embedding_factory.py b/pilot/embedding_engine/embedding_factory.py
new file mode 100644
index 000000000..e7345d952
--- /dev/null
+++ b/pilot/embedding_engine/embedding_factory.py
@@ -0,0 +1,39 @@
+from abc import ABC, abstractmethod
+from typing import Any, Type, TYPE_CHECKING
+
+from pilot.componet import BaseComponet
+
+if TYPE_CHECKING:
+ from langchain.embeddings.base import Embeddings
+
+
+class EmbeddingFactory(BaseComponet, ABC):
+ name = "embedding_factory"
+
+ @abstractmethod
+ def create(
+ self, model_name: str = None, embedding_cls: Type = None
+ ) -> "Embeddings":
+ """Create embedding"""
+
+
+class DefaultEmbeddingFactory(EmbeddingFactory):
+ def __init__(self, system_app=None, model_name: str = None, **kwargs: Any) -> None:
+ super().__init__(system_app=system_app)
+ self._default_model_name = model_name
+ self.kwargs = kwargs
+
+ def init_app(self, system_app):
+ pass
+
+ def create(
+ self, model_name: str = None, embedding_cls: Type = None
+ ) -> "Embeddings":
+ if not model_name:
+ model_name = self._default_model_name
+ if embedding_cls:
+ return embedding_cls(model_name=model_name, **self.kwargs)
+ else:
+ from langchain.embeddings import HuggingFaceEmbeddings
+
+ return HuggingFaceEmbeddings(model_name=model_name, **self.kwargs)
diff --git a/pilot/model/adapter.py b/pilot/model/adapter.py
index 5d3acf505..0b93faa40 100644
--- a/pilot/model/adapter.py
+++ b/pilot/model/adapter.py
@@ -96,13 +96,18 @@ def get_llm_model_adapter(model_name: str, model_path: str) -> BaseLLMAdaper:
def _dynamic_model_parser() -> Callable[[None], List[Type]]:
from pilot.utils.parameter_utils import _SimpleArgParser
+ from pilot.model.parameter import EmbeddingModelParameters, WorkerType
- pre_args = _SimpleArgParser("model_name", "model_path")
+ pre_args = _SimpleArgParser("model_name", "model_path", "worker_type")
pre_args.parse()
model_name = pre_args.get("model_name")
model_path = pre_args.get("model_path")
+ worker_type = pre_args.get("worker_type")
if model_name is None:
return None
+ if worker_type == WorkerType.TEXT2VEC:
+ return [EmbeddingModelParameters]
+
llm_adapter = get_llm_model_adapter(model_name, model_path)
param_class = llm_adapter.model_param_class()
return [param_class]
diff --git a/pilot/model/cli.py b/pilot/model/cli.py
index 2406f6920..3e94c7045 100644
--- a/pilot/model/cli.py
+++ b/pilot/model/cli.py
@@ -167,7 +167,6 @@ def stop(model_name: str, model_type: str, host: str, port: int):
def _remote_model_dynamic_factory() -> Callable[[None], List[Type]]:
- from pilot.model.adapter import _dynamic_model_parser
from pilot.utils.parameter_utils import _SimpleArgParser
from pilot.model.cluster import RemoteWorkerManager
from pilot.model.parameter import WorkerType
diff --git a/pilot/model/cluster/__init__.py b/pilot/model/cluster/__init__.py
index b518a756b..b73fd7873 100644
--- a/pilot/model/cluster/__init__.py
+++ b/pilot/model/cluster/__init__.py
@@ -5,6 +5,9 @@
WorkerParameterRequest,
WorkerStartupRequest,
)
+from pilot.model.cluster.worker_base import ModelWorker
+from pilot.model.cluster.worker.default_worker import DefaultModelWorker
+
from pilot.model.cluster.worker.manager import (
initialize_worker_manager_in_client,
run_worker_manager,
@@ -23,11 +26,15 @@
"EmbeddingsRequest",
"PromptRequest",
"WorkerApplyRequest",
- "WorkerParameterRequest"
- "WorkerStartupRequest"
- "worker_manager"
+ "WorkerParameterRequest",
+ "WorkerStartupRequest",
+ "ModelWorker",
+ "DefaultModelWorker",
+ "worker_manager",
"run_worker_manager",
"initialize_worker_manager_in_client",
"ModelRegistry",
- "ModelRegistryClient" "RemoteWorkerManager" "run_model_controller",
+ "ModelRegistryClient",
+ "RemoteWorkerManager",
+ "run_model_controller",
]
diff --git a/pilot/model/cluster/controller/controller.py b/pilot/model/cluster/controller/controller.py
index d48e1362f..e1d55eab7 100644
--- a/pilot/model/cluster/controller/controller.py
+++ b/pilot/model/cluster/controller/controller.py
@@ -8,7 +8,10 @@
from pilot.model.parameter import ModelControllerParameters
from pilot.model.cluster.registry import EmbeddedModelRegistry, ModelRegistry
from pilot.utils.parameter_utils import EnvArgumentParser
-from pilot.utils.api_utils import _api_remote as api_remote
+from pilot.utils.api_utils import (
+ _api_remote as api_remote,
+ _sync_api_remote as sync_api_remote,
+)
class BaseModelController(ABC):
@@ -89,6 +92,12 @@ class ModelRegistryClient(_RemoteModelController, ModelRegistry):
async def get_all_model_instances(self) -> List[ModelInstance]:
return await self.get_all_instances()
+ @sync_api_remote(path="/api/controller/models")
+ def sync_get_all_instances(
+ self, model_name: str, healthy_only: bool = False
+ ) -> List[ModelInstance]:
+ pass
+
class ModelControllerAdapter(BaseModelController):
def __init__(self, backend: BaseModelController = None) -> None:
diff --git a/pilot/model/cluster/embedding/__init__.py b/pilot/model/cluster/embedding/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/pilot/model/cluster/embedding/remote_embedding.py b/pilot/model/cluster/embedding/remote_embedding.py
new file mode 100644
index 000000000..98fb8af16
--- /dev/null
+++ b/pilot/model/cluster/embedding/remote_embedding.py
@@ -0,0 +1,28 @@
+from typing import List
+from langchain.embeddings.base import Embeddings
+
+from pilot.model.cluster.manager_base import WorkerManager
+
+
+class RemoteEmbeddings(Embeddings):
+ def __init__(self, model_name: str, worker_manager: WorkerManager) -> None:
+ self.model_name = model_name
+ self.worker_manager = worker_manager
+
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
+ """Embed search docs."""
+ params = {"model": self.model_name, "input": texts}
+ return self.worker_manager.sync_embeddings(params)
+
+ def embed_query(self, text: str) -> List[float]:
+ """Embed query text."""
+ return self.embed_documents([text])[0]
+
+ async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
+ """Asynchronous Embed search docs."""
+ params = {"model": self.model_name, "input": texts}
+ return await self.worker_manager.embeddings(params)
+
+ async def aembed_query(self, text: str) -> List[float]:
+ """Asynchronous Embed query text."""
+ return await self.aembed_documents([text])[0]
diff --git a/pilot/model/cluster/manager_base.py b/pilot/model/cluster/manager_base.py
index 4f3fd27e3..d66991cae 100644
--- a/pilot/model/cluster/manager_base.py
+++ b/pilot/model/cluster/manager_base.py
@@ -1,6 +1,6 @@
import asyncio
from dataclasses import dataclass
-from typing import List, Optional, Dict, Iterator
+from typing import List, Optional, Dict, Iterator, Callable
from abc import ABC, abstractmethod
from datetime import datetime
from concurrent.futures import Future
@@ -35,15 +35,31 @@ async def start(self):
async def stop(self):
"""Stop worker manager"""
+ @abstractmethod
+ def after_start(self, listener: Callable[["WorkerManager"], None]):
+ """Add a listener after WorkerManager startup"""
+
@abstractmethod
async def get_model_instances(
self, worker_type: str, model_name: str, healthy_only: bool = True
+ ) -> List[WorkerRunData]:
+ """Asynchronous get model instances by worker type and model name"""
+
+ @abstractmethod
+ def sync_get_model_instances(
+ self, worker_type: str, model_name: str, healthy_only: bool = True
) -> List[WorkerRunData]:
"""Get model instances by worker type and model name"""
@abstractmethod
async def select_one_instance(
self, worker_type: str, model_name: str, healthy_only: bool = True
+ ) -> WorkerRunData:
+ """Asynchronous select one instance"""
+
+ @abstractmethod
+ def sync_select_one_instance(
+ self, worker_type: str, model_name: str, healthy_only: bool = True
) -> WorkerRunData:
"""Select one instance"""
@@ -69,7 +85,15 @@ async def generate(self, params: Dict) -> ModelOutput:
@abstractmethod
async def embeddings(self, params: Dict) -> List[List[float]]:
- """Embed input"""
+ """Asynchronous embed input"""
+
+ @abstractmethod
+ def sync_embeddings(self, params: Dict) -> List[List[float]]:
+ """Embed input
+
+ This function may be passed to a third-party system call for synchronous calls.
+ We must provide a synchronous version.
+ """
@abstractmethod
async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput:
diff --git a/pilot/model/cluster/registry.py b/pilot/model/cluster/registry.py
index e5e9b0618..398882eb9 100644
--- a/pilot/model/cluster/registry.py
+++ b/pilot/model/cluster/registry.py
@@ -58,6 +58,12 @@ async def get_all_instances(
- List[ModelInstance]: A list of instances for the given model.
"""
+ @abstractmethod
+ def sync_get_all_instances(
+ self, model_name: str, healthy_only: bool = False
+ ) -> List[ModelInstance]:
+ """Fetch all instances of a given model. Optionally, fetch only the healthy instances."""
+
@abstractmethod
async def get_all_model_instances(self) -> List[ModelInstance]:
"""
@@ -163,6 +169,11 @@ async def deregister_instance(self, instance: ModelInstance) -> bool:
async def get_all_instances(
self, model_name: str, healthy_only: bool = False
+ ) -> List[ModelInstance]:
+ return self.sync_get_all_instances(model_name, healthy_only)
+
+ def sync_get_all_instances(
+ self, model_name: str, healthy_only: bool = False
) -> List[ModelInstance]:
instances = self.registry[model_name]
if healthy_only:
@@ -179,7 +190,7 @@ async def send_heartbeat(self, instance: ModelInstance) -> bool:
)
if not exist_ins:
# register new install from heartbeat
- self.register_instance(instance)
+ await self.register_instance(instance)
return True
ins = exist_ins[0]
diff --git a/pilot/model/cluster/worker/embedding_worker.py b/pilot/model/cluster/worker/embedding_worker.py
index a8824f228..80f06b145 100644
--- a/pilot/model/cluster/worker/embedding_worker.py
+++ b/pilot/model/cluster/worker/embedding_worker.py
@@ -24,7 +24,7 @@ def __init__(self) -> None:
"Could not import langchain.embeddings.HuggingFaceEmbeddings python package. "
"Please install it with `pip install langchain`."
) from exc
- self.embeddings: Embeddings = None
+ self._embeddings_impl: Embeddings = None
self._model_params = None
def load_worker(self, model_name: str, model_path: str, **kwargs) -> None:
@@ -75,16 +75,16 @@ def start(
kwargs = model_params.build_kwargs(model_name=model_params.model_path)
logger.info(f"Start HuggingFaceEmbeddings with kwargs: {kwargs}")
- self.embeddings = HuggingFaceEmbeddings(**kwargs)
+ self._embeddings_impl = HuggingFaceEmbeddings(**kwargs)
def __del__(self):
self.stop()
def stop(self) -> None:
- if not self.embeddings:
+ if not self._embeddings_impl:
return
- del self.embeddings
- self.embeddings = None
+ del self._embeddings_impl
+ self._embeddings_impl = None
_clear_torch_cache(self._model_params.device)
def generate_stream(self, params: Dict):
@@ -96,5 +96,7 @@ def generate(self, params: Dict):
raise NotImplementedError("Not supported generate for embeddings model")
def embeddings(self, params: Dict) -> List[List[float]]:
+ model = params.get("model")
+ logger.info(f"Receive embeddings request, model: {model}")
input: List[str] = params["input"]
- return self.embeddings.embed_documents(input)
+ return self._embeddings_impl.embed_documents(input)
diff --git a/pilot/model/cluster/worker/manager.py b/pilot/model/cluster/worker/manager.py
index 7e4852527..93f2a373a 100644
--- a/pilot/model/cluster/worker/manager.py
+++ b/pilot/model/cluster/worker/manager.py
@@ -72,6 +72,7 @@ def __init__(
self.model_registry = model_registry
self.host = host
self.port = port
+ self.start_listeners = []
self.run_data = WorkerRunData(
host=self.host,
@@ -105,6 +106,8 @@ async def start(self):
asyncio.create_task(
_async_heartbeat_sender(self.run_data, 20, self.send_heartbeat_func)
)
+ for listener in self.start_listeners:
+ listener(self)
async def stop(self):
if not self.run_data.stop_event.is_set():
@@ -116,6 +119,9 @@ async def stop(self):
stop_tasks.append(self.deregister_func(self.run_data))
await asyncio.gather(*stop_tasks)
+ def after_start(self, listener: Callable[["WorkerManager"], None]):
+ self.start_listeners.append(listener)
+
def add_worker(
self,
worker: ModelWorker,
@@ -137,14 +143,7 @@ def add_worker(
worker_key = self._worker_key(
worker_params.worker_type, worker_params.model_name
)
- host = worker_params.host
- port = worker_params.port
- instances = self.workers.get(worker_key)
- if not instances:
- instances = []
- self.workers[worker_key] = instances
- logger.info(f"Init empty instances list for {worker_key}")
# Load model params from persist storage
model_params = worker.parse_parameters(command_args=command_args)
@@ -159,14 +158,15 @@ def add_worker(
semaphore=asyncio.Semaphore(worker_params.limit_model_concurrency),
command_args=command_args,
)
- exist_instances = [
- ins for ins in instances if ins.host == host and ins.port == port
- ]
- if not exist_instances:
- instances.append(worker_run_data)
+ instances = self.workers.get(worker_key)
+ if not instances:
+ instances = [worker_run_data]
+ self.workers[worker_key] = instances
+ logger.info(f"Init empty instances list for {worker_key}")
return True
else:
# TODO Update worker
+ logger.warn(f"Instance {worker_key} exist")
return False
async def model_startup(self, startup_req: WorkerStartupRequest) -> bool:
@@ -222,16 +222,18 @@ async def supported_models(self) -> List[WorkerSupportedModel]:
async def get_model_instances(
self, worker_type: str, model_name: str, healthy_only: bool = True
+ ) -> List[WorkerRunData]:
+ return self.sync_get_model_instances(worker_type, model_name, healthy_only)
+
+ def sync_get_model_instances(
+ self, worker_type: str, model_name: str, healthy_only: bool = True
) -> List[WorkerRunData]:
worker_key = self._worker_key(worker_type, model_name)
return self.workers.get(worker_key)
- async def select_one_instance(
- self, worker_type: str, model_name: str, healthy_only: bool = True
+ def _simple_select(
+ self, worker_type: str, model_name: str, worker_instances: List[WorkerRunData]
) -> WorkerRunData:
- worker_instances = await self.get_model_instances(
- worker_type, model_name, healthy_only
- )
if not worker_instances:
raise Exception(
f"Cound not found worker instances for model name {model_name} and worker type {worker_type}"
@@ -239,12 +241,34 @@ async def select_one_instance(
worker_run_data = random.choice(worker_instances)
return worker_run_data
+ async def select_one_instance(
+ self, worker_type: str, model_name: str, healthy_only: bool = True
+ ) -> WorkerRunData:
+ worker_instances = await self.get_model_instances(
+ worker_type, model_name, healthy_only
+ )
+ return self._simple_select(worker_type, model_name, worker_instances)
+
+ def sync_select_one_instance(
+ self, worker_type: str, model_name: str, healthy_only: bool = True
+ ) -> WorkerRunData:
+ worker_instances = self.sync_get_model_instances(
+ worker_type, model_name, healthy_only
+ )
+ return self._simple_select(worker_type, model_name, worker_instances)
+
async def _get_model(self, params: Dict, worker_type: str = "llm") -> WorkerRunData:
model = params.get("model")
if not model:
raise Exception("Model name count not be empty")
return await self.select_one_instance(worker_type, model, healthy_only=True)
+ def _sync_get_model(self, params: Dict, worker_type: str = "llm") -> WorkerRunData:
+ model = params.get("model")
+ if not model:
+ raise Exception("Model name count not be empty")
+ return self.sync_select_one_instance(worker_type, model, healthy_only=True)
+
async def generate_stream(
self, params: Dict, async_wrapper=None, **kwargs
) -> Iterator[ModelOutput]:
@@ -304,6 +328,10 @@ async def embeddings(self, params: Dict) -> List[List[float]]:
worker_run_data.worker.embeddings, params
)
+ def sync_embeddings(self, params: Dict) -> List[List[float]]:
+ worker_run_data = self._sync_get_model(params, worker_type="text2vec")
+ return worker_run_data.worker.embeddings(params)
+
async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput:
apply_func: Callable[[WorkerApplyRequest], Awaitable[str]] = None
if apply_req.apply_type == WorkerApplyType.START:
@@ -458,6 +486,10 @@ async def start(self):
async def stop(self):
return await self.worker_manager.stop()
+ def after_start(self, listener: Callable[["WorkerManager"], None]):
+ if listener is not None:
+ self.worker_manager.after_start(listener)
+
async def supported_models(self) -> List[WorkerSupportedModel]:
return await self.worker_manager.supported_models()
@@ -474,6 +506,13 @@ async def get_model_instances(
worker_type, model_name, healthy_only
)
+ def sync_get_model_instances(
+ self, worker_type: str, model_name: str, healthy_only: bool = True
+ ) -> List[WorkerRunData]:
+ return self.worker_manager.sync_get_model_instances(
+ worker_type, model_name, healthy_only
+ )
+
async def select_one_instance(
self, worker_type: str, model_name: str, healthy_only: bool = True
) -> WorkerRunData:
@@ -481,6 +520,13 @@ async def select_one_instance(
worker_type, model_name, healthy_only
)
+ def sync_select_one_instance(
+ self, worker_type: str, model_name: str, healthy_only: bool = True
+ ) -> WorkerRunData:
+ return self.worker_manager.sync_select_one_instance(
+ worker_type, model_name, healthy_only
+ )
+
async def generate_stream(self, params: Dict, **kwargs) -> Iterator[ModelOutput]:
async for output in self.worker_manager.generate_stream(params, **kwargs):
yield output
@@ -491,6 +537,9 @@ async def generate(self, params: Dict) -> ModelOutput:
async def embeddings(self, params: Dict) -> List[List[float]]:
return await self.worker_manager.embeddings(params)
+ def sync_embeddings(self, params: Dict) -> List[List[float]]:
+ return self.worker_manager.sync_embeddings(params)
+
async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput:
return await self.worker_manager.worker_apply(apply_req)
@@ -586,11 +635,11 @@ def _setup_fastapi(worker_params: ModelWorkerParameters, app=None):
@app.on_event("startup")
async def startup_event():
- asyncio.create_task(worker_manager.worker_manager.start())
+ asyncio.create_task(worker_manager.start())
@app.on_event("shutdown")
async def startup_event():
- await worker_manager.worker_manager.stop()
+ await worker_manager.stop()
return app
@@ -666,29 +715,60 @@ async def send_heartbeat_func(worker_run_data: WorkerRunData):
def _build_worker(worker_params: ModelWorkerParameters):
- if worker_params.worker_class:
+ worker_class = worker_params.worker_class
+ if worker_class:
from pilot.utils.module_utils import import_from_checked_string
- worker_cls = import_from_checked_string(worker_params.worker_class, ModelWorker)
- logger.info(
- f"Import worker class from {worker_params.worker_class} successfully"
- )
- worker: ModelWorker = worker_cls()
+ worker_cls = import_from_checked_string(worker_class, ModelWorker)
+ logger.info(f"Import worker class from {worker_class} successfully")
else:
- from pilot.model.cluster.worker.default_worker import DefaultModelWorker
+ if (
+ worker_params.worker_type is None
+ or worker_params.worker_type == WorkerType.LLM
+ ):
+ from pilot.model.cluster.worker.default_worker import DefaultModelWorker
+
+ worker_cls = DefaultModelWorker
+ elif worker_params.worker_type == WorkerType.TEXT2VEC:
+ from pilot.model.cluster.worker.embedding_worker import (
+ EmbeddingsModelWorker,
+ )
- worker = DefaultModelWorker()
- return worker
+ worker_cls = EmbeddingsModelWorker
+ else:
+ raise Exception("Unsupported worker type: {worker_params.worker_type}")
+
+ return worker_cls()
def _start_local_worker(
worker_manager: WorkerManagerAdapter, worker_params: ModelWorkerParameters
):
worker = _build_worker(worker_params)
- worker_manager.worker_manager = _create_local_model_manager(worker_params)
+ if not worker_manager.worker_manager:
+ worker_manager.worker_manager = _create_local_model_manager(worker_params)
worker_manager.worker_manager.add_worker(worker, worker_params)
+def _start_local_embedding_worker(
+ worker_manager: WorkerManagerAdapter,
+ embedding_model_name: str = None,
+ embedding_model_path: str = None,
+):
+ if not embedding_model_name or not embedding_model_path:
+ return
+ embedding_worker_params = ModelWorkerParameters(
+ model_name=embedding_model_name,
+ model_path=embedding_model_path,
+ worker_type=WorkerType.TEXT2VEC,
+ worker_class="pilot.model.cluster.worker.embedding_worker.EmbeddingsModelWorker",
+ )
+ logger.info(
+ f"Start local embedding worker with embedding parameters\n{embedding_worker_params}"
+ )
+ _start_local_worker(worker_manager, embedding_worker_params)
+
+
def initialize_worker_manager_in_client(
app=None,
include_router: bool = True,
@@ -697,6 +777,9 @@ def initialize_worker_manager_in_client(
run_locally: bool = True,
controller_addr: str = None,
local_port: int = 5000,
+ embedding_model_name: str = None,
+ embedding_model_path: str = None,
+ start_listener: Callable[["WorkerManager"], None] = None,
):
"""Initialize WorkerManager in client.
If run_locally is True:
@@ -728,6 +811,10 @@ def initialize_worker_manager_in_client(
logger.info(f"Worker params: {worker_params}")
_setup_fastapi(worker_params, app)
_start_local_worker(worker_manager, worker_params)
+ worker_manager.after_start(start_listener)
+ _start_local_embedding_worker(
+ worker_manager, embedding_model_name, embedding_model_path
+ )
else:
from pilot.model.cluster.controller.controller import (
ModelRegistryClient,
@@ -741,9 +828,12 @@ def initialize_worker_manager_in_client(
logger.info(f"Worker params: {worker_params}")
client = ModelRegistryClient(worker_params.controller_addr)
worker_manager.worker_manager = RemoteWorkerManager(client)
+ worker_manager.after_start(start_listener)
initialize_controller(
app=app, remote_controller_addr=worker_params.controller_addr
)
+ loop = asyncio.get_event_loop()
+ loop.run_until_complete(worker_manager.start())
if include_router and app:
# mount WorkerManager router
@@ -757,6 +847,8 @@ def run_worker_manager(
model_path: str = None,
standalone: bool = False,
port: int = None,
+ embedding_model_name: str = None,
+ embedding_model_path: str = None,
):
global worker_manager
@@ -765,15 +857,22 @@ def run_worker_manager(
)
embedded_mod = True
+ logger.info(f"Worker params: {worker_params}")
if not app:
# Run worker manager independently
embedded_mod = False
app = _setup_fastapi(worker_params)
_start_local_worker(worker_manager, worker_params)
+ _start_local_embedding_worker(
+ worker_manager, embedding_model_name, embedding_model_path
+ )
else:
_start_local_worker(worker_manager, worker_params)
+ _start_local_embedding_worker(
+ worker_manager, embedding_model_name, embedding_model_path
+ )
loop = asyncio.get_event_loop()
- loop.run_until_complete(worker_manager.worker_manager.start())
+ loop.run_until_complete(worker_manager.start())
if include_router:
app.include_router(router, prefix="/api")
diff --git a/pilot/model/cluster/worker/remote_manager.py b/pilot/model/cluster/worker/remote_manager.py
index afdf10774..7f9de4d62 100644
--- a/pilot/model/cluster/worker/remote_manager.py
+++ b/pilot/model/cluster/worker/remote_manager.py
@@ -1,14 +1,12 @@
-from typing import Callable, Any
-import httpx
import asyncio
+from typing import Any, Callable
+
+import httpx
+from pilot.model.base import ModelInstance, WorkerApplyOutput, WorkerSupportedModel
+from pilot.model.cluster.base import *
from pilot.model.cluster.registry import ModelRegistry
from pilot.model.cluster.worker.manager import LocalWorkerManager, WorkerRunData, logger
-from pilot.model.cluster.base import *
-from pilot.model.base import (
- ModelInstance,
- WorkerApplyOutput,
- WorkerSupportedModel,
-)
+from pilot.model.cluster.worker.remote_worker import RemoteModelWorker
class RemoteWorkerManager(LocalWorkerManager):
@@ -16,7 +14,8 @@ def __init__(self, model_registry: ModelRegistry = None) -> None:
super().__init__(model_registry=model_registry)
async def start(self):
- pass
+ for listener in self.start_listeners:
+ listener(self)
async def stop(self):
pass
@@ -125,15 +124,9 @@ async def model_shutdown(self, shutdown_req: WorkerStartupRequest) -> bool:
success_handler=lambda x: True,
)
- async def get_model_instances(
- self, worker_type: str, model_name: str, healthy_only: bool = True
+ def _build_worker_instances(
+ self, model_name: str, instances: List[ModelInstance]
) -> List[WorkerRunData]:
- from pilot.model.cluster.worker.remote_worker import RemoteModelWorker
-
- worker_key = self._worker_key(worker_type, model_name)
- instances: List[ModelInstance] = await self.model_registry.get_all_instances(
- worker_key, healthy_only
- )
worker_instances = []
for ins in instances:
worker = RemoteModelWorker()
@@ -151,6 +144,24 @@ async def get_model_instances(
worker_instances.append(wr)
return worker_instances
+ async def get_model_instances(
+ self, worker_type: str, model_name: str, healthy_only: bool = True
+ ) -> List[WorkerRunData]:
+ worker_key = self._worker_key(worker_type, model_name)
+ instances: List[ModelInstance] = await self.model_registry.get_all_instances(
+ worker_key, healthy_only
+ )
+ return self._build_worker_instances(model_name, instances)
+
+ def sync_get_model_instances(
+ self, worker_type: str, model_name: str, healthy_only: bool = True
+ ) -> List[WorkerRunData]:
+ worker_key = self._worker_key(worker_type, model_name)
+ instances: List[ModelInstance] = self.model_registry.sync_get_all_instances(
+ worker_key, healthy_only
+ )
+ return self._build_worker_instances(model_name, instances)
+
async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput:
async def _remote_apply_func(worker_run_data: WorkerRunData):
return await self._fetch_from_worker(
diff --git a/pilot/model/cluster/worker/remote_worker.py b/pilot/model/cluster/worker/remote_worker.py
index b123f1aa7..a0c306bbf 100644
--- a/pilot/model/cluster/worker/remote_worker.py
+++ b/pilot/model/cluster/worker/remote_worker.py
@@ -87,7 +87,15 @@ async def async_generate(self, params: Dict) -> ModelOutput:
def embeddings(self, params: Dict) -> List[List[float]]:
"""Get embeddings for input"""
- raise NotImplementedError
+ import requests
+
+ response = requests.post(
+ self.worker_addr + "/embeddings",
+ headers=self.headers,
+ json=params,
+ timeout=self.timeout,
+ )
+ return response.json()
async def async_embeddings(self, params: Dict) -> List[List[float]]:
"""Asynchronous get embeddings for input"""
diff --git a/pilot/model/parameter.py b/pilot/model/parameter.py
index f8bc91a7c..400ba0396 100644
--- a/pilot/model/parameter.py
+++ b/pilot/model/parameter.py
@@ -46,9 +46,7 @@ class ModelWorkerParameters(BaseModelParameters):
)
worker_class: Optional[str] = field(
default=None,
- metadata={
- "help": "Model worker class, pilot.model.worker.default_worker.DefaultModelWorker"
- },
+ metadata={"help": "Model worker class, pilot.model.cluster.DefaultModelWorker"},
)
host: Optional[str] = field(
default="0.0.0.0", metadata={"help": "Model worker deploy host"}
diff --git a/pilot/openapi/api_v1/api_v1.py b/pilot/openapi/api_v1/api_v1.py
index fdbfcf2b1..f45a6af95 100644
--- a/pilot/openapi/api_v1/api_v1.py
+++ b/pilot/openapi/api_v1/api_v1.py
@@ -111,7 +111,7 @@ async def db_connect_delete(db_name: str = None):
async def async_db_summary_embedding(db_name, db_type):
# 在这里执行需要异步运行的代码
- db_summary_client = DBSummaryClient()
+ db_summary_client = DBSummaryClient(system_app=CFG.SYSTEM_APP)
db_summary_client.db_summary_embedding(db_name, db_type)
diff --git a/pilot/scene/chat_dashboard/chat.py b/pilot/scene/chat_dashboard/chat.py
index a40ae768f..1aeca8406 100644
--- a/pilot/scene/chat_dashboard/chat.py
+++ b/pilot/scene/chat_dashboard/chat.py
@@ -61,7 +61,7 @@ def generate_input_values(self):
except ImportError:
raise ValueError("Could not import DBSummaryClient. ")
- client = DBSummaryClient()
+ client = DBSummaryClient(system_app=CFG.SYSTEM_APP)
try:
table_infos = client.get_similar_tables(
dbname=self.db_name, query=self.current_user_input, topk=self.top_k
diff --git a/pilot/scene/chat_db/auto_execute/chat.py b/pilot/scene/chat_db/auto_execute/chat.py
index 785881746..689222c0a 100644
--- a/pilot/scene/chat_db/auto_execute/chat.py
+++ b/pilot/scene/chat_db/auto_execute/chat.py
@@ -35,7 +35,7 @@ def generate_input_values(self):
from pilot.summary.db_summary_client import DBSummaryClient
except ImportError:
raise ValueError("Could not import DBSummaryClient. ")
- client = DBSummaryClient()
+ client = DBSummaryClient(system_app=CFG.SYSTEM_APP)
try:
table_infos = client.get_db_summary(
dbname=self.db_name,
diff --git a/pilot/scene/chat_db/professional_qa/chat.py b/pilot/scene/chat_db/professional_qa/chat.py
index 1c246700c..374131bfe 100644
--- a/pilot/scene/chat_db/professional_qa/chat.py
+++ b/pilot/scene/chat_db/professional_qa/chat.py
@@ -41,7 +41,7 @@ def generate_input_values(self):
except ImportError:
raise ValueError("Could not import DBSummaryClient. ")
if self.db_name:
- client = DBSummaryClient()
+ client = DBSummaryClient(system_app=CFG.SYSTEM_APP)
try:
table_infos = client.get_db_summary(
dbname=self.db_name, query=self.current_user_input, topk=self.top_k
diff --git a/pilot/scene/chat_knowledge/v1/chat.py b/pilot/scene/chat_knowledge/v1/chat.py
index 65402b4e3..c41345f1d 100644
--- a/pilot/scene/chat_knowledge/v1/chat.py
+++ b/pilot/scene/chat_knowledge/v1/chat.py
@@ -23,6 +23,7 @@ class ChatKnowledge(BaseChat):
def __init__(self, chat_session_id, user_input, select_param: str = None):
""" """
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
+ from pilot.embedding_engine.embedding_factory import EmbeddingFactory
self.knowledge_space = select_param
super().__init__(
@@ -47,9 +48,13 @@ def __init__(self, chat_session_id, user_input, select_param: str = None):
"vector_store_type": CFG.VECTOR_STORE_TYPE,
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
}
+ embedding_factory = CFG.SYSTEM_APP.get_componet(
+ "embedding_factory", EmbeddingFactory
+ )
self.knowledge_embedding_client = EmbeddingEngine(
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
vector_store_config=vector_store_config,
+ embedding_factory=embedding_factory,
)
def generate_input_values(self):
diff --git a/pilot/server/base.py b/pilot/server/base.py
index 6b44e8472..59116e09c 100644
--- a/pilot/server/base.py
+++ b/pilot/server/base.py
@@ -2,10 +2,11 @@
import os
import threading
import sys
-from typing import Optional
+from typing import Optional, Any
from dataclasses import dataclass, field
from pilot.configs.config import Config
+from pilot.componet import SystemApp
from pilot.utils.parameter_utils import BaseParameters
@@ -18,30 +19,28 @@ def signal_handler(sig, frame):
os._exit(0)
-def async_db_summery():
+def async_db_summery(system_app: SystemApp):
from pilot.summary.db_summary_client import DBSummaryClient
- client = DBSummaryClient()
+ client = DBSummaryClient(system_app=system_app)
thread = threading.Thread(target=client.init_db_summary)
thread.start()
-def server_init(args):
+def server_init(args, system_app: SystemApp):
from pilot.commands.command_mange import CommandRegistry
- from pilot.connections.manages.connection_manager import ConnectManager
+
from pilot.common.plugins import scan_plugins
# logger.info(f"args: {args}")
# init config
cfg = Config()
- # init connect manage
- conn_manage = ConnectManager()
- cfg.LOCAL_DB_MANAGE = conn_manage
+ cfg.SYSTEM_APP = system_app
# load_native_plugins(cfg)
signal.signal(signal.SIGINT, signal_handler)
- async_db_summery()
+
cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode))
# Loader plugins and commands
@@ -70,6 +69,22 @@ def server_init(args):
cfg.command_disply = command_disply_registry
+def _create_model_start_listener(system_app: SystemApp):
+ from pilot.connections.manages.connection_manager import ConnectManager
+ from pilot.model.cluster import worker_manager
+
+ cfg = Config()
+
+ def startup_event(wh):
+ # init connect manage
+ print("begin run _add_app_startup_event")
+ conn_manage = ConnectManager(system_app)
+ cfg.LOCAL_DB_MANAGE = conn_manage
+ async_db_summery(system_app)
+
+ return startup_event
+
+
@dataclass
class WebWerverParameters(BaseParameters):
host: Optional[str] = field(
diff --git a/pilot/server/componet_configs.py b/pilot/server/componet_configs.py
new file mode 100644
index 000000000..745937068
--- /dev/null
+++ b/pilot/server/componet_configs.py
@@ -0,0 +1,38 @@
+from typing import Any, Type, TYPE_CHECKING
+
+from pilot.componet import SystemApp
+from pilot.embedding_engine.embedding_factory import EmbeddingFactory
+
+if TYPE_CHECKING:
+ from langchain.embeddings.base import Embeddings
+
+
+def initialize_componets(system_app: SystemApp, embedding_model_name: str):
+ from pilot.model.cluster import worker_manager
+
+ system_app.register(
+ RemoteEmbeddingFactory, worker_manager, model_name=embedding_model_name
+ )
+
+
+class RemoteEmbeddingFactory(EmbeddingFactory):
+ def __init__(
+ self, system_app, worker_manager, model_name: str = None, **kwargs: Any
+ ) -> None:
+ super().__init__(system_app=system_app)
+ self._worker_manager = worker_manager
+ self._default_model_name = model_name
+ self.kwargs = kwargs
+
+ def init_app(self, system_app):
+ pass
+
+ def create(
+ self, model_name: str = None, embedding_cls: Type = None
+ ) -> "Embeddings":
+ from pilot.model.cluster.embedding.remote_embedding import RemoteEmbeddings
+
+ if embedding_cls:
+ raise NotImplementedError
+ # Ignore model_name args
+ return RemoteEmbeddings(self._default_model_name, self._worker_manager)
diff --git a/pilot/server/dbgpt_server.py b/pilot/server/dbgpt_server.py
index 011b1146a..3019f067f 100644
--- a/pilot/server/dbgpt_server.py
+++ b/pilot/server/dbgpt_server.py
@@ -7,9 +7,15 @@
sys.path.append(ROOT_PATH)
import signal
from pilot.configs.config import Config
-from pilot.configs.model_config import LLM_MODEL_CONFIG
+from pilot.configs.model_config import LLM_MODEL_CONFIG, EMBEDDING_MODEL_CONFIG
+from pilot.componet import SystemApp
-from pilot.server.base import server_init, WebWerverParameters
+from pilot.server.base import (
+ server_init,
+ WebWerverParameters,
+ _create_model_start_listener,
+)
+from pilot.server.componet_configs import initialize_componets
from fastapi.staticfiles import StaticFiles
from fastapi import FastAPI, applications
@@ -48,6 +54,8 @@ def swagger_monkey_patch(*args, **kwargs):
applications.get_swagger_ui_html = swagger_monkey_patch
app = FastAPI()
+system_app = SystemApp(app)
+
origins = ["*"]
# 添加跨域中间件
@@ -98,7 +106,12 @@ def initialize_app(param: WebWerverParameters = None, args: List[str] = None):
param = WebWerverParameters(**vars(parser.parse_args(args=args)))
setup_logging(logging_level=param.log_level)
- server_init(param)
+ # Before start
+ system_app.before_start()
+
+ server_init(param, system_app)
+ model_start_listener = _create_model_start_listener(system_app)
+ initialize_componets(system_app, CFG.EMBEDDING_MODEL)
model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL]
if not param.light:
@@ -108,6 +121,9 @@ def initialize_app(param: WebWerverParameters = None, args: List[str] = None):
model_name=CFG.LLM_MODEL,
model_path=model_path,
local_port=param.port,
+ embedding_model_name=CFG.EMBEDDING_MODEL,
+ embedding_model_path=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
+ start_listener=model_start_listener,
)
CFG.NEW_SERVER_MODE = True
@@ -120,6 +136,7 @@ def initialize_app(param: WebWerverParameters = None, args: List[str] = None):
run_locally=False,
controller_addr=CFG.MODEL_SERVER,
local_port=param.port,
+ start_listener=model_start_listener,
)
CFG.SERVER_LIGHT_MODE = True
diff --git a/pilot/server/knowledge/api.py b/pilot/server/knowledge/api.py
index 5d9c7522b..f43333aa1 100644
--- a/pilot/server/knowledge/api.py
+++ b/pilot/server/knowledge/api.py
@@ -4,8 +4,6 @@
from fastapi import APIRouter, File, UploadFile, Form
-from langchain.embeddings import HuggingFaceEmbeddings
-
from pilot.configs.config import Config
from pilot.configs.model_config import (
EMBEDDING_MODEL_CONFIG,
@@ -14,6 +12,7 @@
from pilot.openapi.api_view_model import Result
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
+from pilot.embedding_engine.embedding_factory import EmbeddingFactory
from pilot.server.knowledge.service import KnowledgeService
from pilot.server.knowledge.request.request import (
@@ -32,10 +31,6 @@
router = APIRouter()
-embeddings = HuggingFaceEmbeddings(
- model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
-)
-
knowledge_space_service = KnowledgeService()
@@ -186,8 +181,13 @@ def document_list(space_name: str, query_request: ChunkQueryRequest):
@router.post("/knowledge/{vector_name}/query")
def similar_query(space_name: str, query_request: KnowledgeQueryRequest):
print(f"Received params: {space_name}, {query_request}")
+ embedding_factory = CFG.SYSTEM_APP.get_componet(
+ "embedding_factory", EmbeddingFactory
+ )
client = EmbeddingEngine(
- model_name=embeddings, vector_store_config={"vector_store_name": space_name}
+ model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
+ vector_store_config={"vector_store_name": space_name},
+ embedding_factory=embedding_factory,
)
docs = client.similar_search(query_request.query, query_request.top_k)
res = [
diff --git a/pilot/server/knowledge/service.py b/pilot/server/knowledge/service.py
index dc95cf9c9..0c04dee3a 100644
--- a/pilot/server/knowledge/service.py
+++ b/pilot/server/knowledge/service.py
@@ -154,6 +154,7 @@ def get_knowledge_documents(self, space, request: DocumentQueryRequest):
def sync_knowledge_document(self, space_name, doc_ids):
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
+ from pilot.embedding_engine.embedding_factory import EmbeddingFactory
from langchain.text_splitter import (
RecursiveCharacterTextSplitter,
SpacyTextSplitter,
@@ -204,6 +205,9 @@ def sync_knowledge_document(self, space_name, doc_ids):
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
+ embedding_factory = CFG.SYSTEM_APP.get_componet(
+ "embedding_factory", EmbeddingFactory
+ )
client = EmbeddingEngine(
knowledge_source=doc.content,
knowledge_type=doc.doc_type.upper(),
@@ -214,6 +218,7 @@ def sync_knowledge_document(self, space_name, doc_ids):
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
},
text_splitter=text_splitter,
+ embedding_factory=embedding_factory,
)
chunk_docs = client.read()
# update document status
diff --git a/pilot/server/llmserver.py b/pilot/server/llmserver.py
index 2e67b370f..d521a062d 100644
--- a/pilot/server/llmserver.py
+++ b/pilot/server/llmserver.py
@@ -8,7 +8,7 @@
sys.path.append(ROOT_PATH)
from pilot.configs.config import Config
-from pilot.configs.model_config import LLM_MODEL_CONFIG
+from pilot.configs.model_config import LLM_MODEL_CONFIG, EMBEDDING_MODEL_CONFIG
from pilot.model.cluster import run_worker_manager
CFG = Config()
@@ -21,4 +21,6 @@
model_path=model_path,
standalone=True,
port=CFG.MODEL_PORT,
+ embedding_model_name=CFG.EMBEDDING_MODEL,
+ embedding_model_path=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
)
diff --git a/pilot/summary/db_summary_client.py b/pilot/summary/db_summary_client.py
index 1b6cc251f..f41043601 100644
--- a/pilot/summary/db_summary_client.py
+++ b/pilot/summary/db_summary_client.py
@@ -2,6 +2,7 @@
import uuid
from pilot.common.schema import DBType
+from pilot.componet import SystemApp
from pilot.configs.config import Config
from pilot.configs.model_config import (
KNOWLEDGE_UPLOAD_ROOT_PATH,
@@ -26,16 +27,19 @@ class DBSummaryClient:
, get_similar_tables method(get user query related tables info)
"""
- def __init__(self):
- pass
+ def __init__(self, system_app: SystemApp):
+ self.system_app = system_app
def db_summary_embedding(self, dbname, db_type):
"""put db profile and table profile summary into vector store"""
- from langchain.embeddings import HuggingFaceEmbeddings
from pilot.embedding_engine.string_embedding import StringEmbedding
+ from pilot.embedding_engine.embedding_factory import EmbeddingFactory
db_summary_client = RdbmsSummary(dbname, db_type)
- embeddings = HuggingFaceEmbeddings(
+ embedding_factory = self.system_app.get_componet(
+ "embedding_factory", EmbeddingFactory
+ )
+ embeddings = embedding_factory.create(
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
)
vector_store_config = {
@@ -83,15 +87,20 @@ def db_summary_embedding(self, dbname, db_type):
def get_db_summary(self, dbname, query, topk):
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
+ from pilot.embedding_engine.embedding_factory import EmbeddingFactory
vector_store_config = {
"vector_store_name": dbname + "_profile",
"vector_store_type": CFG.VECTOR_STORE_TYPE,
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
}
+ embedding_factory = CFG.SYSTEM_APP.get_componet(
+ "embedding_factory", EmbeddingFactory
+ )
knowledge_embedding_client = EmbeddingEngine(
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
vector_store_config=vector_store_config,
+ embedding_factory=embedding_factory,
)
table_docs = knowledge_embedding_client.similar_search(query, topk)
ans = [d.page_content for d in table_docs]
@@ -100,6 +109,7 @@ def get_db_summary(self, dbname, query, topk):
def get_similar_tables(self, dbname, query, topk):
"""get user query related tables info"""
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
+ from pilot.embedding_engine.embedding_factory import EmbeddingFactory
vector_store_config = {
"vector_store_name": dbname + "_summary",
@@ -107,9 +117,13 @@ def get_similar_tables(self, dbname, query, topk):
"vector_store_type": CFG.VECTOR_STORE_TYPE,
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
}
+ embedding_factory = CFG.SYSTEM_APP.get_componet(
+ "embedding_factory", EmbeddingFactory
+ )
knowledge_embedding_client = EmbeddingEngine(
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
vector_store_config=vector_store_config,
+ embedding_factory=embedding_factory,
)
if CFG.SUMMARY_CONFIG == "FAST":
table_docs = knowledge_embedding_client.similar_search(query, topk)
@@ -136,6 +150,7 @@ def get_similar_tables(self, dbname, query, topk):
knowledge_embedding_client = EmbeddingEngine(
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
vector_store_config=vector_store_config,
+ embedding_factory=embedding_factory,
)
table_summery = knowledge_embedding_client.similar_search(query, 1)
related_table_summaries.append(table_summery[0].page_content)
diff --git a/pilot/utils/api_utils.py b/pilot/utils/api_utils.py
index 58dbd1a9c..93b280188 100644
--- a/pilot/utils/api_utils.py
+++ b/pilot/utils/api_utils.py
@@ -15,60 +15,60 @@ def _extract_dataclass_from_generic(type_hint: Type[T]) -> Union[Type[T], None]:
return None
-def _api_remote(path, method="GET"):
- def decorator(func):
- return_type = get_type_hints(func).get("return")
- if return_type is None:
- raise TypeError("Return type must be annotated in the decorated function.")
-
- actual_dataclass = _extract_dataclass_from_generic(return_type)
- logging.debug(
- f"return_type: {return_type}, actual_dataclass: {actual_dataclass}"
- )
- if not actual_dataclass:
- actual_dataclass = return_type
- sig = signature(func)
-
- async def wrapper(self, *args, **kwargs):
- import httpx
-
- base_url = self.base_url # Get base_url from class instance
-
- bound = sig.bind(self, *args, **kwargs)
- bound.apply_defaults()
-
- formatted_url = base_url + path.format(**bound.arguments)
+def _build_request(self, func, path, method, *args, **kwargs):
+ return_type = get_type_hints(func).get("return")
+ if return_type is None:
+ raise TypeError("Return type must be annotated in the decorated function.")
+
+ actual_dataclass = _extract_dataclass_from_generic(return_type)
+ logging.debug(f"return_type: {return_type}, actual_dataclass: {actual_dataclass}")
+ if not actual_dataclass:
+ actual_dataclass = return_type
+ sig = signature(func)
+ base_url = self.base_url # Get base_url from class instance
+
+ bound = sig.bind(self, *args, **kwargs)
+ bound.apply_defaults()
+
+ formatted_url = base_url + path.format(**bound.arguments)
+
+ # Extract args names from signature, except "self"
+ arg_names = list(sig.parameters.keys())[1:]
+
+ # Combine args and kwargs into a single dictionary
+ combined_args = dict(zip(arg_names, args))
+ combined_args.update(kwargs)
+
+ request_data = {}
+ for key, value in combined_args.items():
+ if is_dataclass(value):
+ # Here, instead of adding it as a nested dictionary,
+ # we set request_data directly to its dictionary representation.
+ request_data = asdict(value)
+ else:
+ request_data[key] = value
- # Extract args names from signature, except "self"
- arg_names = list(sig.parameters.keys())[1:]
+ request_params = {"method": method, "url": formatted_url}
- # Combine args and kwargs into a single dictionary
- combined_args = dict(zip(arg_names, args))
- combined_args.update(kwargs)
+ if method in ["POST", "PUT", "PATCH"]:
+ request_params["json"] = request_data
+ else: # For GET, DELETE, etc.
+ request_params["params"] = request_data
- request_data = {}
- for key, value in combined_args.items():
- if is_dataclass(value):
- # Here, instead of adding it as a nested dictionary,
- # we set request_data directly to its dictionary representation.
- request_data = asdict(value)
- else:
- request_data[key] = value
+ logging.debug(f"request_params: {request_params}, args: {args}, kwargs: {kwargs}")
+ return return_type, actual_dataclass, request_params
- request_params = {"method": method, "url": formatted_url}
- if method in ["POST", "PUT", "PATCH"]:
- request_params["json"] = request_data
- else: # For GET, DELETE, etc.
- request_params["params"] = request_data
+def _api_remote(path, method="GET"):
+ def decorator(func):
+ async def wrapper(self, *args, **kwargs):
+ import httpx
- logging.info(
- f"request_params: {request_params}, args: {args}, kwargs: {kwargs}"
+ return_type, actual_dataclass, request_params = _build_request(
+ self, func, path, method, *args, **kwargs
)
-
async with httpx.AsyncClient() as client:
response = await client.request(**request_params)
-
if response.status_code == 200:
return _parse_response(
response.json(), return_type, actual_dataclass
@@ -82,6 +82,28 @@ async def wrapper(self, *args, **kwargs):
return decorator
+def _sync_api_remote(path, method="GET"):
+ def decorator(func):
+ def wrapper(self, *args, **kwargs):
+ import requests
+
+ return_type, actual_dataclass, request_params = _build_request(
+ self, func, path, method, *args, **kwargs
+ )
+
+ response = requests.request(**request_params)
+
+ if response.status_code == 200:
+ return _parse_response(response.json(), return_type, actual_dataclass)
+ else:
+ error_msg = f"Remote request error, error code: {response.status_code}, error msg: {response.text}"
+ raise Exception(error_msg)
+
+ return wrapper
+
+ return decorator
+
+
def _parse_response(json_response, return_type, actual_dataclass):
# print(f'return_type.__origin__: {return_type.__origin__}, actual_dataclass: {actual_dataclass}, json_response: {json_response}')
if is_dataclass(actual_dataclass):
diff --git a/pilot/vector_store/extract_tovec.py b/pilot/vector_store/extract_tovec.py
index c3cf2577f..b37a48ddd 100644
--- a/pilot/vector_store/extract_tovec.py
+++ b/pilot/vector_store/extract_tovec.py
@@ -30,8 +30,9 @@ def knownledge_tovec_st(filename):
https://github.com/UKPLab/sentence-transformers
"""
from pilot.configs.model_config import EMBEDDING_MODEL_CONFIG
+ from pilot.embedding_engine.embedding_factory import DefaultEmbeddingFactory
- embeddings = HuggingFaceEmbeddings(
+ embeddings = DefaultEmbeddingFactory().create(
model_name=EMBEDDING_MODEL_CONFIG["sentence-transforms"]
)
@@ -58,8 +59,9 @@ def load_knownledge_from_doc():
)
from pilot.configs.model_config import EMBEDDING_MODEL_CONFIG
+ from pilot.embedding_engine.embedding_factory import DefaultEmbeddingFactory
- embeddings = HuggingFaceEmbeddings(
+ embeddings = DefaultEmbeddingFactory().create(
model_name=EMBEDDING_MODEL_CONFIG["sentence-transforms"]
)
diff --git a/pilot/vector_store/file_loader.py b/pilot/vector_store/file_loader.py
index 76411004d..f01f96867 100644
--- a/pilot/vector_store/file_loader.py
+++ b/pilot/vector_store/file_loader.py
@@ -44,7 +44,11 @@ class KnownLedge2Vector:
def __init__(self, model_name=None) -> None:
if not model_name:
# use default embedding model
- self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name)
+ from pilot.embedding_engine.embedding_factory import DefaultEmbeddingFactory
+
+ self.embeddings = DefaultEmbeddingFactory().create(
+ model_name=self.model_name
+ )
def init_vector_store(self):
persist_dir = os.path.join(VECTORE_PATH, ".vectordb")