diff --git a/docker/compose_examples/cluster-docker-compose.yml b/docker/compose_examples/cluster-docker-compose.yml index b41033458..0ad6be9ae 100644 --- a/docker/compose_examples/cluster-docker-compose.yml +++ b/docker/compose_examples/cluster-docker-compose.yml @@ -7,6 +7,16 @@ services: restart: unless-stopped networks: - dbgptnet + api-server: + image: eosphorosai/dbgpt:latest + command: dbgpt start apiserver --controller_addr http://controller:8000 + restart: unless-stopped + depends_on: + - controller + networks: + - dbgptnet + ports: + - 8100:8100/tcp llm-worker: image: eosphorosai/dbgpt:latest command: dbgpt start worker --model_name vicuna-13b-v1.5 --model_path /app/models/vicuna-13b-v1.5 --port 8001 --controller_addr http://controller:8000 diff --git a/docs/getting_started/install/cluster/cluster.rst b/docs/getting_started/install/cluster/cluster.rst index 93660d0a4..17895e7bc 100644 --- a/docs/getting_started/install/cluster/cluster.rst +++ b/docs/getting_started/install/cluster/cluster.rst @@ -77,3 +77,4 @@ By analyzing this information, we can identify performance bottlenecks in model ./vms/standalone.md ./vms/index.md + ./openai.md diff --git a/docs/getting_started/install/cluster/openai.md b/docs/getting_started/install/cluster/openai.md new file mode 100644 index 000000000..8f23ba0fa --- /dev/null +++ b/docs/getting_started/install/cluster/openai.md @@ -0,0 +1,51 @@ +OpenAI-Compatible RESTful APIs +================================== +(openai-apis-index)= + +### Install Prepare + +You must [deploy DB-GPT cluster](https://db-gpt.readthedocs.io/en/latest/getting_started/install/cluster/vms/index.html) first. + +### Launch Model API Server + +```bash +dbgpt start apiserver --controller_addr http://127.0.0.1:8000 --api_keys EMPTY +``` +By default, the Model API Server starts on port 8100. + +### Validate with cURL + +#### List models + +```bash +curl http://127.0.0.1:8100/api/v1/models \ +-H "Authorization: Bearer EMPTY" \ +-H "Content-Type: application/json" +``` + +#### Chat completions + +```bash +curl http://127.0.0.1:8100/api/v1/chat/completions \ +-H "Authorization: Bearer EMPTY" \ +-H "Content-Type: application/json" \ +-d '{"model": "vicuna-13b-v1.5", "messages": [{"role": "user", "content": "hello"}]}' +``` + +### Validate with OpenAI Official SDK + +#### Chat completions + +```python +import openai +openai.api_key = "EMPTY" +openai.api_base = "http://127.0.0.1:8100/api/v1" +model = "vicuna-13b-v1.5" + +completion = openai.ChatCompletion.create( + model=model, + messages=[{"role": "user", "content": "hello"}] +) +# print the completion +print(completion.choices[0].message.content) +``` \ No newline at end of file diff --git a/docs/getting_started/install/llm/proxyllm/proxyllm.md b/docs/getting_started/install/llm/proxyllm/proxyllm.md index a04252d2f..fae549dd3 100644 --- a/docs/getting_started/install/llm/proxyllm/proxyllm.md +++ b/docs/getting_started/install/llm/proxyllm/proxyllm.md @@ -24,9 +24,12 @@ PROXY_SERVER_URL=https://api.openai.com/v1/chat/completions #Azure LLM_MODEL=chatgpt_proxyllm -OPENAI_API_TYPE=azure -PROXY_API_KEY={your-openai-sk} -PROXY_SERVER_URL=https://xx.openai.azure.com/v1/chat/completions +PROXY_API_KEY={your-azure-sk} +PROXY_API_BASE=https://{your domain}.openai.azure.com/ +PROXY_API_TYPE=azure +PROXY_SERVER_URL=xxxx +PROXY_API_VERSION=2023-05-15 +PROXYLLM_BACKEND=gpt-35-turbo #Aliyun tongyi LLM_MODEL=tongyi_proxyllm diff --git a/docs/locales/zh_CN/LC_MESSAGES/getting_started/install/cluster/openai.po b/docs/locales/zh_CN/LC_MESSAGES/getting_started/install/cluster/openai.po new file mode 100644 index 000000000..0ef41aa6d --- /dev/null +++ b/docs/locales/zh_CN/LC_MESSAGES/getting_started/install/cluster/openai.po @@ -0,0 +1,71 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) 2023, csunny +# This file is distributed under the same license as the DB-GPT package. +# FIRST AUTHOR , 2023. +# +#, fuzzy +msgid "" +msgstr "" +"Project-Id-Version: DB-GPT 👏👏 0.4.0\n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2023-11-02 21:09+0800\n" +"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" +"Last-Translator: FULL NAME \n" +"Language: zh_CN\n" +"Language-Team: zh_CN \n" +"Plural-Forms: nplurals=1; plural=0;\n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=utf-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Generated-By: Babel 2.12.1\n" + +#: ../../getting_started/install/cluster/openai.md:1 +#: 01f4e2bf853341198633b367efec1522 +msgid "OpenAI-Compatible RESTful APIs" +msgstr "OpenAI RESTful 兼容接口" + +#: ../../getting_started/install/cluster/openai.md:5 +#: d8717e42335e4027bf4e76b3d28768ee +msgid "Install Prepare" +msgstr "安装准备" + +#: ../../getting_started/install/cluster/openai.md:7 +#: 9a48d8ee116942468de4c6faf9a64758 +msgid "" +"You must [deploy DB-GPT cluster](https://db-" +"gpt.readthedocs.io/en/latest/getting_started/install/cluster/vms/index.html)" +" first." +msgstr "你必须先部署 [DB-GPT 集群]" +"(https://db-gpt.readthedocs.io/projects/db-gpt-docs-zh-cn/zh-cn/latest/getting_started/install/cluster/vms/index.html)。" + +#: ../../getting_started/install/cluster/openai.md:9 +#: 7673a7121f004f7ca6b1a94a7e238fa3 +msgid "Launch Model API Server" +msgstr "启动模型 API Server" + +#: ../../getting_started/install/cluster/openai.md:14 +#: 84a925c2cbcd4e4895a1d2d2fe8f720f +msgid "By default, the Model API Server starts on port 8100." +msgstr "默认情况下,模型 API Server 使用 8100 端口启动。" + +#: ../../getting_started/install/cluster/openai.md:16 +#: e53ed41977cd4721becd51eba05c6609 +msgid "Validate with cURL" +msgstr "通过 cURL 验证" + +#: ../../getting_started/install/cluster/openai.md:18 +#: 7c883b410b5c4e53a256bf17c1ded80d +msgid "List models" +msgstr "列出模型" + +#: ../../getting_started/install/cluster/openai.md:26 +#: ../../getting_started/install/cluster/openai.md:37 +#: 7cf0ed13f0754f149ec085cd6cf7a45a 990d5d5ed5d64ab49550e68495b9e7a0 +msgid "Chat completions" +msgstr "" + +#: ../../getting_started/install/cluster/openai.md:35 +#: 81583edd22df44e091d18a0832278131 +msgid "Validate with OpenAI Official SDK" +msgstr "通过 OpenAI 官方 SDK 验证" + diff --git a/docs/locales/zh_CN/LC_MESSAGES/getting_started/install/environment/environment.po b/docs/locales/zh_CN/LC_MESSAGES/getting_started/install/environment/environment.po index 09b3c8fa2..addf53bc6 100644 --- a/docs/locales/zh_CN/LC_MESSAGES/getting_started/install/environment/environment.po +++ b/docs/locales/zh_CN/LC_MESSAGES/getting_started/install/environment/environment.po @@ -8,7 +8,7 @@ msgid "" msgstr "" "Project-Id-Version: DB-GPT 👏👏 0.3.5\n" "Report-Msgid-Bugs-To: \n" -"POT-Creation-Date: 2023-08-17 13:07+0800\n" +"POT-Creation-Date: 2023-11-02 21:04+0800\n" "PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" "Last-Translator: FULL NAME \n" "Language: zh_CN\n" @@ -20,290 +20,292 @@ msgstr "" "Generated-By: Babel 2.12.1\n" #: ../../getting_started/install/environment/environment.md:1 -#: be341d16f7b24bf4ad123ab78a6d855a +#: a17719d2f4374285a7beb4d1db470146 #, fuzzy msgid "Environment Parameter" msgstr "环境变量说明" #: ../../getting_started/install/environment/environment.md:4 -#: 46eddb27c90f41548ea9a724bbcebd37 +#: 9a62e6fff7914eeaa2d195ddef4fcb61 msgid "LLM MODEL Config" msgstr "模型配置" #: ../../getting_started/install/environment/environment.md:5 -#: 7deaa85df4a04fb098f5994547a8724f +#: 90e3991538324ecfac8cac7ef2103ac2 msgid "LLM Model Name, see /pilot/configs/model_config.LLM_MODEL_CONFIG" msgstr "LLM Model Name, see /pilot/configs/model_config.LLM_MODEL_CONFIG" #: ../../getting_started/install/environment/environment.md:6 -#: 3902801c546547b3a4009df681ef7d52 +#: 1f45af01100c4586acbc05469e3006bc msgid "LLM_MODEL=vicuna-13b" msgstr "LLM_MODEL=vicuna-13b" #: ../../getting_started/install/environment/environment.md:8 -#: 84b0fdbfa1544ec28751e9b69b00cc02 +#: bed14b704f154c2db525f7fafd3aa5a4 msgid "MODEL_SERVER_ADDRESS" msgstr "MODEL_SERVER_ADDRESS" #: ../../getting_started/install/environment/environment.md:9 -#: 0b430bfab77d405989470d00ca3f6fe0 +#: ea42946cfe4f4ad996bf82c1996e7344 msgid "MODEL_SERVER=http://127.0.0.1:8000 LIMIT_MODEL_CONCURRENCY" msgstr "MODEL_SERVER=http://127.0.0.1:8000 LIMIT_MODEL_CONCURRENCY" #: ../../getting_started/install/environment/environment.md:12 -#: b477a25586c546729a93fb6785b7b2ec +#: 021c261231f342fdba34098b1baa06fd msgid "LIMIT_MODEL_CONCURRENCY=5" msgstr "LIMIT_MODEL_CONCURRENCY=5" #: ../../getting_started/install/environment/environment.md:14 -#: 1d6ea800af384fff9c265610f71cc94e +#: afaf0ba7fd09463d8ff74b514ed7264c msgid "MAX_POSITION_EMBEDDINGS" msgstr "MAX_POSITION_EMBEDDINGS" #: ../../getting_started/install/environment/environment.md:16 -#: 388e758ce4ea4692a4c34294cebce7f2 +#: e4517a942bca4361a64a00408f993f5b msgid "MAX_POSITION_EMBEDDINGS=4096" msgstr "MAX_POSITION_EMBEDDINGS=4096" #: ../../getting_started/install/environment/environment.md:18 -#: 16a307dce1294ceba892ff93ae4e81c0 +#: 78d2ef04ed4548b9b7b0fb8ae35c9d5c msgid "QUANTIZE_QLORA" msgstr "QUANTIZE_QLORA" #: ../../getting_started/install/environment/environment.md:20 -#: 93ceb2b2fcd5454b82eefb0ae8c7ae77 +#: bfa65db03c6d46bba293331f03ab15ac msgid "QUANTIZE_QLORA=True" msgstr "QUANTIZE_QLORA=True" #: ../../getting_started/install/environment/environment.md:22 -#: 15ffa35d023a4530b02a85ee6168dd4b +#: 1947d45a7f184821910b4834ad5f1897 msgid "QUANTIZE_8bit" msgstr "QUANTIZE_8bit" #: ../../getting_started/install/environment/environment.md:24 -#: 81df248ac5cb4ab0b13a711505f6a177 +#: 4a2ee2919d0e4bdaa13c9d92eefd2aac msgid "QUANTIZE_8bit=True" msgstr "QUANTIZE_8bit=True" #: ../../getting_started/install/environment/environment.md:27 -#: 15cc7b7d41ad44f0891c1189709f00f1 +#: 348dc1e411b54ab09414f40a20e934e4 msgid "LLM PROXY Settings" msgstr "LLM PROXY Settings" #: ../../getting_started/install/environment/environment.md:28 -#: e6c1115a39404f11b193a1593bc51a22 +#: a692e78425a040f5828ab54ff9a33f77 msgid "OPENAI Key" msgstr "OPENAI Key" #: ../../getting_started/install/environment/environment.md:30 -#: 8157e0a831fe4506a426822b7565e4f6 +#: 940d00e25a424acf92951a314a64e5ea msgid "PROXY_API_KEY={your-openai-sk}" msgstr "PROXY_API_KEY={your-openai-sk}" #: ../../getting_started/install/environment/environment.md:31 -#: 89b34d00bdb64e738bd9bc8c086b1f02 +#: 4bd27547ae6041679e91f2a363cd1deb msgid "PROXY_SERVER_URL=https://api.openai.com/v1/chat/completions" msgstr "PROXY_SERVER_URL=https://api.openai.com/v1/chat/completions" #: ../../getting_started/install/environment/environment.md:33 -#: 7a97df730aeb484daf19c8172e61a290 +#: cfa3071afb0b47baad6bd729d4a02cb9 msgid "from https://bard.google.com/ f12-> application-> __Secure-1PSID" msgstr "from https://bard.google.com/ f12-> application-> __Secure-1PSID" #: ../../getting_started/install/environment/environment.md:35 -#: d430ddf726a049c0a9e0a9bfd5a6fe0e +#: a17efa03b10f47f68afac9e865982a75 msgid "BARD_PROXY_API_KEY={your-bard-token}" msgstr "BARD_PROXY_API_KEY={your-bard-token}" #: ../../getting_started/install/environment/environment.md:38 -#: 23d6b0da3e7042abb55f6181c4a382d2 +#: 6bcfe90574da4d82a459e8e11bf73cba msgid "DATABASE SETTINGS" msgstr "DATABASE SETTINGS" #: ../../getting_started/install/environment/environment.md:39 -#: dbae0a2d847f41f5be9396a160ef88d0 +#: 2b1e62d9bf5d4af5a22f68c8248eaafb msgid "SQLite database (Current default database)" msgstr "SQLite database (Current default database)" #: ../../getting_started/install/environment/environment.md:40 -#: bdb55b7280c341a981e9d338cce53345 +#: 8a909ac3b3c943da8dbc4e8dd596c80c msgid "LOCAL_DB_PATH=data/default_sqlite.db" msgstr "LOCAL_DB_PATH=data/default_sqlite.db" #: ../../getting_started/install/environment/environment.md:41 -#: 739d67927a9d46b28500deba1917916b +#: 90ae6507932f4815b6e180051738bb93 msgid "LOCAL_DB_TYPE=sqlite # Database Type default:sqlite" msgstr "LOCAL_DB_TYPE=sqlite # Database Type default:sqlite" #: ../../getting_started/install/environment/environment.md:43 -#: eb4717bce6a6483b86d9780d924c5ff1 +#: d2ce34e0dcf44ccf9e8007d548ba7b0a msgid "MYSQL database" msgstr "MYSQL database" #: ../../getting_started/install/environment/environment.md:44 -#: 0f4cdf0ff5dd4ff0b397dfa88541a2e1 +#: c07159d63c334f6cbb95fcc30bfb7ea5 msgid "LOCAL_DB_TYPE=mysql" msgstr "LOCAL_DB_TYPE=mysql" #: ../../getting_started/install/environment/environment.md:45 -#: c971ead492c34487bd766300730a9cba +#: e16700b2ea8d411e91d010c1cde7aecc msgid "LOCAL_DB_USER=root" msgstr "LOCAL_DB_USER=root" #: ../../getting_started/install/environment/environment.md:46 -#: 02828b29ad044eeab890a2f8af0e5907 +#: bfc2dce1bf374121b6861e677b4e1ffa msgid "LOCAL_DB_PASSWORD=aa12345678" msgstr "LOCAL_DB_PASSWORD=aa12345678" #: ../../getting_started/install/environment/environment.md:47 -#: 53dc7f15b3934987b1f4c2e2d0b11299 +#: bc384739f5b04e21a34d0d2b78e7906c msgid "LOCAL_DB_HOST=127.0.0.1" msgstr "LOCAL_DB_HOST=127.0.0.1" #: ../../getting_started/install/environment/environment.md:48 -#: 1ac95fc482934247a118bab8dcebeb57 +#: e5253d452e0d42b7ac308fe6fbfb5017 msgid "LOCAL_DB_PORT=3306" msgstr "LOCAL_DB_PORT=3306" #: ../../getting_started/install/environment/environment.md:51 -#: 34e46aa926844be19c7196759b03af63 +#: 9ca8f6fe06ed4cbab390f94be252e165 msgid "EMBEDDING SETTINGS" msgstr "EMBEDDING SETTINGS" #: ../../getting_started/install/environment/environment.md:52 -#: 2b5aa08cc995495e85a1f7dc4f97b5d7 +#: 76c7c260293c4b49bae057143fd48377 msgid "EMBEDDING MODEL Name, see /pilot/configs/model_config.LLM_MODEL_CONFIG" msgstr "EMBEDDING模型, 参考see /pilot/configs/model_config.LLM_MODEL_CONFIG" #: ../../getting_started/install/environment/environment.md:53 -#: 0de0ca551ed040248406f848feca541d +#: f1d63a0128ce493cae37d34f1976bcca msgid "EMBEDDING_MODEL=text2vec" msgstr "EMBEDDING_MODEL=text2vec" #: ../../getting_started/install/environment/environment.md:55 -#: 43019fb570904c9981eb68f33e64569c +#: b8fbb99109d04781b2dd5bc5d6efa5bd msgid "Embedding Chunk size, default 500" msgstr "Embedding 切片大小, 默认500" #: ../../getting_started/install/environment/environment.md:57 -#: 7e3f93854873461286e96887e04167aa +#: bf8256576ea34f6a9c5f261ab9aab676 msgid "KNOWLEDGE_CHUNK_SIZE=500" msgstr "KNOWLEDGE_CHUNK_SIZE=500" #: ../../getting_started/install/environment/environment.md:59 -#: 9504f4a59ae74352a524b7741113e2d6 +#: 9b156c6b599b4c02a58ce023b4ff25f2 msgid "Embedding Chunk Overlap, default 100" msgstr "Embedding chunk Overlap, 文本块之间的最大重叠量。保留一些重叠可以保持文本块之间的连续性(例如使用滑动窗口),默认100" #: ../../getting_started/install/environment/environment.md:60 -#: 24e6119c2051479bbd9dba71a9c23dbe +#: dcafd903c36041ac85ac99a14dbee512 msgid "KNOWLEDGE_CHUNK_OVERLAP=100" msgstr "KNOWLEDGE_CHUNK_OVERLAP=100" #: ../../getting_started/install/environment/environment.md:62 -#: 0d180d7f2230442abee901c19526e442 -msgid "embeding recall top k,5" +#: 6c3244b7e5e24b0188c7af4bb52e9134 +#, fuzzy +msgid "embedding recall top k,5" msgstr "embedding 召回topk, 默认5" #: ../../getting_started/install/environment/environment.md:64 -#: a5bb9ab2ba50411cbbe87f7836bfbb6d +#: f4a2f30551cf4fe1a7ff3c7c74ec77be msgid "KNOWLEDGE_SEARCH_TOP_SIZE=5" msgstr "KNOWLEDGE_SEARCH_TOP_SIZE=5" #: ../../getting_started/install/environment/environment.md:66 -#: 183b8dd78cba4ae19bd2e08d69d21e0b -msgid "embeding recall max token ,2000" +#: 593f2512362f467e92fdaa60dd5903a0 +#, fuzzy +msgid "embedding recall max token ,2000" msgstr "embedding向量召回最大token, 默认2000" #: ../../getting_started/install/environment/environment.md:68 -#: ce0c711febcb44c18ae0fc858c3718d1 +#: 83d6d28914be4d6282d457272e508ddc msgid "KNOWLEDGE_SEARCH_MAX_TOKEN=5" msgstr "KNOWLEDGE_SEARCH_MAX_TOKEN=5" #: ../../getting_started/install/environment/environment.md:71 #: ../../getting_started/install/environment/environment.md:87 -#: 4cab1f399cc245b4a1a1976d2c4fc926 ec9cec667a1c4473bf9a796a26e1ce20 +#: 6bc1b9d995e74294a1c78e783c550db7 d33c77ded834438e9f4a2df06e7e041a msgid "Vector Store SETTINGS" msgstr "Vector Store SETTINGS" #: ../../getting_started/install/environment/environment.md:72 #: ../../getting_started/install/environment/environment.md:88 -#: 4dd04aadd46948a5b1dcf01fdb0ef074 bab7d512f33e40cf9e10f0da67e699c8 +#: 9cafa06e2d584f70afd848184e0fa52a f01057251b8b4ffea806192dfe1048ed msgid "Chroma" msgstr "Chroma" #: ../../getting_started/install/environment/environment.md:73 #: ../../getting_started/install/environment/environment.md:89 -#: 13eec36741b14e028e2d3859a320826e ab3ffbcf9358401993af636ba9ab2e2d +#: e6c16fab37484769b819aeecbc13e6db faad299722e5400e95ec6ac3c1e018b8 msgid "VECTOR_STORE_TYPE=Chroma" msgstr "VECTOR_STORE_TYPE=Chroma" #: ../../getting_started/install/environment/environment.md:74 #: ../../getting_started/install/environment/environment.md:90 -#: d15b91e2a2884f23a1dd2d54783b0638 d1f856d571b547098bb0c2a18f9f1979 +#: 4eca3a51716d406f8ffd49c06550e871 581ee9dd38064b119660c44bdd00cbaa msgid "MILVUS" msgstr "MILVUS" #: ../../getting_started/install/environment/environment.md:75 #: ../../getting_started/install/environment/environment.md:91 -#: 1e165f6c934343c7808459cc7a65bc70 985dd60c2b7d4baaa6601a810a6522d7 +#: 814c93048bed46589358a854d6c99683 b72b1269a2224f5f961214e41c019f21 msgid "VECTOR_STORE_TYPE=Milvus" msgstr "VECTOR_STORE_TYPE=Milvus" #: ../../getting_started/install/environment/environment.md:76 #: ../../getting_started/install/environment/environment.md:92 -#: a1a53f051cee40ed886346a94babd75a d263e8eaee684935a58f0a4fe61c6f0e +#: 73ae665f1db9402883662734588fd02c c4da20319c994e83ba5a7706db967178 msgid "MILVUS_URL=127.0.0.1" msgstr "MILVUS_URL=127.0.0.1" #: ../../getting_started/install/environment/environment.md:77 #: ../../getting_started/install/environment/environment.md:93 -#: 2741a312db1a4c6a8a1c1d62415c5fba d03bbf921ddd4f4bb715fe5610c3d0aa +#: e30c5288516d42aa858a485db50490c1 f843b2e58bcb4e4594e3c28499c341d0 msgid "MILVUS_PORT=19530" msgstr "MILVUS_PORT=19530" #: ../../getting_started/install/environment/environment.md:78 #: ../../getting_started/install/environment/environment.md:94 -#: d0786490d38c4e4f971cc14f62fe1fc8 e9e0854873dc4c209861ee4eb77d25cd +#: 158669efcc7d4bcaac1c8dd01b499029 24e88ffd32f242f281c56c0ec3ad2639 msgid "MILVUS_USERNAME" msgstr "MILVUS_USERNAME" #: ../../getting_started/install/environment/environment.md:79 #: ../../getting_started/install/environment/environment.md:95 -#: 9a82d07153cc432ebe754b5bc02fde0d a6485c1cfa7d4069a6894c43674c8c2b +#: 111a985297184c8aa5a0dd8e14a58445 6602093a6bb24d6792548e2392105c82 msgid "MILVUS_PASSWORD" msgstr "MILVUS_PASSWORD" #: ../../getting_started/install/environment/environment.md:80 #: ../../getting_started/install/environment/environment.md:96 -#: 2f233f32b8ba408a9fbadb21fabb99ec 809b3219dd824485bc2cfc898530d708 +#: 47bdfcd78fbe4ccdb5f49b717a6d01a6 b96c0545b2044926a8a8190caf94ad25 msgid "MILVUS_SECURE=" msgstr "MILVUS_SECURE=" #: ../../getting_started/install/environment/environment.md:82 #: ../../getting_started/install/environment/environment.md:98 -#: f00603661f2b42e1bd2bca74ad1e3c31 f378e16fdec44c559e34c6929de812e8 +#: 755c32b5d6c54607907a138b5474c0ec ff4f2a7ddaa14f089dda7a14e1062c36 msgid "WEAVIATE" msgstr "WEAVIATE" #: ../../getting_started/install/environment/environment.md:83 -#: da2049ebc6874cf0a6b562e0e2fd9ec7 +#: 23b2ce83385d40a589a004709f9864be msgid "VECTOR_STORE_TYPE=Weaviate" msgstr "VECTOR_STORE_TYPE=Weaviate" #: ../../getting_started/install/environment/environment.md:84 #: ../../getting_started/install/environment/environment.md:99 -#: 25f1246629934289aad7ef01c7304097 c9fe0e413d9a4fc8abf86b3ed99e0581 +#: 9acef304d89a448a9e734346705ba872 cf5151b6c1594ccd8beb1c3f77769acb msgid "WEAVIATE_URL=https://kt-region-m8hcy0wc.weaviate.network" msgstr "WEAVIATE_URL=https://kt-region-m8hcy0wc.weaviate.network" #: ../../getting_started/install/environment/environment.md:102 -#: ba7c9e707f6a4cd6b99e52b58da3ab2d +#: c3003516b2364051bf34f8c3086e348a msgid "Multi-GPU Setting" msgstr "Multi-GPU Setting" #: ../../getting_started/install/environment/environment.md:103 -#: 5ca75fdf2c264b2c844d77f659b4f0b3 +#: ade8fc381c5e438aa29d159c10041713 msgid "" "See https://developer.nvidia.com/blog/cuda-pro-tip-control-gpu-" "visibility-cuda_visible_devices/ If CUDA_VISIBLE_DEVICES is not " @@ -313,49 +315,49 @@ msgstr "" "cuda_visible_devices/ 如果 CUDA_VISIBLE_DEVICES没有设置, 会使用所有可用的gpu" #: ../../getting_started/install/environment/environment.md:106 -#: de92eb310aff43fbbbf3c5a116c3b2c6 +#: e137bd19be5e410ba6709027dbf2923a msgid "CUDA_VISIBLE_DEVICES=0" msgstr "CUDA_VISIBLE_DEVICES=0" #: ../../getting_started/install/environment/environment.md:108 -#: d2641df6123a442b8e4444ad5f01a9aa +#: 7669947acbdc4b1d92bcc029a8353a5d msgid "" "Optionally, you can also specify the gpu ID to use before the starting " "command" msgstr "你也可以通过启动命令设置gpu ID" #: ../../getting_started/install/environment/environment.md:110 -#: 76c66179d11a4e5fa369421378609aae +#: 751743d1753b4051beea46371278d793 msgid "CUDA_VISIBLE_DEVICES=3,4,5,6" msgstr "CUDA_VISIBLE_DEVICES=3,4,5,6" #: ../../getting_started/install/environment/environment.md:112 -#: 29bd0f01fdf540ad98385ea8473f7647 +#: 3acc3de0af0d4df2bb575e161e377f85 msgid "You can configure the maximum memory used by each GPU." msgstr "可以设置GPU的最大内存" #: ../../getting_started/install/environment/environment.md:114 -#: 31e5e23838734ba7a2810e2387e6d6a0 +#: 67f1d9b172b84294a44ecace5436e6e0 msgid "MAX_GPU_MEMORY=16Gib" msgstr "MAX_GPU_MEMORY=16Gib" #: ../../getting_started/install/environment/environment.md:117 -#: 99aa63ab1ae049d9b94536d6a96f3443 +#: 3c69dfe48bcf46b89b76cac1e7849a66 msgid "Other Setting" msgstr "Other Setting" #: ../../getting_started/install/environment/environment.md:118 -#: 3168732183874bffb59a3575d3473d62 +#: d5015b70f4fe4d20a63de9d87f86957a msgid "Language Settings(influence prompt language)" msgstr "Language Settings(涉及prompt语言以及知识切片方式)" #: ../../getting_started/install/environment/environment.md:119 -#: 73eb0a96f29b4739bd456faa9cb5033d +#: 5543c28bb8e34c9fb3bb6b063c2b1750 msgid "LANGUAGE=en" msgstr "LANGUAGE=en" #: ../../getting_started/install/environment/environment.md:120 -#: c6646b78c6cf4d25a13108232f5b2046 +#: cb4ed5b892ee41068c1ca76cb29aa400 msgid "LANGUAGE=zh" msgstr "LANGUAGE=zh" diff --git a/docs/locales/zh_CN/LC_MESSAGES/modules/knowledge.po b/docs/locales/zh_CN/LC_MESSAGES/modules/knowledge.po index 0b3eac094..bb2dae7af 100644 --- a/docs/locales/zh_CN/LC_MESSAGES/modules/knowledge.po +++ b/docs/locales/zh_CN/LC_MESSAGES/modules/knowledge.po @@ -8,7 +8,7 @@ msgid "" msgstr "" "Project-Id-Version: DB-GPT 0.3.0\n" "Report-Msgid-Bugs-To: \n" -"POT-Creation-Date: 2023-07-13 15:39+0800\n" +"POT-Creation-Date: 2023-11-02 21:04+0800\n" "PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" "Last-Translator: FULL NAME \n" "Language: zh_CN\n" @@ -19,103 +19,84 @@ msgstr "" "Content-Transfer-Encoding: 8bit\n" "Generated-By: Babel 2.12.1\n" -#: ../../modules/knowledge.rst:2 ../../modules/knowledge.rst:136 -#: 3cc8fa6e9fbd4d889603d99424e9529a +#: ../../modules/knowledge.md:1 b94b3b15cb2441ed9d78abd222a717b7 msgid "Knowledge" msgstr "知识" -#: ../../modules/knowledge.rst:4 0465a393d9d541958c39c1d07c885d1f +#: ../../modules/knowledge.md:3 c6d6e308a6ce42948d29e928136ef561 #, fuzzy msgid "" "As the knowledge base is currently the most significant user demand " "scenario, we natively support the construction and processing of " "knowledge bases. At the same time, we also provide multiple knowledge " -"base management strategies in this project, such as pdf knowledge,md " -"knowledge, txt knowledge, word knowledge, ppt knowledge:" +"base management strategies in this project, such as:" msgstr "" "由于知识库是当前用户需求最显著的场景,我们原生支持知识库的构建和处理。同时,我们还在本项目中提供了多种知识库管理策略,如:pdf,md , " "txt, word, ppt" -#: ../../modules/knowledge.rst:6 e670cbe14d8e4da88ba935e4120c31e0 +#: ../../modules/knowledge.md:4 268abc408d40410ba90cf5f121dc5270 +msgid "Default built-in knowledge base" +msgstr "" + +#: ../../modules/knowledge.md:5 558c3364c38b458a8ebf81030efc2a48 +msgid "Custom addition of knowledge bases" +msgstr "" + +#: ../../modules/knowledge.md:6 9cb3ce62da1440579c095848c7aef88c msgid "" -"We currently support many document formats: raw text, txt, pdf, md, html," -" doc, ppt, and url. In the future, we will continue to support more types" -" of knowledge, including audio, video, various databases, and big data " -"sources. Of course, we look forward to your active participation in " -"contributing code." +"Various usage scenarios such as constructing knowledge bases through " +"plugin capabilities and web crawling. Users only need to organize the " +"knowledge documents, and they can use our existing capabilities to build " +"the knowledge base required for the large model." msgstr "" -#: ../../modules/knowledge.rst:9 e0bf601a1a0c458297306db6ff79f931 -msgid "**Create your own knowledge repository**" +#: ../../modules/knowledge.md:9 b8ca6bc4dd9845baa56e36eea7fac2a2 +#, fuzzy +msgid "Create your own knowledge repository" msgstr "创建你自己的知识库" -#: ../../modules/knowledge.rst:11 bb26708135d44615be3c1824668010f6 -msgid "1.prepare" -msgstr "准备" +#: ../../modules/knowledge.md:11 17d7178a67924f43aa5b6293707ef041 +msgid "" +"1.Place personal knowledge files or folders in the pilot/datasets " +"directory." +msgstr "" -#: ../../modules/knowledge.rst:13 c150a0378f3e4625908fa0d8a25860e9 +#: ../../modules/knowledge.md:13 31c31f14bf444981939689f9a9fb038a #, fuzzy msgid "" -"We currently support many document formats: TEXT(raw text), " -"DOCUMENT(.txt, .pdf, .md, .doc, .ppt, .html), and URL." +"We currently support many document formats: txt, pdf, md, html, doc, ppt," +" and url." msgstr "当前支持txt, pdf, md, html, doc, ppt, url文档格式" -#: ../../modules/knowledge.rst:15 7f9f02a93d5d4325b3d2d976f4bb28a0 +#: ../../modules/knowledge.md:15 9ad2f2e05f8842a9b9d8469a3704df23 msgid "before execution:" msgstr "开始前" -#: ../../modules/knowledge.rst:24 59699a8385e04982a992cf0d71f6dcd5 -#, fuzzy +#: ../../modules/knowledge.md:22 6fd2775914b641c4b8e486417b558ea6 msgid "" -"2.prepare embedding model, you can download from https://huggingface.co/." -" Notice you have installed git-lfs." +"2.Update your .env, set your vector store type, VECTOR_STORE_TYPE=Chroma " +"(now only support Chroma and Milvus, if you set Milvus, please set " +"MILVUS_URL and MILVUS_PORT)" msgstr "" -"提前准备Embedding Model, 你可以在https://huggingface.co/进行下载,注意:你需要先安装git-lfs.eg:" -" git clone https://huggingface.co/THUDM/chatglm2-6b" - -#: ../../modules/knowledge.rst:27 2be1a17d0b54476b9dea080d244fd747 -msgid "" -"eg: git clone https://huggingface.co/sentence-transformers/all-" -"MiniLM-L6-v2" -msgstr "eg: git clone https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2" -#: ../../modules/knowledge.rst:33 d328f6e243624c9488ebd27c9324621b -msgid "" -"3.prepare vector_store instance and vector store config, now we support " -"Chroma, Milvus and Weaviate." -msgstr "提前准备向量数据库环境,目前支持Chroma, Milvus and Weaviate向量数据库" +#: ../../modules/knowledge.md:25 131c5f58898a4682940910980edb2043 +msgid "2.Run the knowledge repository initialization command" +msgstr "" -#: ../../modules/knowledge.rst:63 44f97154eff647d399fd30b6f9e3b867 +#: ../../modules/knowledge.md:31 2cf550f17881497bb881b19efcc18c23 msgid "" -"3.init Url Type EmbeddingEngine api and embedding your document into " -"vector store in your code." -msgstr "初始化 Url类型 EmbeddingEngine api, 将url文档embedding向量化到向量数据库 " - -#: ../../modules/knowledge.rst:75 e2581b414f0148bca88253c7af9cd591 -msgid "If you want to add your source_reader or text_splitter, do this:" -msgstr "如果你想手动添加你自定义的source_reader和text_splitter, 请参考:" - -#: ../../modules/knowledge.rst:95 74c110414f924bbfa3d512e45ba2f30f -#, fuzzy -msgid "" -"4.init Document Type EmbeddingEngine api and embedding your document into" -" vector store in your code. Document type can be .txt, .pdf, .md, .doc, " -".ppt." +"Optionally, you can run `dbgpt knowledge load --help` command to see more" +" usage." msgstr "" -"初始化 文档型类型 EmbeddingEngine api, 将文档embedding向量化到向量数据库(文档可以是.txt, .pdf, " -".md, .html, .doc, .ppt)" -#: ../../modules/knowledge.rst:108 0afd40098d5f4dfd9e44fe1d8004da25 +#: ../../modules/knowledge.md:33 c8a2ea571b944bdfbcad48fa8b54fcc9 msgid "" -"5.init TEXT Type EmbeddingEngine api and embedding your document into " -"vector store in your code." -msgstr "初始化TEXT类型 EmbeddingEngine api, 将文档embedding向量化到向量数据库" - -#: ../../modules/knowledge.rst:120 a66961bf3efd41fa8ea938129446f5a5 -msgid "4.similar search based on your knowledge base. ::" -msgstr "在知识库进行相似性搜索" +"3.Add the knowledge repository in the interface by entering the name of " +"your knowledge repository (if not specified, enter \"default\") so you " +"can use it for Q&A based on your knowledge base." +msgstr "" -#: ../../modules/knowledge.rst:126 b7066f408378450db26770f83fbd2716 +#: ../../modules/knowledge.md:35 b701170ad75e49dea7d7734c15681e0f msgid "" "Note that the default vector model used is text2vec-large-chinese (which " "is a large model, so if your personal computer configuration is not " @@ -125,48 +106,6 @@ msgstr "" "注意,这里默认向量模型是text2vec-large-chinese(模型比较大,如果个人电脑配置不够建议采用text2vec-base-" "chinese),因此确保需要将模型download下来放到models目录中。" -#: ../../modules/knowledge.rst:128 58481d55cab74936b6e84b24c39b1674 -#, fuzzy -msgid "" -"`pdf_embedding <./knowledge/pdf/pdf_embedding.html>`_: supported pdf " -"embedding." -msgstr "pdf_embedding <./knowledge/pdf_embedding.html>`_: supported pdf embedding." - -#: ../../modules/knowledge.rst:129 fbb013c4f1bc46af910c91292f6690cf -#, fuzzy -msgid "" -"`markdown_embedding <./knowledge/markdown/markdown_embedding.html>`_: " -"supported markdown embedding." -msgstr "pdf_embedding <./knowledge/pdf_embedding.html>`_: supported pdf embedding." - -#: ../../modules/knowledge.rst:130 59d45732f4914d16b4e01aee0992edf7 -#, fuzzy -msgid "" -"`word_embedding <./knowledge/word/word_embedding.html>`_: supported word " -"embedding." -msgstr "pdf_embedding <./knowledge/pdf_embedding.html>`_: supported pdf embedding." - -#: ../../modules/knowledge.rst:131 df0e6f311861423e885b38e020a7c0f0 -#, fuzzy -msgid "" -"`url_embedding <./knowledge/url/url_embedding.html>`_: supported url " -"embedding." -msgstr "pdf_embedding <./knowledge/pdf_embedding.html>`_: supported pdf embedding." - -#: ../../modules/knowledge.rst:132 7c550c1f5bc34fe9986731fb465e12cd -#, fuzzy -msgid "" -"`ppt_embedding <./knowledge/ppt/ppt_embedding.html>`_: supported ppt " -"embedding." -msgstr "pdf_embedding <./knowledge/pdf_embedding.html>`_: supported pdf embedding." - -#: ../../modules/knowledge.rst:133 8648684cb191476faeeb548389f79050 -#, fuzzy -msgid "" -"`string_embedding <./knowledge/string/string_embedding.html>`_: supported" -" raw text embedding." -msgstr "pdf_embedding <./knowledge/pdf_embedding.html>`_: supported pdf embedding." - #~ msgid "before execution: python -m spacy download zh_core_web_sm" #~ msgstr "在执行之前请先执行python -m spacy download zh_core_web_sm" @@ -201,3 +140,112 @@ msgstr "pdf_embedding <./knowledge/pdf_embedding.html>`_: supported pdf embeddin #~ "and MILVUS_PORT)" #~ msgstr "2.更新你的.env,设置你的向量存储类型,VECTOR_STORE_TYPE=Chroma(现在只支持Chroma和Milvus,如果你设置了Milvus,请设置MILVUS_URL和MILVUS_PORT)" +#~ msgid "" +#~ "We currently support many document " +#~ "formats: raw text, txt, pdf, md, " +#~ "html, doc, ppt, and url. In the" +#~ " future, we will continue to support" +#~ " more types of knowledge, including " +#~ "audio, video, various databases, and big" +#~ " data sources. Of course, we look " +#~ "forward to your active participation in" +#~ " contributing code." +#~ msgstr "" + +#~ msgid "1.prepare" +#~ msgstr "准备" + +#~ msgid "" +#~ "2.prepare embedding model, you can " +#~ "download from https://huggingface.co/. Notice " +#~ "you have installed git-lfs." +#~ msgstr "" +#~ "提前准备Embedding Model, 你可以在https://huggingface.co/进行下载,注意" +#~ ":你需要先安装git-lfs.eg: git clone " +#~ "https://huggingface.co/THUDM/chatglm2-6b" + +#~ msgid "" +#~ "eg: git clone https://huggingface.co/sentence-" +#~ "transformers/all-MiniLM-L6-v2" +#~ msgstr "" +#~ "eg: git clone https://huggingface.co/sentence-" +#~ "transformers/all-MiniLM-L6-v2" + +#~ msgid "" +#~ "3.prepare vector_store instance and vector " +#~ "store config, now we support Chroma, " +#~ "Milvus and Weaviate." +#~ msgstr "提前准备向量数据库环境,目前支持Chroma, Milvus and Weaviate向量数据库" + +#~ msgid "" +#~ "3.init Url Type EmbeddingEngine api and" +#~ " embedding your document into vector " +#~ "store in your code." +#~ msgstr "初始化 Url类型 EmbeddingEngine api, 将url文档embedding向量化到向量数据库 " + +#~ msgid "If you want to add your source_reader or text_splitter, do this:" +#~ msgstr "如果你想手动添加你自定义的source_reader和text_splitter, 请参考:" + +#~ msgid "" +#~ "4.init Document Type EmbeddingEngine api " +#~ "and embedding your document into vector" +#~ " store in your code. Document type" +#~ " can be .txt, .pdf, .md, .doc, " +#~ ".ppt." +#~ msgstr "" +#~ "初始化 文档型类型 EmbeddingEngine api, " +#~ "将文档embedding向量化到向量数据库(文档可以是.txt, .pdf, .md, .html," +#~ " .doc, .ppt)" + +#~ msgid "" +#~ "5.init TEXT Type EmbeddingEngine api and" +#~ " embedding your document into vector " +#~ "store in your code." +#~ msgstr "初始化TEXT类型 EmbeddingEngine api, 将文档embedding向量化到向量数据库" + +#~ msgid "4.similar search based on your knowledge base. ::" +#~ msgstr "在知识库进行相似性搜索" + +#~ msgid "" +#~ "`pdf_embedding <./knowledge/pdf/pdf_embedding.html>`_: " +#~ "supported pdf embedding." +#~ msgstr "" +#~ "pdf_embedding <./knowledge/pdf_embedding.html>`_: " +#~ "supported pdf embedding." + +#~ msgid "" +#~ "`markdown_embedding " +#~ "<./knowledge/markdown/markdown_embedding.html>`_: supported " +#~ "markdown embedding." +#~ msgstr "" +#~ "pdf_embedding <./knowledge/pdf_embedding.html>`_: " +#~ "supported pdf embedding." + +#~ msgid "" +#~ "`word_embedding <./knowledge/word/word_embedding.html>`_: " +#~ "supported word embedding." +#~ msgstr "" +#~ "pdf_embedding <./knowledge/pdf_embedding.html>`_: " +#~ "supported pdf embedding." + +#~ msgid "" +#~ "`url_embedding <./knowledge/url/url_embedding.html>`_: " +#~ "supported url embedding." +#~ msgstr "" +#~ "pdf_embedding <./knowledge/pdf_embedding.html>`_: " +#~ "supported pdf embedding." + +#~ msgid "" +#~ "`ppt_embedding <./knowledge/ppt/ppt_embedding.html>`_: " +#~ "supported ppt embedding." +#~ msgstr "" +#~ "pdf_embedding <./knowledge/pdf_embedding.html>`_: " +#~ "supported pdf embedding." + +#~ msgid "" +#~ "`string_embedding <./knowledge/string/string_embedding.html>`_:" +#~ " supported raw text embedding." +#~ msgstr "" +#~ "pdf_embedding <./knowledge/pdf_embedding.html>`_: " +#~ "supported pdf embedding." + diff --git a/pilot/base_modules/agent/db/__init__.py b/pilot/base_modules/agent/db/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/component.py b/pilot/component.py index 8f8c8c5a4..8182f3435 100644 --- a/pilot/component.py +++ b/pilot/component.py @@ -46,6 +46,8 @@ class ComponentType(str, Enum): WORKER_MANAGER = "dbgpt_worker_manager" WORKER_MANAGER_FACTORY = "dbgpt_worker_manager_factory" MODEL_CONTROLLER = "dbgpt_model_controller" + MODEL_REGISTRY = "dbgpt_model_registry" + MODEL_API_SERVER = "dbgpt_model_api_server" AGENT_HUB = "dbgpt_agent_hub" EXECUTOR_DEFAULT = "dbgpt_thread_pool_default" TRACER = "dbgpt_tracer" @@ -68,7 +70,6 @@ def init_app(self, system_app: SystemApp): This method needs to be implemented by every component to define how it integrates with the main system app. """ - pass T = TypeVar("T", bound=BaseComponent) @@ -90,13 +91,28 @@ def app(self) -> Optional["FastAPI"]: """Returns the internal ASGI app.""" return self._asgi_app - def register(self, component: Type[BaseComponent], *args, **kwargs): - """Register a new component by its type.""" + def register(self, component: Type[BaseComponent], *args, **kwargs) -> T: + """Register a new component by its type. + + Args: + component (Type[BaseComponent]): The component class to register + + Returns: + T: The instance of registered component + """ instance = component(self, *args, **kwargs) self.register_instance(instance) + return instance + + def register_instance(self, instance: T) -> T: + """Register an already initialized component. - def register_instance(self, instance: T): - """Register an already initialized component.""" + Args: + instance (T): The component instance to register + + Returns: + T: The instance of registered component + """ name = instance.name if isinstance(name, ComponentType): name = name.value @@ -107,18 +123,34 @@ def register_instance(self, instance: T): logger.info(f"Register component with name {name} and instance: {instance}") self.components[name] = instance instance.init_app(self) + return instance def get_component( self, name: Union[str, ComponentType], component_type: Type[T], default_component=_EMPTY_DEFAULT_COMPONENT, + or_register_component: Type[BaseComponent] = None, + *args, + **kwargs, ) -> T: - """Retrieve a registered component by its name and type.""" + """Retrieve a registered component by its name and type. + + Args: + name (Union[str, ComponentType]): Component name + component_type (Type[T]): The type of current retrieve component + default_component : The default component instance if not retrieve by name + or_register_component (Type[BaseComponent]): The new component to register if not retrieve by name + + Returns: + T: The instance retrieved by component name + """ if isinstance(name, ComponentType): name = name.value component = self.components.get(name) if not component: + if or_register_component: + return self.register(or_register_component, *args, **kwargs) if default_component != _EMPTY_DEFAULT_COMPONENT: return default_component raise ValueError(f"No component found with name {name}") diff --git a/pilot/model/base.py b/pilot/model/base.py index e89b243c9..48480b94b 100644 --- a/pilot/model/base.py +++ b/pilot/model/base.py @@ -2,7 +2,7 @@ # -*- coding: utf-8 -*- from enum import Enum -from typing import TypedDict, Optional, Dict, List +from typing import TypedDict, Optional, Dict, List, Any from dataclasses import dataclass, asdict from datetime import datetime from pilot.utils.parameter_utils import ParameterDescription @@ -52,6 +52,8 @@ class ModelOutput: text: str error_code: int model_context: Dict = None + finish_reason: str = None + usage: Dict[str, Any] = None def to_dict(self) -> Dict: return asdict(self) diff --git a/pilot/model/cli.py b/pilot/model/cli.py index 1030adfc2..79b47db82 100644 --- a/pilot/model/cli.py +++ b/pilot/model/cli.py @@ -8,6 +8,7 @@ from pilot.model.base import WorkerApplyType from pilot.model.parameter import ( ModelControllerParameters, + ModelAPIServerParameters, ModelWorkerParameters, ModelParameters, BaseParameters, @@ -441,15 +442,27 @@ def stop_model_worker(port: int): @click.command(name="apiserver") +@EnvArgumentParser.create_click_option(ModelAPIServerParameters) def start_apiserver(**kwargs): - """Start apiserver(TODO)""" - raise NotImplementedError + """Start apiserver""" + + if kwargs["daemon"]: + log_file = os.path.join(LOGDIR, "model_apiserver_uvicorn.log") + _run_current_with_daemon("ModelAPIServer", log_file) + else: + from pilot.model.cluster import run_apiserver + + run_apiserver() @click.command(name="apiserver") -def stop_apiserver(**kwargs): - """Start apiserver(TODO)""" - raise NotImplementedError +@add_stop_server_options +def stop_apiserver(port: int): + """Stop apiserver""" + name = "ModelAPIServer" + if port: + name = f"{name}-{port}" + _stop_service("apiserver", name, port=port) def _stop_all_model_server(**kwargs): diff --git a/pilot/model/cluster/__init__.py b/pilot/model/cluster/__init__.py index 9937ffa0b..a777a8d4b 100644 --- a/pilot/model/cluster/__init__.py +++ b/pilot/model/cluster/__init__.py @@ -21,6 +21,7 @@ run_model_controller, BaseModelController, ) +from pilot.model.cluster.apiserver.api import run_apiserver from pilot.model.cluster.worker.remote_manager import RemoteWorkerManager @@ -40,4 +41,5 @@ "ModelRegistryClient", "RemoteWorkerManager", "run_model_controller", + "run_apiserver", ] diff --git a/pilot/model/cluster/apiserver/__init__.py b/pilot/model/cluster/apiserver/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/model/cluster/apiserver/api.py b/pilot/model/cluster/apiserver/api.py new file mode 100644 index 000000000..148a51eed --- /dev/null +++ b/pilot/model/cluster/apiserver/api.py @@ -0,0 +1,443 @@ +"""A server that provides OpenAI-compatible RESTful APIs. It supports: +- Chat Completions. (Reference: https://platform.openai.com/docs/api-reference/chat) + +Adapted from https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/openai_api_server.py +""" +from typing import Optional, List, Dict, Any, Generator + +import logging +import asyncio +import shortuuid +import json +from fastapi import APIRouter, FastAPI +from fastapi import Depends, HTTPException +from fastapi.exceptions import RequestValidationError +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import StreamingResponse +from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer + +from pydantic import BaseSettings + +from fastchat.protocol.openai_api_protocol import ( + ChatCompletionResponse, + ChatCompletionResponseStreamChoice, + ChatCompletionStreamResponse, + ChatMessage, + ChatCompletionResponseChoice, + DeltaMessage, + EmbeddingsRequest, + EmbeddingsResponse, + ErrorResponse, + ModelCard, + ModelList, + ModelPermission, + UsageInfo, +) +from fastchat.protocol.api_protocol import ( + APIChatCompletionRequest, + APITokenCheckRequest, + APITokenCheckResponse, + APITokenCheckResponseItem, +) +from fastchat.serve.openai_api_server import create_error_response, check_requests +from fastchat.constants import ErrorCode + +from pilot.component import BaseComponent, ComponentType, SystemApp +from pilot.utils.parameter_utils import EnvArgumentParser +from pilot.scene.base_message import ModelMessage, ModelMessageRoleType +from pilot.model.base import ModelInstance, ModelOutput +from pilot.model.parameter import ModelAPIServerParameters, WorkerType +from pilot.model.cluster import ModelRegistry, ModelRegistryClient +from pilot.model.cluster.manager_base import WorkerManager, WorkerManagerFactory +from pilot.utils.utils import setup_logging + +logger = logging.getLogger(__name__) + + +class APIServerException(Exception): + def __init__(self, code: int, message: str): + self.code = code + self.message = message + + +class APISettings(BaseSettings): + api_keys: Optional[List[str]] = None + + +api_settings = APISettings() +get_bearer_token = HTTPBearer(auto_error=False) + + +async def check_api_key( + auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token), +) -> str: + if api_settings.api_keys: + if auth is None or (token := auth.credentials) not in api_settings.api_keys: + raise HTTPException( + status_code=401, + detail={ + "error": { + "message": "", + "type": "invalid_request_error", + "param": None, + "code": "invalid_api_key", + } + }, + ) + return token + else: + # api_keys not set; allow all + return None + + +class APIServer(BaseComponent): + name = ComponentType.MODEL_API_SERVER + + def init_app(self, system_app: SystemApp): + self.system_app = system_app + + def get_worker_manager(self) -> WorkerManager: + """Get the worker manager component instance + + Raises: + APIServerException: If can't get worker manager component instance + """ + worker_manager = self.system_app.get_component( + ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory + ).create() + if not worker_manager: + raise APIServerException( + ErrorCode.INTERNAL_ERROR, + f"Could not get component {ComponentType.WORKER_MANAGER_FACTORY} from system_app", + ) + return worker_manager + + def get_model_registry(self) -> ModelRegistry: + """Get the model registry component instance + + Raises: + APIServerException: If can't get model registry component instance + """ + + controller = self.system_app.get_component( + ComponentType.MODEL_REGISTRY, ModelRegistry + ) + if not controller: + raise APIServerException( + ErrorCode.INTERNAL_ERROR, + f"Could not get component {ComponentType.MODEL_REGISTRY} from system_app", + ) + return controller + + async def get_model_instances_or_raise( + self, model_name: str + ) -> List[ModelInstance]: + """Get healthy model instances with request model name + + Args: + model_name (str): Model name + + Raises: + APIServerException: If can't get healthy model instances with request model name + """ + registry = self.get_model_registry() + registry_model_name = f"{model_name}@llm" + model_instances = await registry.get_all_instances( + registry_model_name, healthy_only=True + ) + if not model_instances: + all_instances = await registry.get_all_model_instances(healthy_only=True) + models = [ + ins.model_name.split("@llm")[0] + for ins in all_instances + if ins.model_name.endswith("@llm") + ] + if models: + models = "&&".join(models) + message = f"Only {models} allowed now, your model {model_name}" + else: + message = f"No models allowed now, your model {model_name}" + raise APIServerException(ErrorCode.INVALID_MODEL, message) + return model_instances + + async def get_available_models(self) -> ModelList: + """Return available models + + Just include LLM and embedding models. + + Returns: + List[ModelList]: The list of models. + """ + registry = self.get_model_registry() + model_instances = await registry.get_all_model_instances(healthy_only=True) + model_name_set = set() + for inst in model_instances: + name, worker_type = WorkerType.parse_worker_key(inst.model_name) + if worker_type == WorkerType.LLM or worker_type == WorkerType.TEXT2VEC: + model_name_set.add(name) + models = list(model_name_set) + models.sort() + # TODO: return real model permission details + model_cards = [] + for m in models: + model_cards.append( + ModelCard( + id=m, root=m, owned_by="DB-GPT", permission=[ModelPermission()] + ) + ) + return ModelList(data=model_cards) + + async def chat_completion_stream_generator( + self, model_name: str, params: Dict[str, Any], n: int + ) -> Generator[str, Any, None]: + """Chat stream completion generator + + Args: + model_name (str): Model name + params (Dict[str, Any]): The parameters pass to model worker + n (int): How many completions to generate for each prompt. + """ + worker_manager = self.get_worker_manager() + id = f"chatcmpl-{shortuuid.random()}" + finish_stream_events = [] + for i in range(n): + # First chunk with role + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=DeltaMessage(role="assistant"), + finish_reason=None, + ) + chunk = ChatCompletionStreamResponse( + id=id, choices=[choice_data], model=model_name + ) + yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n" + + previous_text = "" + async for model_output in worker_manager.generate_stream(params): + model_output: ModelOutput = model_output + if model_output.error_code != 0: + yield f"data: {json.dumps(model_output.to_dict(), ensure_ascii=False)}\n\n" + yield "data: [DONE]\n\n" + return + decoded_unicode = model_output.text.replace("\ufffd", "") + delta_text = decoded_unicode[len(previous_text) :] + previous_text = ( + decoded_unicode + if len(decoded_unicode) > len(previous_text) + else previous_text + ) + + if len(delta_text) == 0: + delta_text = None + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=DeltaMessage(content=delta_text), + finish_reason=model_output.finish_reason, + ) + chunk = ChatCompletionStreamResponse( + id=id, choices=[choice_data], model=model_name + ) + if delta_text is None: + if model_output.finish_reason is not None: + finish_stream_events.append(chunk) + continue + yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n" + # There is not "content" field in the last delta message, so exclude_none to exclude field "content". + for finish_chunk in finish_stream_events: + yield f"data: {finish_chunk.json(exclude_none=True, ensure_ascii=False)}\n\n" + yield "data: [DONE]\n\n" + + async def chat_completion_generate( + self, model_name: str, params: Dict[str, Any], n: int + ) -> ChatCompletionResponse: + """Generate completion + Args: + model_name (str): Model name + params (Dict[str, Any]): The parameters pass to model worker + n (int): How many completions to generate for each prompt. + """ + worker_manager: WorkerManager = self.get_worker_manager() + choices = [] + chat_completions = [] + for i in range(n): + model_output = asyncio.create_task(worker_manager.generate(params)) + chat_completions.append(model_output) + try: + all_tasks = await asyncio.gather(*chat_completions) + except Exception as e: + return create_error_response(ErrorCode.INTERNAL_ERROR, str(e)) + usage = UsageInfo() + for i, model_output in enumerate(all_tasks): + model_output: ModelOutput = model_output + if model_output.error_code != 0: + return create_error_response(model_output.error_code, model_output.text) + choices.append( + ChatCompletionResponseChoice( + index=i, + message=ChatMessage(role="assistant", content=model_output.text), + finish_reason=model_output.finish_reason or "stop", + ) + ) + if model_output.usage: + task_usage = UsageInfo.parse_obj(model_output.usage) + for usage_key, usage_value in task_usage.dict().items(): + setattr(usage, usage_key, getattr(usage, usage_key) + usage_value) + + return ChatCompletionResponse(model=model_name, choices=choices, usage=usage) + + +def get_api_server() -> APIServer: + api_server = global_system_app.get_component( + ComponentType.MODEL_API_SERVER, APIServer, default_component=None + ) + if not api_server: + global_system_app.register(APIServer) + return global_system_app.get_component(ComponentType.MODEL_API_SERVER, APIServer) + + +router = APIRouter() + + +@router.get("/v1/models", dependencies=[Depends(check_api_key)]) +async def get_available_models(api_server: APIServer = Depends(get_api_server)): + return await api_server.get_available_models() + + +@router.post("/v1/chat/completions", dependencies=[Depends(check_api_key)]) +async def create_chat_completion( + request: APIChatCompletionRequest, api_server: APIServer = Depends(get_api_server) +): + await api_server.get_model_instances_or_raise(request.model) + error_check_ret = check_requests(request) + if error_check_ret is not None: + return error_check_ret + params = { + "model": request.model, + "messages": ModelMessage.to_dict_list( + ModelMessage.from_openai_messages(request.messages) + ), + "echo": False, + } + if request.temperature: + params["temperature"] = request.temperature + if request.top_p: + params["top_p"] = request.top_p + if request.max_tokens: + params["max_new_tokens"] = request.max_tokens + if request.stop: + params["stop"] = request.stop + if request.user: + params["user"] = request.user + + # TODO check token length + if request.stream: + generator = api_server.chat_completion_stream_generator( + request.model, params, request.n + ) + return StreamingResponse(generator, media_type="text/event-stream") + return await api_server.chat_completion_generate(request.model, params, request.n) + + +def _initialize_all(controller_addr: str, system_app: SystemApp): + from pilot.model.cluster import RemoteWorkerManager, ModelRegistryClient + from pilot.model.cluster.worker.manager import _DefaultWorkerManagerFactory + + if not system_app.get_component( + ComponentType.MODEL_REGISTRY, ModelRegistry, default_component=None + ): + # Register model registry if not exist + registry = ModelRegistryClient(controller_addr) + registry.name = ComponentType.MODEL_REGISTRY.value + system_app.register_instance(registry) + + registry = system_app.get_component( + ComponentType.MODEL_REGISTRY, ModelRegistry, default_component=None + ) + worker_manager = RemoteWorkerManager(registry) + + # Register worker manager component if not exist + system_app.get_component( + ComponentType.WORKER_MANAGER_FACTORY, + WorkerManagerFactory, + or_register_component=_DefaultWorkerManagerFactory, + worker_manager=worker_manager, + ) + # Register api server component if not exist + system_app.get_component( + ComponentType.MODEL_API_SERVER, APIServer, or_register_component=APIServer + ) + + +def initialize_apiserver( + controller_addr: str, + app=None, + system_app: SystemApp = None, + host: str = None, + port: int = None, + api_keys: List[str] = None, +): + global global_system_app + global api_settings + embedded_mod = True + if not app: + embedded_mod = False + app = FastAPI() + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"], + allow_headers=["*"], + ) + + if not system_app: + system_app = SystemApp(app) + global_system_app = system_app + + if api_keys: + api_settings.api_keys = api_keys + + app.include_router(router, prefix="/api", tags=["APIServer"]) + + @app.exception_handler(APIServerException) + async def validation_apiserver_exception_handler(request, exc: APIServerException): + return create_error_response(exc.code, exc.message) + + @app.exception_handler(RequestValidationError) + async def validation_exception_handler(request, exc): + return create_error_response(ErrorCode.VALIDATION_TYPE_ERROR, str(exc)) + + _initialize_all(controller_addr, system_app) + + if not embedded_mod: + import uvicorn + + uvicorn.run(app, host=host, port=port, log_level="info") + + +def run_apiserver(): + parser = EnvArgumentParser() + env_prefix = "apiserver_" + apiserver_params: ModelAPIServerParameters = parser.parse_args_into_dataclass( + ModelAPIServerParameters, + env_prefixes=[env_prefix], + ) + setup_logging( + "pilot", + logging_level=apiserver_params.log_level, + logger_filename=apiserver_params.log_file, + ) + api_keys = None + if apiserver_params.api_keys: + api_keys = apiserver_params.api_keys.strip().split(",") + + initialize_apiserver( + apiserver_params.controller_addr, + host=apiserver_params.host, + port=apiserver_params.port, + api_keys=api_keys, + ) + + +if __name__ == "__main__": + run_apiserver() diff --git a/pilot/model/cluster/apiserver/tests/__init__.py b/pilot/model/cluster/apiserver/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/model/cluster/apiserver/tests/test_api.py b/pilot/model/cluster/apiserver/tests/test_api.py new file mode 100644 index 000000000..281a8aff6 --- /dev/null +++ b/pilot/model/cluster/apiserver/tests/test_api.py @@ -0,0 +1,248 @@ +import pytest +import pytest_asyncio +from aioresponses import aioresponses +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from httpx import AsyncClient, HTTPError + +from pilot.component import SystemApp +from pilot.utils.openai_utils import chat_completion_stream, chat_completion + +from pilot.model.cluster.apiserver.api import ( + api_settings, + initialize_apiserver, + ModelList, + UsageInfo, + ChatCompletionResponse, + ChatCompletionResponseStreamChoice, + ChatCompletionStreamResponse, + ChatMessage, + ChatCompletionResponseChoice, + DeltaMessage, +) +from pilot.model.cluster.tests.conftest import _new_cluster + +from pilot.model.cluster.worker.manager import _DefaultWorkerManagerFactory + +app = FastAPI() +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"], + allow_headers=["*"], +) + + +@pytest_asyncio.fixture +async def system_app(): + return SystemApp(app) + + +@pytest_asyncio.fixture +async def client(request, system_app: SystemApp): + param = getattr(request, "param", {}) + api_keys = param.get("api_keys", []) + client_api_key = param.get("client_api_key") + if "num_workers" not in param: + param["num_workers"] = 2 + if "api_keys" in param: + del param["api_keys"] + headers = {} + if client_api_key: + headers["Authorization"] = "Bearer " + client_api_key + print(f"param: {param}") + if api_settings: + # Clear global api keys + api_settings.api_keys = [] + async with AsyncClient(app=app, base_url="http://test", headers=headers) as client: + async with _new_cluster(**param) as cluster: + worker_manager, model_registry = cluster + system_app.register(_DefaultWorkerManagerFactory, worker_manager) + system_app.register_instance(model_registry) + # print(f"Instances {model_registry.registry}") + initialize_apiserver(None, app, system_app, api_keys=api_keys) + yield client + + +@pytest.mark.asyncio +async def test_get_all_models(client: AsyncClient): + res = await client.get("/api/v1/models") + res.status_code == 200 + model_lists = ModelList.parse_obj(res.json()) + print(f"model list json: {res.json()}") + assert model_lists.object == "list" + assert len(model_lists.data) == 2 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client, expected_messages", + [ + ({"stream_messags": ["Hello", " world."]}, "Hello world."), + ({"stream_messags": ["你好,我是", "张三。"]}, "你好,我是张三。"), + ], + indirect=["client"], +) +async def test_chat_completions(client: AsyncClient, expected_messages): + chat_data = { + "model": "test-model-name-0", + "messages": [{"role": "user", "content": "Hello"}], + "stream": True, + } + full_text = "" + async for text in chat_completion_stream( + "/api/v1/chat/completions", chat_data, client + ): + full_text += text + assert full_text == expected_messages + + assert ( + await chat_completion("/api/v1/chat/completions", chat_data, client) + == expected_messages + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client, expected_messages, client_api_key", + [ + ( + {"stream_messags": ["Hello", " world."], "api_keys": ["abc"]}, + "Hello world.", + "abc", + ), + ({"stream_messags": ["你好,我是", "张三。"], "api_keys": ["abc"]}, "你好,我是张三。", "abc"), + ], + indirect=["client"], +) +async def test_chat_completions_with_openai_lib_async_no_stream( + client: AsyncClient, expected_messages: str, client_api_key: str +): + import openai + + openai.api_key = client_api_key + openai.api_base = "http://test/api/v1" + + model_name = "test-model-name-0" + + with aioresponses() as mocked: + mock_message = {"text": expected_messages} + one_res = ChatCompletionResponseChoice( + index=0, + message=ChatMessage(role="assistant", content=expected_messages), + finish_reason="stop", + ) + data = ChatCompletionResponse( + model=model_name, choices=[one_res], usage=UsageInfo() + ) + mock_message = f"{data.json(exclude_unset=True, ensure_ascii=False)}\n\n" + # Mock http request + mocked.post( + "http://test/api/v1/chat/completions", status=200, body=mock_message + ) + completion = await openai.ChatCompletion.acreate( + model=model_name, + messages=[{"role": "user", "content": "Hello! What is your name?"}], + ) + assert completion.choices[0].message.content == expected_messages + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client, expected_messages, client_api_key", + [ + ( + {"stream_messags": ["Hello", " world."], "api_keys": ["abc"]}, + "Hello world.", + "abc", + ), + ({"stream_messags": ["你好,我是", "张三。"], "api_keys": ["abc"]}, "你好,我是张三。", "abc"), + ], + indirect=["client"], +) +async def test_chat_completions_with_openai_lib_async_stream( + client: AsyncClient, expected_messages: str, client_api_key: str +): + import openai + + openai.api_key = client_api_key + openai.api_base = "http://test/api/v1" + + model_name = "test-model-name-0" + + with aioresponses() as mocked: + mock_message = {"text": expected_messages} + choice_data = ChatCompletionResponseStreamChoice( + index=0, + delta=DeltaMessage(content=expected_messages), + finish_reason="stop", + ) + chunk = ChatCompletionStreamResponse( + id=0, choices=[choice_data], model=model_name + ) + mock_message = f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n" + mocked.post( + "http://test/api/v1/chat/completions", + status=200, + body=mock_message, + content_type="text/event-stream", + ) + + stream_stream_resp = "" + async for stream_resp in await openai.ChatCompletion.acreate( + model=model_name, + messages=[{"role": "user", "content": "Hello! What is your name?"}], + stream=True, + ): + stream_stream_resp = stream_resp.choices[0]["delta"].get("content", "") + assert stream_stream_resp == expected_messages + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client, expected_messages, api_key_is_error", + [ + ( + { + "stream_messags": ["Hello", " world."], + "api_keys": ["abc", "xx"], + "client_api_key": "abc", + }, + "Hello world.", + False, + ), + ({"stream_messags": ["你好,我是", "张三。"]}, "你好,我是张三。", False), + ( + {"stream_messags": ["你好,我是", "张三。"], "api_keys": ["abc", "xx"]}, + "你好,我是张三。", + True, + ), + ( + { + "stream_messags": ["你好,我是", "张三。"], + "api_keys": ["abc", "xx"], + "client_api_key": "error_api_key", + }, + "你好,我是张三。", + True, + ), + ], + indirect=["client"], +) +async def test_chat_completions_with_api_keys( + client: AsyncClient, expected_messages: str, api_key_is_error: bool +): + chat_data = { + "model": "test-model-name-0", + "messages": [{"role": "user", "content": "Hello"}], + "stream": True, + } + if api_key_is_error: + with pytest.raises(HTTPError): + await chat_completion("/api/v1/chat/completions", chat_data, client) + else: + assert ( + await chat_completion("/api/v1/chat/completions", chat_data, client) + == expected_messages + ) diff --git a/pilot/model/cluster/controller/controller.py b/pilot/model/cluster/controller/controller.py index 173c8c019..0006d91a0 100644 --- a/pilot/model/cluster/controller/controller.py +++ b/pilot/model/cluster/controller/controller.py @@ -66,7 +66,9 @@ async def get_all_instances( f"Get all instances with {model_name}, healthy_only: {healthy_only}" ) if not model_name: - return await self.registry.get_all_model_instances() + return await self.registry.get_all_model_instances( + healthy_only=healthy_only + ) else: return await self.registry.get_all_instances(model_name, healthy_only) @@ -98,8 +100,10 @@ async def send_heartbeat(self, instance: ModelInstance) -> bool: class ModelRegistryClient(_RemoteModelController, ModelRegistry): - async def get_all_model_instances(self) -> List[ModelInstance]: - return await self.get_all_instances() + async def get_all_model_instances( + self, healthy_only: bool = False + ) -> List[ModelInstance]: + return await self.get_all_instances(healthy_only=healthy_only) @sync_api_remote(path="/api/controller/models") def sync_get_all_instances( diff --git a/pilot/model/cluster/registry.py b/pilot/model/cluster/registry.py index 398882eb9..eb5f1e415 100644 --- a/pilot/model/cluster/registry.py +++ b/pilot/model/cluster/registry.py @@ -1,22 +1,37 @@ import random import threading import time +import logging from abc import ABC, abstractmethod from collections import defaultdict from datetime import datetime, timedelta -from typing import Dict, List, Tuple +from typing import Dict, List, Optional, Tuple import itertools +from pilot.component import BaseComponent, ComponentType, SystemApp from pilot.model.base import ModelInstance -class ModelRegistry(ABC): +logger = logging.getLogger(__name__) + + +class ModelRegistry(BaseComponent, ABC): """ Abstract base class for a model registry. It provides an interface for registering, deregistering, fetching instances, and sending heartbeats for instances. """ + name = ComponentType.MODEL_REGISTRY + + def __init__(self, system_app: SystemApp | None = None): + self.system_app = system_app + super().__init__(system_app) + + def init_app(self, system_app: SystemApp): + """Initialize the component with the main application.""" + self.system_app = system_app + @abstractmethod async def register_instance(self, instance: ModelInstance) -> bool: """ @@ -65,9 +80,11 @@ def sync_get_all_instances( """Fetch all instances of a given model. Optionally, fetch only the healthy instances.""" @abstractmethod - async def get_all_model_instances(self) -> List[ModelInstance]: + async def get_all_model_instances( + self, healthy_only: bool = False + ) -> List[ModelInstance]: """ - Fetch all instances of all models + Fetch all instances of all models, Optionally, fetch only the healthy instances. Returns: - List[ModelInstance]: A list of instances for the all models. @@ -105,8 +122,12 @@ async def send_heartbeat(self, instance: ModelInstance) -> bool: class EmbeddedModelRegistry(ModelRegistry): def __init__( - self, heartbeat_interval_secs: int = 60, heartbeat_timeout_secs: int = 120 + self, + system_app: SystemApp | None = None, + heartbeat_interval_secs: int = 60, + heartbeat_timeout_secs: int = 120, ): + super().__init__(system_app) self.registry: Dict[str, List[ModelInstance]] = defaultdict(list) self.heartbeat_interval_secs = heartbeat_interval_secs self.heartbeat_timeout_secs = heartbeat_timeout_secs @@ -180,9 +201,14 @@ def sync_get_all_instances( instances = [ins for ins in instances if ins.healthy == True] return instances - async def get_all_model_instances(self) -> List[ModelInstance]: - print(self.registry) - return list(itertools.chain(*self.registry.values())) + async def get_all_model_instances( + self, healthy_only: bool = False + ) -> List[ModelInstance]: + logger.debug("Current registry metadata:\n{self.registry}") + instances = list(itertools.chain(*self.registry.values())) + if healthy_only: + instances = [ins for ins in instances if ins.healthy == True] + return instances async def send_heartbeat(self, instance: ModelInstance) -> bool: _, exist_ins = self._get_instances( diff --git a/pilot/model/cluster/tests/__init__.py b/pilot/model/cluster/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/model/cluster/worker/tests/base_tests.py b/pilot/model/cluster/tests/conftest.py similarity index 71% rename from pilot/model/cluster/worker/tests/base_tests.py rename to pilot/model/cluster/tests/conftest.py index 21821d9f9..f614387ac 100644 --- a/pilot/model/cluster/worker/tests/base_tests.py +++ b/pilot/model/cluster/tests/conftest.py @@ -6,6 +6,7 @@ from pilot.model.base import ModelOutput from pilot.model.cluster.worker_base import ModelWorker from pilot.model.cluster.worker.manager import ( + WorkerManager, LocalWorkerManager, RegisterFunc, DeregisterFunc, @@ -13,6 +14,23 @@ ApplyFunction, ) +from pilot.model.base import ModelInstance +from pilot.model.cluster.registry import ModelRegistry, EmbeddedModelRegistry + + +@pytest.fixture +def model_registry(request): + return EmbeddedModelRegistry() + + +@pytest.fixture +def model_instance(): + return ModelInstance( + model_name="test_model", + host="192.168.1.1", + port=5000, + ) + class MockModelWorker(ModelWorker): def __init__( @@ -51,8 +69,10 @@ def stop(self) -> None: raise Exception("Stop worker error for mock") def generate_stream(self, params: Dict) -> Iterator[ModelOutput]: + full_text = "" for msg in self.stream_messags: - yield ModelOutput(text=msg, error_code=0) + full_text += msg + yield ModelOutput(text=full_text, error_code=0) def generate(self, params: Dict) -> ModelOutput: output = None @@ -67,6 +87,8 @@ def embeddings(self, params: Dict) -> List[List[float]]: _TEST_MODEL_NAME = "vicuna-13b-v1.5" _TEST_MODEL_PATH = "/app/models/vicuna-13b-v1.5" +ClusterType = Tuple[WorkerManager, ModelRegistry] + def _new_worker_params( model_name: str = _TEST_MODEL_NAME, @@ -85,7 +107,9 @@ def _create_workers( worker_type: str = WorkerType.LLM.value, stream_messags: List[str] = None, embeddings: List[List[float]] = None, -) -> List[Tuple[ModelWorker, ModelWorkerParameters]]: + host: str = "127.0.0.1", + start_port=8001, +) -> List[Tuple[ModelWorker, ModelWorkerParameters, ModelInstance]]: workers = [] for i in range(num_workers): model_name = f"test-model-name-{i}" @@ -98,10 +122,16 @@ def _create_workers( stream_messags=stream_messags, embeddings=embeddings, ) + model_instance = ModelInstance( + model_name=WorkerType.to_worker_key(model_name, worker_type), + host=host, + port=start_port + i, + healthy=True, + ) worker_params = _new_worker_params( model_name, model_path, worker_type=worker_type ) - workers.append((worker, worker_params)) + workers.append((worker, worker_params, model_instance)) return workers @@ -127,12 +157,12 @@ async def _start_worker_manager(**kwargs): model_registry=model_registry, ) - for worker, worker_params in _create_workers( + for worker, worker_params, model_instance in _create_workers( num_workers, error_worker, stop_error, stream_messags, embeddings ): worker_manager.add_worker(worker, worker_params) if workers: - for worker, worker_params in workers: + for worker, worker_params, model_instance in workers: worker_manager.add_worker(worker, worker_params) if start: @@ -143,6 +173,15 @@ async def _start_worker_manager(**kwargs): await worker_manager.stop() +async def _create_model_registry( + workers: List[Tuple[ModelWorker, ModelWorkerParameters, ModelInstance]] +) -> ModelRegistry: + registry = EmbeddedModelRegistry() + for _, _, inst in workers: + assert await registry.register_instance(inst) == True + return registry + + @pytest_asyncio.fixture async def manager_2_workers(request): param = getattr(request, "param", {}) @@ -166,3 +205,27 @@ async def manager_2_embedding_workers(request): ) async with _start_worker_manager(workers=workers, **param) as worker_manager: yield (worker_manager, workers) + + +@asynccontextmanager +async def _new_cluster(**kwargs) -> ClusterType: + num_workers = kwargs.get("num_workers", 0) + workers = _create_workers( + num_workers, stream_messags=kwargs.get("stream_messags", []) + ) + if "num_workers" in kwargs: + del kwargs["num_workers"] + registry = await _create_model_registry( + workers, + ) + async with _start_worker_manager(workers=workers, **kwargs) as worker_manager: + yield (worker_manager, registry) + + +@pytest_asyncio.fixture +async def cluster_2_workers(request): + param = getattr(request, "param", {}) + workers = _create_workers(2) + registry = await _create_model_registry(workers) + async with _start_worker_manager(workers=workers, **param) as worker_manager: + yield (worker_manager, registry) diff --git a/pilot/model/cluster/worker/default_worker.py b/pilot/model/cluster/worker/default_worker.py index 5caa2ee7e..44a476f20 100644 --- a/pilot/model/cluster/worker/default_worker.py +++ b/pilot/model/cluster/worker/default_worker.py @@ -256,15 +256,22 @@ def _prepare_generate_stream(self, params: Dict, span_operation_name: str): return params, model_context, generate_stream_func, model_span def _handle_output(self, output, previous_response, model_context): + finish_reason = None + usage = None if isinstance(output, dict): finish_reason = output.get("finish_reason") + usage = output.get("usage") output = output["text"] if finish_reason is not None: logger.info(f"finish_reason: {finish_reason}") incremental_output = output[len(previous_response) :] print(incremental_output, end="", flush=True) model_output = ModelOutput( - text=output, error_code=0, model_context=model_context + text=output, + error_code=0, + model_context=model_context, + finish_reason=finish_reason, + usage=usage, ) return model_output, incremental_output, output diff --git a/pilot/model/cluster/worker/manager.py b/pilot/model/cluster/worker/manager.py index a76fa6685..2dcfb086e 100644 --- a/pilot/model/cluster/worker/manager.py +++ b/pilot/model/cluster/worker/manager.py @@ -99,9 +99,7 @@ def __init__( ) def _worker_key(self, worker_type: str, model_name: str) -> str: - if isinstance(worker_type, WorkerType): - worker_type = worker_type.value - return f"{model_name}@{worker_type}" + return WorkerType.to_worker_key(model_name, worker_type) async def run_blocking_func(self, func, *args): if asyncio.iscoroutinefunction(func): diff --git a/pilot/model/cluster/worker/tests/test_manager.py b/pilot/model/cluster/worker/tests/test_manager.py index 919e64f99..681fb49a3 100644 --- a/pilot/model/cluster/worker/tests/test_manager.py +++ b/pilot/model/cluster/worker/tests/test_manager.py @@ -3,7 +3,7 @@ from typing import List, Iterator, Dict, Tuple from dataclasses import asdict from pilot.model.parameter import ModelParameters, ModelWorkerParameters, WorkerType -from pilot.model.base import ModelOutput, WorkerApplyType +from pilot.model.base import ModelOutput, WorkerApplyType, ModelInstance from pilot.model.cluster.base import WorkerApplyRequest, WorkerStartupRequest from pilot.model.cluster.worker_base import ModelWorker from pilot.model.cluster.manager_base import WorkerRunData @@ -14,7 +14,7 @@ SendHeartbeatFunc, ApplyFunction, ) -from pilot.model.cluster.worker.tests.base_tests import ( +from pilot.model.cluster.tests.conftest import ( MockModelWorker, manager_2_workers, manager_with_2_workers, @@ -216,7 +216,7 @@ async def test__remove_worker(): workers = _create_workers(3) async with _start_worker_manager(workers=workers, stop=False) as manager: assert len(manager.workers) == 3 - for _, worker_params in workers: + for _, worker_params, _ in workers: manager._remove_worker(worker_params) not_exist_parmas = _new_worker_params( model_name="this is a not exist worker params" @@ -229,7 +229,7 @@ async def test__remove_worker(): async def test_model_startup(mock_build_worker): async with _start_worker_manager() as manager: workers = _create_workers(1) - worker, worker_params = workers[0] + worker, worker_params, model_instance = workers[0] mock_build_worker.return_value = worker req = WorkerStartupRequest( @@ -245,7 +245,7 @@ async def test_model_startup(mock_build_worker): async with _start_worker_manager() as manager: workers = _create_workers(1, error_worker=True) - worker, worker_params = workers[0] + worker, worker_params, model_instance = workers[0] mock_build_worker.return_value = worker req = WorkerStartupRequest( host="127.0.0.1", @@ -263,7 +263,7 @@ async def test_model_startup(mock_build_worker): async def test_model_shutdown(mock_build_worker): async with _start_worker_manager(start=False, stop=False) as manager: workers = _create_workers(1) - worker, worker_params = workers[0] + worker, worker_params, model_instance = workers[0] mock_build_worker.return_value = worker req = WorkerStartupRequest( @@ -298,7 +298,7 @@ async def test_get_model_instances(is_async): workers = _create_workers(3) async with _start_worker_manager(workers=workers, stop=False) as manager: assert len(manager.workers) == 3 - for _, worker_params in workers: + for _, worker_params, _ in workers: model_name = worker_params.model_name worker_type = worker_params.worker_type if is_async: @@ -326,7 +326,7 @@ async def test__simple_select( ] ): manager, workers = manager_with_2_workers - for _, worker_params in workers: + for _, worker_params, _ in workers: model_name = worker_params.model_name worker_type = worker_params.worker_type instances = await manager.get_model_instances(worker_type, model_name) @@ -351,7 +351,7 @@ async def test_select_one_instance( ], ): manager, workers = manager_with_2_workers - for _, worker_params in workers: + for _, worker_params, _ in workers: model_name = worker_params.model_name worker_type = worker_params.worker_type if is_async: @@ -376,7 +376,7 @@ async def test__get_model( ], ): manager, workers = manager_with_2_workers - for _, worker_params in workers: + for _, worker_params, _ in workers: model_name = worker_params.model_name worker_type = worker_params.worker_type params = {"model": model_name} @@ -403,13 +403,13 @@ async def test_generate_stream( expected_messages: str, ): manager, workers = manager_with_2_workers - for _, worker_params in workers: + for _, worker_params, _ in workers: model_name = worker_params.model_name worker_type = worker_params.worker_type params = {"model": model_name} text = "" async for out in manager.generate_stream(params): - text += out.text + text = out.text assert text == expected_messages @@ -417,8 +417,8 @@ async def test_generate_stream( @pytest.mark.parametrize( "manager_with_2_workers, expected_messages", [ - ({"stream_messags": ["Hello", " world."]}, " world."), - ({"stream_messags": ["你好,我是", "张三。"]}, "张三。"), + ({"stream_messags": ["Hello", " world."]}, "Hello world."), + ({"stream_messags": ["你好,我是", "张三。"]}, "你好,我是张三。"), ], indirect=["manager_with_2_workers"], ) @@ -429,7 +429,7 @@ async def test_generate( expected_messages: str, ): manager, workers = manager_with_2_workers - for _, worker_params in workers: + for _, worker_params, _ in workers: model_name = worker_params.model_name worker_type = worker_params.worker_type params = {"model": model_name} @@ -454,7 +454,7 @@ async def test_embeddings( is_async: bool, ): manager, workers = manager_2_embedding_workers - for _, worker_params in workers: + for _, worker_params, _ in workers: model_name = worker_params.model_name worker_type = worker_params.worker_type params = {"model": model_name, "input": ["hello", "world"]} @@ -472,7 +472,7 @@ async def test_parameter_descriptions( ] ): manager, workers = manager_with_2_workers - for _, worker_params in workers: + for _, worker_params, _ in workers: model_name = worker_params.model_name worker_type = worker_params.worker_type params = await manager.parameter_descriptions(worker_type, model_name) diff --git a/pilot/model/model_adapter.py b/pilot/model/model_adapter.py index e09b868e7..e2deeaa02 100644 --- a/pilot/model/model_adapter.py +++ b/pilot/model/model_adapter.py @@ -467,7 +467,8 @@ def __str__(self) -> str: sep="\n", sep2="", stop_str=["", "[UNK]"], - ) + ), + override=True, ) # source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L227 register_conv_template( @@ -482,7 +483,8 @@ def __str__(self) -> str: sep="###", sep2="", stop_str=["", "[UNK]"], - ) + ), + override=True, ) # source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L242 register_conv_template( @@ -495,5 +497,6 @@ def __str__(self) -> str: sep="", sep2="", stop_str=["", "<|endoftext|>"], - ) + ), + override=True, ) diff --git a/pilot/model/parameter.py b/pilot/model/parameter.py index ea81ec091..e21de1c42 100644 --- a/pilot/model/parameter.py +++ b/pilot/model/parameter.py @@ -1,9 +1,10 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- + import os from dataclasses import dataclass, field from enum import Enum -from typing import Dict, Optional +from typing import Dict, Optional, Union, Tuple from pilot.model.conversation import conv_templates from pilot.utils.parameter_utils import BaseParameters @@ -19,6 +20,35 @@ class WorkerType(str, Enum): def values(): return [item.value for item in WorkerType] + @staticmethod + def to_worker_key(worker_name, worker_type: Union[str, "WorkerType"]) -> str: + """Generate worker key from worker name and worker type + + Args: + worker_name (str): Worker name(eg., chatglm2-6b) + worker_type (Union[str, "WorkerType"]): Worker type(eg., 'llm', or [`WorkerType.LLM`]) + + Returns: + str: Generated worker key + """ + if "@" in worker_name: + raise ValueError(f"Invaild symbol '@' in your worker name {worker_name}") + if isinstance(worker_type, WorkerType): + worker_type = worker_type.value + return f"{worker_name}@{worker_type}" + + @staticmethod + def parse_worker_key(worker_key: str) -> Tuple[str, str]: + """Parse worker name and worker type from worker key + + Args: + worker_key (str): Worker key generated by [`WorkerType.to_worker_key`] + + Returns: + Tuple[str, str]: Worker name and worker type + """ + return tuple(worker_key.split("@")) + @dataclass class ModelControllerParameters(BaseParameters): @@ -60,6 +90,56 @@ class ModelControllerParameters(BaseParameters): ) +@dataclass +class ModelAPIServerParameters(BaseParameters): + host: Optional[str] = field( + default="0.0.0.0", metadata={"help": "Model API server deploy host"} + ) + port: Optional[int] = field( + default=8100, metadata={"help": "Model API server deploy port"} + ) + daemon: Optional[bool] = field( + default=False, metadata={"help": "Run Model API server in background"} + ) + controller_addr: Optional[str] = field( + default="http://127.0.0.1:8000", + metadata={"help": "The Model controller address to connect"}, + ) + + api_keys: Optional[str] = field( + default=None, + metadata={"help": "Optional list of comma separated API keys"}, + ) + + log_level: Optional[str] = field( + default=None, + metadata={ + "help": "Logging level", + "valid_values": [ + "FATAL", + "ERROR", + "WARNING", + "WARNING", + "INFO", + "DEBUG", + "NOTSET", + ], + }, + ) + log_file: Optional[str] = field( + default="dbgpt_model_apiserver.log", + metadata={ + "help": "The filename to store log", + }, + ) + tracer_file: Optional[str] = field( + default="dbgpt_model_apiserver_tracer.jsonl", + metadata={ + "help": "The filename to store tracer span records", + }, + ) + + @dataclass class BaseModelParameters(BaseParameters): model_name: str = field(metadata={"help": "Model name", "tags": "fixed"}) diff --git a/pilot/scene/base_message.py b/pilot/scene/base_message.py index eeb42a285..12a72e909 100644 --- a/pilot/scene/base_message.py +++ b/pilot/scene/base_message.py @@ -1,7 +1,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any, Dict, List, Tuple, Optional +from typing import Any, Dict, List, Tuple, Optional, Union from pydantic import BaseModel, Field, root_validator @@ -70,14 +70,6 @@ def type(self) -> str: return "system" -class ModelMessage(BaseModel): - """Type of message that interaction between dbgpt-server and llm-server""" - - """Similar to openai's message format""" - role: str - content: str - - class ModelMessageRoleType: """ "Type of ModelMessage role""" @@ -87,6 +79,45 @@ class ModelMessageRoleType: VIEW = "view" +class ModelMessage(BaseModel): + """Type of message that interaction between dbgpt-server and llm-server""" + + """Similar to openai's message format""" + role: str + content: str + + @staticmethod + def from_openai_messages( + messages: Union[str, List[Dict[str, str]]] + ) -> List["ModelMessage"]: + """Openai message format to current ModelMessage format""" + if isinstance(messages, str): + return [ModelMessage(role=ModelMessageRoleType.HUMAN, content=messages)] + result = [] + for message in messages: + msg_role = message["role"] + content = message["content"] + if msg_role == "system": + result.append( + ModelMessage(role=ModelMessageRoleType.SYSTEM, content=content) + ) + elif msg_role == "user": + result.append( + ModelMessage(role=ModelMessageRoleType.HUMAN, content=content) + ) + elif msg_role == "assistant": + result.append( + ModelMessage(role=ModelMessageRoleType.AI, content=content) + ) + else: + raise ValueError(f"Unknown role: {msg_role}") + return result + + @staticmethod + def to_dict_list(messages: List["ModelMessage"]) -> List[Dict[str, str]]: + return list(map(lambda m: m.dict(), messages)) + + class Generation(BaseModel): """Output of a single generation.""" diff --git a/pilot/utils/openai_utils.py b/pilot/utils/openai_utils.py new file mode 100644 index 000000000..6577d3abf --- /dev/null +++ b/pilot/utils/openai_utils.py @@ -0,0 +1,99 @@ +from typing import Dict, Any, Awaitable, Callable, Optional, Iterator +import httpx +import asyncio +import logging +import json + +logger = logging.getLogger(__name__) +MessageCaller = Callable[[str], Awaitable[None]] + + +async def _do_chat_completion( + url: str, + chat_data: Dict[str, Any], + client: httpx.AsyncClient, + headers: Dict[str, Any] = {}, + timeout: int = 60, + caller: Optional[MessageCaller] = None, +) -> Iterator[str]: + async with client.stream( + "POST", + url, + headers=headers, + json=chat_data, + timeout=timeout, + ) as res: + if res.status_code != 200: + error_message = await res.aread() + if error_message: + error_message = error_message.decode("utf-8") + logger.error( + f"Request failed with status {res.status_code}. Error: {error_message}" + ) + raise httpx.RequestError( + f"Request failed with status {res.status_code}", + request=res.request, + ) + async for line in res.aiter_lines(): + if line: + if not line.startswith("data: "): + if caller: + await caller(line) + yield line + else: + decoded_line = line.split("data: ", 1)[1] + if decoded_line.lower().strip() != "[DONE]".lower(): + obj = json.loads(decoded_line) + if obj["choices"][0]["delta"].get("content") is not None: + text = obj["choices"][0]["delta"].get("content") + if caller: + await caller(text) + yield text + await asyncio.sleep(0.02) + + +async def chat_completion_stream( + url: str, + chat_data: Dict[str, Any], + client: Optional[httpx.AsyncClient] = None, + headers: Dict[str, Any] = {}, + timeout: int = 60, + caller: Optional[MessageCaller] = None, +) -> Iterator[str]: + if client: + async for text in _do_chat_completion( + url, + chat_data, + client=client, + headers=headers, + timeout=timeout, + caller=caller, + ): + yield text + else: + async with httpx.AsyncClient() as client: + async for text in _do_chat_completion( + url, + chat_data, + client=client, + headers=headers, + timeout=timeout, + caller=caller, + ): + yield text + + +async def chat_completion( + url: str, + chat_data: Dict[str, Any], + client: Optional[httpx.AsyncClient] = None, + headers: Dict[str, Any] = {}, + timeout: int = 60, + caller: Optional[MessageCaller] = None, +) -> str: + full_text = "" + async for text in chat_completion_stream( + url, chat_data, client, headers=headers, timeout=timeout, caller=caller + ): + full_text += text + return full_text diff --git a/requirements/dev-requirements.txt b/requirements/dev-requirements.txt index 072b527f1..d1a98ed49 100644 --- a/requirements/dev-requirements.txt +++ b/requirements/dev-requirements.txt @@ -8,6 +8,7 @@ pytest-integration pytest-mock pytest-recording pytesseract==0.3.10 +aioresponses # python code format black # for git hooks