From 373bd01394cfa784a9ab3fe0d404d6ae1beee3d0 Mon Sep 17 00:00:00 2001 From: AllentDan <41138331+AllentDan@users.noreply.github.com> Date: Wed, 1 Nov 2023 11:40:35 +0800 Subject: [PATCH] Improve api_server and webui usage (#544) * make IPv6 compatible, safe run for coroutine interrupting * instance_id -> session_id and fix api_client.py * update doc * remove useless faq * safe ip mapping * update app.py * WIP completion * completion * update doc * disable interactive mode for /v1/chat/completions * docstring * docstring * refactor gradio * update gradio * udpate * update doc * rename * session_id default -1 * missed two files * add a APIClient * add chat func for APIClient * refine * add concurrent function * sequence_start, sequence_end --> interactive_mode * update doc * comments * doc * better text completion * remove /v1/embeddings * comments * deprecate generate and use /v1/interactive/completions * /v1/interactive/completion -> /v1/chat/interactive * embeddings * rename * remove wrong arg description * docstring * fix * update cli * update doc * strict session_len limit condition * pass model args to api_server --- README.md | 8 +- README_zh-CN.md | 8 +- benchmark/profile_restful_api.py | 54 +- benchmark/profile_throughput.py | 3 +- docs/en/restful_api.md | 126 ++--- docs/en/supported_models/codellama.md | 8 +- docs/zh_cn/restful_api.md | 125 ++--- docs/zh_cn/supported_models/codellama.md | 10 +- lmdeploy/cli/serve.py | 8 +- lmdeploy/serve/async_engine.py | 128 +---- lmdeploy/serve/gradio/__init__.py | 5 + lmdeploy/serve/gradio/api_server_backend.py | 193 +++++++ lmdeploy/serve/gradio/app.py | 530 +----------------- lmdeploy/serve/gradio/constants.py | 28 + lmdeploy/serve/gradio/css.py | 18 - .../serve/gradio/triton_server_backend.py | 134 +++++ lmdeploy/serve/gradio/turbomind_coupled.py | 194 +++++++ lmdeploy/serve/openai/api_client.py | 339 +++++++++-- lmdeploy/serve/openai/api_server.py | 256 +++++++-- lmdeploy/serve/openai/protocol.py | 11 +- lmdeploy/turbomind/chat.py | 23 +- 21 files changed, 1280 insertions(+), 929 deletions(-) create mode 100644 lmdeploy/serve/gradio/api_server_backend.py create mode 100644 lmdeploy/serve/gradio/constants.py delete mode 100644 lmdeploy/serve/gradio/css.py create mode 100644 lmdeploy/serve/gradio/triton_server_backend.py create mode 100644 lmdeploy/serve/gradio/turbomind_coupled.py diff --git a/README.md b/README.md index c65cff7e5..b40c0b90c 100644 --- a/README.md +++ b/README.md @@ -157,16 +157,16 @@ Then, you can communicate with it by command line, ```shell # restful_api_url is what printed in api_server.py, e.g. http://localhost:23333 -lmdeploy serve api_client restful_api_url +lmdeploy serve api_client api_server_url ``` or webui, ```shell -# restful_api_url is what printed in api_server.py, e.g. http://localhost:23333 +# api_server_url is what printed in api_server.py, e.g. http://localhost:23333 # server_ip and server_port here are for gradio ui -# example: lmdeploy serve gradio http://localhost:23333 --server_name localhost --server_port 6006 --restful_api True -lmdeploy serve gradio restful_api_url --server_name ${server_ip} --server_port ${server_port} --restful_api True +# example: lmdeploy serve gradio http://localhost:23333 --server_name localhost --server_port 6006 +lmdeploy serve gradio api_server_url --server_name ${gradio_ui_ip} --server_port ${gradio_ui_port} ``` Refer to [restful_api.md](docs/en/restful_api.md) for more details. diff --git a/README_zh-CN.md b/README_zh-CN.md index 84f860ef3..763432f7c 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -157,16 +157,16 @@ lmdeploy serve api_server ./workspace --server_name 0.0.0.0 --server_port ${serv ```shell # restful_api_url is what printed in api_server.py, e.g. http://localhost:23333 -lmdeploy serve api_client restful_api_url +lmdeploy serve api_client api_server_url ``` 也可以通过 WebUI 方式来对话: ```shell -# restful_api_url is what printed in api_server.py, e.g. http://localhost:23333 +# api_server_url is what printed in api_server.py, e.g. http://localhost:23333 # server_ip and server_port here are for gradio ui -# example: lmdeploy serve gradio http://localhost:23333 --server_name localhost --server_port 6006 --restful_api True -lmdeploy serve gradio restful_api_url --server_name ${server_ip} --server_port${server_port} --restful_api True +# example: lmdeploy serve gradio http://localhost:23333 --server_name localhost --server_port 6006 +lmdeploy serve gradio api_server_url --server_name ${gradio_ui_ip} --server_port ${gradio_ui_port} ``` 更多详情可以查阅 [restful_api.md](docs/zh_cn/restful_api.md)。 diff --git a/benchmark/profile_restful_api.py b/benchmark/profile_restful_api.py index d1f6ebf80..b827db16d 100644 --- a/benchmark/profile_restful_api.py +++ b/benchmark/profile_restful_api.py @@ -2,48 +2,15 @@ import multiprocessing as mp import random import time -from typing import Iterable, List import fire import numpy as np -import requests +from lmdeploy.serve.openai.api_client import get_streaming_response from lmdeploy.tokenizer import Tokenizer from lmdeploy.utils import get_logger -def get_streaming_response(prompt: str, - api_url: str, - session_id: int, - request_output_len: int, - stream: bool = True, - sequence_start: bool = True, - sequence_end: bool = False, - ignore_eos: bool = False) -> Iterable[List[str]]: - headers = {'User-Agent': 'Test Client'} - pload = { - 'prompt': prompt, - 'stream': stream, - 'session_id': session_id, - 'request_output_len': request_output_len, - 'sequence_start': sequence_start, - 'sequence_end': sequence_end, - 'ignore_eos': ignore_eos - } - response = requests.post(api_url, - headers=headers, - json=pload, - stream=stream) - for chunk in response.iter_lines(chunk_size=8192, - decode_unicode=False, - delimiter=b'\n'): - if chunk: - data = json.loads(chunk.decode('utf-8')) - output = data['text'] - tokens = data['tokens'] - yield output, tokens - - def infer(server_addr: str, session_id: int, req_queue: mp.Queue, res_que: mp.Queue): stats = [] @@ -55,13 +22,12 @@ def infer(server_addr: str, session_id: int, req_queue: mp.Queue, timestamps = [] tokens = [] start = time.perf_counter() - for res, token in get_streaming_response( + for res, token, status in get_streaming_response( prompt, server_addr, session_id, request_output_len=output_seqlen, - sequence_start=True, - sequence_end=True): + interactive_mode=False): timestamps.append(time.perf_counter()) tokens.append(token) @@ -80,13 +46,11 @@ def warmup(server_addr: str, def _infer(server_addr, session_id): for _ in range(warmup_round): - for _, _ in get_streaming_response( - '', - server_addr, - session_id, - request_output_len=output_seqlen, - sequence_start=True, - sequence_end=True): + for _ in get_streaming_response('', + server_addr, + session_id, + request_output_len=output_seqlen, + interactive_mode=False): continue _start = time.perf_counter() @@ -150,7 +114,7 @@ def main(server_addr: str, concurrency: int = 1, session_len: int = 2048, samples: int = 1000): - api_url = server_addr + '/generate' + api_url = server_addr + '/v1/chat/interactive' warmup(api_url, concurrency, session_len - 1) req_queue, n_req = read_dataset(tokenizer_path, dataset_path, samples, session_len) diff --git a/benchmark/profile_throughput.py b/benchmark/profile_throughput.py index 9d92b31fa..8fc5090f7 100644 --- a/benchmark/profile_throughput.py +++ b/benchmark/profile_throughput.py @@ -24,7 +24,8 @@ def sample_requests( dataset = [data for data in dataset if len(data['conversations']) >= 2] # Only keep the first two turns of each conversation. dataset = [(data['conversations'][0]['value'], - data['conversations'][1]['value']) for data in dataset] + data['conversations'][1]['value']) + for data in dataset][:num_requests * 2] # speed up encoding # Tokenize the prompts and completions. prompts = [prompt for prompt, _ in dataset] diff --git a/docs/en/restful_api.md b/docs/en/restful_api.md index a66859c0c..7f49edce1 100644 --- a/docs/en/restful_api.md +++ b/docs/en/restful_api.md @@ -7,52 +7,57 @@ lmdeploy serve api_server ./workspace --server_name 0.0.0.0 --server_port ${serv ``` Then, the user can open the swagger UI: `http://{server_ip}:{server_port}` for the detailed api usage. -We provide four restful api in total. Three of them are in OpenAI format. However, we recommend users try -our own api which provides more arguments for users to modify. The performance is comparatively better. +We provide four restful api in total. Three of them are in OpenAI format. + +- /v1/chat/completions +- /v1/models +- /v1/completions + +However, we recommend users try +our own api `/v1/chat/interactive` which provides more arguments for users to modify. The performance is comparatively better. + +**Note** please, if you want to launch multiple requests, you'd better set different `session_id` for both +`/v1/chat/completions` and `/v1/chat/interactive` apis. Or, we will set them random values. ### python -Here is an example for our own api `generate`. +We have integrated the client-side functionalities of these services into the `APIClient` class. Below are some examples demonstrating how to invoke the `api_server` service on the client side. + +If you want to use the `/v1/chat/completions` endpoint, you can try the following code: + +```python +from lmdeploy.serve.openai.api_client import APIClient +api_client = APIClient('http://{server_ip}:{server_port}') +model_name = api_client.available_models[0] +messages = [{"role": "user", "content": "Say this is a test!"}] +for item in api_client.chat_completions_v1(model=model_name, messages=messages): + print(item) +``` + +For the `/v1/completions` endpoint. If you want to use the `/v1/completions` endpoint, you can try: + +```python +from lmdeploy.serve.openai.api_client import APIClient +api_client = APIClient('http://{server_ip}:{server_port}') +model_name = api_client.available_models[0] +for item in api_client.completions_v1(model=model_name, prompt='hi'): + print(item) +``` + +Lmdeploy supports maintaining session histories on the server for `/v1/chat/interactive` api. We disable the +feature by default. + +- On interactive mode, the chat history is kept on the server. In a multiple rounds of conversation, you should set + `interactive_mode = True` and the same `session_id` (can't be -1, it's the default number) to `/v1/chat/interactive` for requests. +- On normal mode, no chat history is kept on the server. + +The interactive mode can be controlled by the `interactive_mode` boolean parameter. The following is an example of normal mode. If you want to experience the interactive mode, simply pass in `interactive_mode=True`. ```python -import json -import requests -from typing import Iterable, List - - -def get_streaming_response(prompt: str, - api_url: str, - session_id: int, - request_output_len: int, - stream: bool = True, - sequence_start: bool = True, - sequence_end: bool = True, - ignore_eos: bool = False) -> Iterable[List[str]]: - headers = {'User-Agent': 'Test Client'} - pload = { - 'prompt': prompt, - 'stream': stream, - 'session_id': session_id, - 'request_output_len': request_output_len, - 'sequence_start': sequence_start, - 'sequence_end': sequence_end, - 'ignore_eos': ignore_eos - } - response = requests.post( - api_url, headers=headers, json=pload, stream=stream) - for chunk in response.iter_lines( - chunk_size=8192, decode_unicode=False, delimiter=b'\n'): - if chunk: - data = json.loads(chunk.decode('utf-8')) - output = data['text'] - tokens = data['tokens'] - yield output, tokens - - -for output, tokens in get_streaming_response( - "Hi, how are you?", "http://{server_ip}:{server_port}/generate", 0, - 512): - print(output, end='') +from lmdeploy.serve.openai.api_client import APIClient +api_client = APIClient('http://{server_ip}:{server_port}') +for item in api_client.generate(prompt='hi'): + print(item) ``` ### Java/Golang/Rust @@ -84,16 +89,15 @@ List Models: curl http://{server_ip}:{server_port}/v1/models ``` -Generate: +Interactive Chat: ```bash -curl http://{server_ip}:{server_port}/generate \ +curl http://{server_ip}:{server_port}/v1/chat/interactive \ -H "Content-Type: application/json" \ -d '{ "prompt": "Hello! How are you?", "session_id": 1, - "sequence_start": true, - "sequence_end": true + "interactive_mode": true }' ``` @@ -104,19 +108,19 @@ curl http://{server_ip}:{server_port}/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ "model": "internlm-chat-7b", - "messages": [{"role": "user", "content": "Hello! Ho are you?"}] + "messages": [{"role": "user", "content": "Hello! How are you?"}] }' ``` -Embeddings: +Text Completions: -```bash -curl http://{server_ip}:{server_port}/v1/embeddings \ - -H "Content-Type: application/json" \ +```shell +curl http://{server_ip}:{server_port}/v1/completions \ + -H 'Content-Type: application/json' \ -d '{ - "model": "internlm-chat-7b", - "input": "Hello world!" - }' + "model": "llama", + "prompt": "two steps to build a house:" +}' ``` ### CLI client @@ -125,7 +129,7 @@ There is a client script for restful api server. ```shell # restful_api_url is what printed in api_server.py, e.g. http://localhost:23333 -lmdeploy serve api_client restful_api_url +lmdeploy serve api_client api_server_url ``` ### webui @@ -133,10 +137,10 @@ lmdeploy serve api_client restful_api_url You can also test restful-api through webui. ```shell -# restful_api_url is what printed in api_server.py, e.g. http://localhost:23333 +# api_server_url is what printed in api_server.py, e.g. http://localhost:23333 # server_ip and server_port here are for gradio ui -# example: lmdeploy serve gradio http://localhost:23333 --server_name localhost --server_port 6006 --restful_api True -lmdeploy serve gradio restful_api_url --server_name ${server_ip} --server_port ${server_port} --restful_api True +# example: lmdeploy serve gradio http://localhost:23333 --server_name localhost --server_port 6006 +lmdeploy serve gradio api_server_url --server_name ${gradio_ui_ip} --server_port ${gradio_ui_port} ``` ### FAQ @@ -146,10 +150,6 @@ lmdeploy serve gradio restful_api_url --server_name ${server_ip} --server_port $ 2. When OOM appeared at the server side, please reduce the number of `instance_num` when lanching the service. -3. When the request with the same `session_id` to `generate` got a empty return value and a negative `tokens`, please consider setting `sequence_start=false` for the second question and the same for the afterwards. - -4. Requests were previously being handled sequentially rather than concurrently. To resolve this issue, - - - kindly provide unique session_id values when calling the `generate` API or else your requests may be associated with client IP addresses +3. When the request with the same `session_id` to `/v1/chat/interactive` got a empty return value and a negative `tokens`, please consider setting `interactive_mode=false` to restart the session. -5. Both `generate` api and `v1/chat/completions` upport engaging in multiple rounds of conversation, where input `prompt` or `messages` consists of either single strings or entire chat histories.These inputs are interpreted using multi-turn dialogue modes. However, ff you want to turn the mode of and manage the chat history in clients, please the parameter `sequence_end: true` when utilizing the `generate` function, or specify `renew_session: true` when making use of `v1/chat/completions` +4. The `/v1/chat/interactive` api disables engaging in multiple rounds of conversation by default. The input argument `prompt` consists of either single strings or entire chat histories. diff --git a/docs/en/supported_models/codellama.md b/docs/en/supported_models/codellama.md index 886dc5922..78f4d2ce5 100644 --- a/docs/en/supported_models/codellama.md +++ b/docs/en/supported_models/codellama.md @@ -97,16 +97,16 @@ Then, you can communicate with it by command line, ```shell # restful_api_url is what printed in api_server.py, e.g. http://localhost:23333 -lmdeploy serve api_client restful_api_url +lmdeploy serve api_client api_server_url ``` or through webui after launching gradio, ```shell -# restful_api_url is what printed in api_server.py, e.g. http://localhost:23333 +# api_server_url is what printed in api_server.py, e.g. http://localhost:23333 # server_ip and server_port here are for gradio ui -# example: lmdeploy serve gradio http://localhost:23333 --server_name localhost --server_port 6006 --restful_api True -lmdeploy serve gradio restful_api_url --server_name ${server_ip} --server_port ${server_port} --restful_api True +# example: lmdeploy serve gradio http://localhost:23333 --server_name localhost --server_port 6006 +lmdeploy serve gradio api_server_url --server_name ${gradio_ui_ip} --server_port ${gradio_ui_port} ``` Regarding the detailed information of RESTful API, you can refer to [restful_api.md](../restful_api.md). diff --git a/docs/zh_cn/restful_api.md b/docs/zh_cn/restful_api.md index 484ab5686..409e29c64 100644 --- a/docs/zh_cn/restful_api.md +++ b/docs/zh_cn/restful_api.md @@ -9,52 +9,52 @@ lmdeploy serve api_server ./workspace 0.0.0.0 --server_port ${server_port} --ins ``` 然后用户可以打开 swagger UI: `http://{server_ip}:{server_port}` 详细查看所有的 API 及其使用方法。 -我们一共提供四个 restful api,其中三个仿照 OpenAI 的形式。不过,我们建议用户用我们提供的另一个 API: `generate`。 +我们一共提供四个 restful api,其中三个仿照 OpenAI 的形式。 + +- /v1/chat/completions +- /v1/models +- /v1/completions + +不过,我们建议用户用我们提供的另一个 API: `/v1/chat/interactive`。 它有更好的性能,提供更多的参数让用户自定义修改。 ### python -这是一个 python 示例,展示如何使用 `generate`。 +我们将这些服务的客户端功能集成在 `APIClient` 类中。下面是一些例子,展示如何在客户端调用 `api_server` 服务。 +如果你想用 `/v1/chat/completions` 接口,你可以尝试下面代码: + +```python +from lmdeploy.serve.openai.api_client import APIClient +api_client = APIClient('http://{server_ip}:{server_port}') +model_name = api_client.available_models[0] +messages = [{"role": "user", "content": "Say this is a test!"}] +for item in api_client.chat_completions_v1(model=model_name, messages=messages): + print(item) +``` + +如果你想用 `/v1/completions` 接口,你可以尝试: ```python -import json -import requests -from typing import Iterable, List - - -def get_streaming_response(prompt: str, - api_url: str, - session_id: int, - request_output_len: int, - stream: bool = True, - sequence_start: bool = True, - sequence_end: bool = True, - ignore_eos: bool = False) -> Iterable[List[str]]: - headers = {'User-Agent': 'Test Client'} - pload = { - 'prompt': prompt, - 'stream': stream, - 'session_id': session_id, - 'request_output_len': request_output_len, - 'sequence_start': sequence_start, - 'sequence_end': sequence_end, - 'ignore_eos': ignore_eos - } - response = requests.post( - api_url, headers=headers, json=pload, stream=stream) - for chunk in response.iter_lines( - chunk_size=8192, decode_unicode=False, delimiter=b'\n'): - if chunk: - data = json.loads(chunk.decode('utf-8')) - output = data['text'] - tokens = data['tokens'] - yield output, tokens - - -for output, tokens in get_streaming_response( - "Hi, how are you?", "http://{server_ip}:{server_port}/generate", 0, - 512): - print(output, end='') +from lmdeploy.serve.openai.api_client import APIClient +api_client = APIClient('http://{server_ip}:{server_port}') +model_name = api_client.available_models[0] +for item in api_client.completions_v1(model=model_name, prompt='hi'): + print(item) +``` + +LMDeploy 的 `/v1/chat/interactive` api 支持将对话内容管理在服务端,但是我们默认关闭。如果想尝试,请阅读以下介绍: + +- 交互模式下,对话历史保存在 server。在一次完整的多轮对话中,所有请求设置`interactive_mode = True`, `session_id`保持相同 (不为 -1,这是缺省值)。 +- 非交互模式下,server 不保存历史记录。 + +交互模式可以通过 `interactive_mode` 布尔量参数控制。下面是一个普通模式的例子, +如果要体验交互模式,将 `interactive_mode=True` 传入即可。 + +```python +from lmdeploy.serve.openai.api_client import APIClient +api_client = APIClient('http://{server_ip}:{server_port}') +for item in api_client.generate(prompt='hi'): + print(item) ``` ### Java/Golang/Rust @@ -86,16 +86,15 @@ cURL 也可以用于查看 API 的输出结果 curl http://{server_ip}:{server_port}/v1/models ``` -使用 generate: +Interactive Chat: ```bash -curl http://{server_ip}:{server_port}/generate \ +curl http://{server_ip}:{server_port}/v1/chat/interactive \ -H "Content-Type: application/json" \ -d '{ "prompt": "Hello! How are you?", "session_id": 1, - "sequence_start": true, - "sequence_end": true + "interactive_mode": true }' ``` @@ -106,19 +105,19 @@ curl http://{server_ip}:{server_port}/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ "model": "internlm-chat-7b", - "messages": [{"role": "user", "content": "Hello! Ho are you?"}] + "messages": [{"role": "user", "content": "Hello! How are you?"}] }' ``` -Embeddings: +Text Completions: -```bash -curl http://{server_ip}:{server_port}/v1/embeddings \ - -H "Content-Type: application/json" \ +```shell +curl http://{server_ip}:{server_port}/v1/completions \ + -H 'Content-Type: application/json' \ -d '{ - "model": "internlm-chat-7b", - "input": "Hello world!" - }' + "model": "llama", + "prompt": "two steps to build a house:" +}' ``` ### CLI client @@ -126,8 +125,8 @@ curl http://{server_ip}:{server_port}/v1/embeddings \ restful api 服务可以通过客户端测试,例如 ```shell -# restful_api_url 就是 api_server 产生的,比如 http://localhost:23333 -lmdeploy serve api_client restful_api_url +# restful_api_url is what printed in api_server.py, e.g. http://localhost:23333 +lmdeploy serve api_client api_server_url ``` ### webui @@ -135,10 +134,10 @@ lmdeploy serve api_client restful_api_url 也可以直接用 webui 测试使用 restful-api。 ```shell -# restful_api_url 就是 api_server 产生的,比如 http://localhost:23333 -# server_ip 和 server_port 是用来提供 gradio ui 访问服务的 -# 例子: lmdeploy serve gradio http://localhost:23333 --server_name localhost --server_port 6006 --restful_api True -lmdeploy serve gradio restful_api_url --server_name ${server_ip} --server_port ${server_port} --restful_api True +# api_server_url 就是 api_server 产生的,比如 http://localhost:23333 +# server_name 和 server_port 是用来提供 gradio ui 访问服务的 +# 例子: lmdeploy serve gradio http://localhost:23333 --server_name localhost --server_port 6006 +lmdeploy serve gradio api_server_url --server_name ${gradio_ui_ip} --server_port ${gradio_ui_port} ``` ### FAQ @@ -148,12 +147,6 @@ lmdeploy serve gradio restful_api_url --server_name ${server_ip} --server_port $ 2. 当服务端显存 OOM 时,可以适当减小启动服务时的 `instance_num` 个数 -3. 当同一个 `session_id` 的请求给 `generate` 函数后,出现返回空字符串和负值的 `tokens`,应该是第二次问话没有设置 `sequence_start=false` - -4. 如果感觉请求不是并发地被处理,而是一个一个地处理,请设置好以下参数: - - - 不同的 session_id 传入 `generate` api。否则,我们将自动绑定会话 id 为请求端的 ip 地址编号。 +3. 当同一个 `session_id` 的请求给 `/v1/chat/interactive` 函数后,出现返回空字符串和负值的 `tokens`,应该是 `session_id` 混乱了,可以先将交互模式关闭,再重新开启。 -5. `generate` api 和 `v1/chat/completions` 均支持多轮对话。`messages` 或者 `prompt` 参数既可以是一个简单字符串表示用户的单词提问,也可以是一段对话历史。 - 两个 api 都是默认开启多伦对话的,如果你想关闭这个功能,然后在客户端管理会话记录,请设置 `sequence_end: true` 传入 `generate`,或者设置 - `renew_session: true` 传入 `v1/chat/completions`。 +4. `/v1/chat/interactive` api 支持多轮对话, 但是默认关闭。`messages` 或者 `prompt` 参数既可以是一个简单字符串表示用户的单词提问,也可以是一段对话历史。 diff --git a/docs/zh_cn/supported_models/codellama.md b/docs/zh_cn/supported_models/codellama.md index a2abd2f4a..017df62b5 100644 --- a/docs/zh_cn/supported_models/codellama.md +++ b/docs/zh_cn/supported_models/codellama.md @@ -98,17 +98,17 @@ lmdeploy serve api_server ./workspace --server_name 0.0.0.0 --server_port ${serv 你可以用命令行,在控制台与 server 通信: ```shell -# restful_api_url 就是 api_server 产生的,比如 http://localhost:23333 -lmdeploy serve api_client restful_api_url +# api_server_url 就是 api_server 产生的,比如 http://localhost:23333 +lmdeploy serve api_client api_server_url ``` 或者,启动 gradio,在 webui 的聊天对话框中,与 codellama 交流: ```shell -# restful_api_url 就是 api_server 产生的,比如 http://localhost:23333 +# api_server_url 就是 api_server 产生的,比如 http://localhost:23333 # server_ip 和 server_port 是用来提供 gradio ui 访问服务的 -# 例子: lmdeploy serve gradio http://localhost:23333 --server_name localhost --server_port 6006 --restful_api True -lmdeploy serve gradio restful_api_url --server_name ${server_ip} --server_port ${server_port} --restful_api True +# 例子: lmdeploy serve gradio http://localhost:23333 --server_name localhost --server_port 6006 +lmdeploy serve gradio api_server_url --server_name ${gradio_ui_ip} --server_port ${gradio_ui_port} ``` 关于 RESTful API的详细介绍,请参考[这份](../restful_api.md)文档。 diff --git a/lmdeploy/cli/serve.py b/lmdeploy/cli/serve.py index 0bff69c31..33580cdfe 100644 --- a/lmdeploy/cli/serve.py +++ b/lmdeploy/cli/serve.py @@ -7,7 +7,7 @@ class SubCliServe(object): def gradio(self, model_path_or_server: str, - server_name: str = 'localhost', + server_name: str = '0.0.0.0', server_port: int = 6006, batch_size: int = 32, tp: int = 1, @@ -18,8 +18,8 @@ def gradio(self, lmdeploy serve gradio ./workspace Example 2: - lmdeploy serve gradio http://localhost:23333 - --server_name localhost + lmdeploy serve gradio http://0.0.0.0:23333 + --server_name 0.0.0.0 --server_port 6006 --restful_api True @@ -48,7 +48,7 @@ def gradio(self, def api_server(self, model_path: str, - server_name: str = 'localhost', + server_name: str = '0.0.0.0', server_port: int = 23333, instance_num: int = 32, tp: int = 1, diff --git a/lmdeploy/serve/async_engine.py b/lmdeploy/serve/async_engine.py index 4aba9dce7..6cc3c4a53 100644 --- a/lmdeploy/serve/async_engine.py +++ b/lmdeploy/serve/async_engine.py @@ -28,7 +28,7 @@ class AsyncEngine: tp (int): tensor parallel """ - def __init__(self, model_path, instance_num=32, tp=1) -> None: + def __init__(self, model_path, instance_num=32, tp=1, **kwargs) -> None: from lmdeploy import turbomind as tm from lmdeploy.tokenizer import Tokenizer tokenizer_model_path = osp.join(model_path, 'triton_models', @@ -42,13 +42,14 @@ def __init__(self, model_path, instance_num=32, tp=1) -> None: self.tm_model.create_instance() for i in range(instance_num) ] self.instance_num = instance_num - self.model: BaseModel = MODELS.get(self.tm_model.model_name)() + self.model: BaseModel = MODELS.get(self.tm_model.model_name)(**kwargs) self.available = [True] * instance_num self.starts = [None] * instance_num self.steps = {} self.loop = asyncio.get_event_loop() def stop_session(self, session_id: int): + """Stop a session by a session_id.""" instance_id = session_id % self.instance_num input_ids = self.tokenizer.encode('') for outputs in self.generators[instance_id].stream_infer( @@ -61,8 +62,24 @@ def stop_session(self, session_id: int): pass self.available[instance_id] = True + def end_session(self, session_id: int): + """Clear a session by a session_id.""" + instance_id = session_id % self.instance_num + input_ids = self.tokenizer.encode('') + for outputs in self.generators[instance_id].stream_infer( + session_id, + input_ids, + request_output_len=0, + sequence_start=False, + sequence_end=True, + stop=True): + pass + self.steps[str(session_id)] = 0 + self.available[instance_id] = True + @contextmanager def safe_run(self, instance_id: int, session_id: Optional[int] = None): + """A context manager to make sure server's safe running.""" self.available[instance_id] = False try: yield @@ -142,7 +159,7 @@ async def generate( session_id, stream_response=True, sequence_start=True, - sequence_end=False, + sequence_end=True, # no interactive mode by default step=0, request_output_len=512, stop=False, @@ -151,6 +168,7 @@ async def generate( temperature=0.8, repetition_penalty=1.0, ignore_eos=False, + do_preprocess=True, ): """Generate responses. @@ -172,6 +190,7 @@ async def generate( repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty ignore_eos (bool): indicator for ignoring eos + do_preprocess (bool): whether pre-process the messages. """ instance_id = session_id % self.instance_num if str(session_id) not in self.steps: @@ -179,14 +198,18 @@ async def generate( if step != 0: self.steps[str(session_id)] = step seed = random.getrandbits(64) - prompt = self.model.messages2prompt(messages, sequence_start) + prompt = messages + if do_preprocess: + prompt = self.model.messages2prompt(prompt, sequence_start) input_ids = self.tokenizer.encode(prompt) finish_reason = 'stop' if stop else None if self.steps[str(session_id)] + len( - input_ids) >= self.tm_model.session_len: + input_ids) + request_output_len >= self.tm_model.session_len: finish_reason = 'length' yield GenOut('', self.steps[str(session_id)], len(input_ids), 0, finish_reason) + if sequence_end is True and sequence_start is False: + self.end_session(session_id) else: generator = await self.get_generator(instance_id, stop) with self.safe_run(instance_id, session_id): @@ -225,98 +248,3 @@ async def generate( self.steps[str(session_id)] += len(input_ids) + tokens if sequence_end or stop: self.steps[str(session_id)] = 0 - - async def generate_openai( - self, - messages, - instance_id, - stream_response=True, - renew_session=False, - request_output_len=512, - stop=False, - top_k=40, - top_p=0.8, - temperature=0.8, - repetition_penalty=1.0, - ignore_eos=False, - ): - """Generate responses. - - Args: - messages (str | List): chat history or prompt - instance_id (int): actually request host ip - stream_response (bool): whether return responses streamingly - renew_session (bool): renew the session - request_output_len (int): output token nums - stop (bool): whether stop inference - top_k (int): The number of the highest probability vocabulary - tokens to keep for top-k-filtering - top_p (float): If set to float < 1, only the smallest set of most - probable tokens with probabilities that add up to top_p or higher - are kept for generation. - temperature (float): to modulate the next token probability - repetition_penalty (float): The parameter for repetition penalty. - 1.0 means no penalty - ignore_eos (bool): indicator for ignoring eos - """ - session_id = instance_id - instance_id %= self.instance_num - sequence_start = False - generator = await self.get_generator(instance_id) - if renew_session: # renew a session - empty_input_ids = self.tokenizer.encode('') - for outputs in generator.stream_infer(session_id=session_id, - input_ids=[empty_input_ids], - request_output_len=0, - sequence_start=False, - sequence_end=True, - stop=True): - pass - self.steps[str(session_id)] = 0 - if str(session_id) not in self.steps: - self.steps[str(session_id)] = 0 - if self.steps[str(session_id)] == 0: - sequence_start = True - seed = random.getrandbits(64) - prompt = self.model.messages2prompt(messages, sequence_start) - input_ids = self.tokenizer.encode(prompt) - finish_reason = 'stop' if stop else None - if self.steps[str(session_id)] + len( - input_ids) >= self.tm_model.session_len: - finish_reason = 'length' - yield GenOut('', self.steps[str(session_id)], len(input_ids), 0, - finish_reason) - else: - with self.safe_run(instance_id, session_id): - response_size = 0 - async for outputs in generator.async_stream_infer( - session_id=session_id, - input_ids=[input_ids], - stream_output=stream_response, - request_output_len=request_output_len, - sequence_start=(sequence_start), - sequence_end=False, - step=self.steps[str(session_id)], - stop=stop, - top_k=top_k, - top_p=top_p, - temperature=temperature, - repetition_penalty=repetition_penalty, - ignore_eos=ignore_eos, - random_seed=seed if sequence_start else None): - res, tokens = outputs[0] - # decode res - response = self.tokenizer.decode(res.tolist(), - offset=response_size) - # utf-8 char at the end means it's a potential unfinished - # byte sequence, continue to concate it with the next - # sequence and decode them together - if response.endswith('�'): - continue - # response, history len, input len, generation len - yield GenOut(response, self.steps[str(session_id)], - len(input_ids), tokens, finish_reason) - response_size = tokens - - # update step - self.steps[str(session_id)] += len(input_ids) + tokens diff --git a/lmdeploy/serve/gradio/__init__.py b/lmdeploy/serve/gradio/__init__.py index ef101fec6..770138a44 100644 --- a/lmdeploy/serve/gradio/__init__.py +++ b/lmdeploy/serve/gradio/__init__.py @@ -1 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .api_server_backend import run_api_server +from .triton_server_backend import run_triton_server +from .turbomind_coupled import run_local + +__all__ = ['run_api_server', 'run_triton_server', 'run_local'] diff --git a/lmdeploy/serve/gradio/api_server_backend.py b/lmdeploy/serve/gradio/api_server_backend.py new file mode 100644 index 000000000..ce6450879 --- /dev/null +++ b/lmdeploy/serve/gradio/api_server_backend.py @@ -0,0 +1,193 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import threading +import time +from typing import Sequence + +import gradio as gr + +from lmdeploy.serve.gradio.constants import CSS, THEME, disable_btn, enable_btn +from lmdeploy.serve.openai.api_client import (get_model_list, + get_streaming_response) +from lmdeploy.serve.openai.api_server import ip2id + + +class InterFace: + api_server_url: str = None + + +def chat_stream_restful( + instruction: str, + state_chatbot: Sequence, + cancel_btn: gr.Button, + reset_btn: gr.Button, + request: gr.Request, +): + """Chat with AI assistant. + + Args: + instruction (str): user's prompt + state_chatbot (Sequence): the chatting history + request (gr.Request): the request from a user + """ + session_id = threading.current_thread().ident + if request is not None: + session_id = ip2id(request.kwargs['client']['host']) + bot_summarized_response = '' + state_chatbot = state_chatbot + [(instruction, None)] + + yield (state_chatbot, state_chatbot, disable_btn, enable_btn, + f'{bot_summarized_response}'.strip()) + + for response, tokens, finish_reason in get_streaming_response( + instruction, + f'{InterFace.api_server_url}/v1/chat/interactive', + session_id=session_id, + request_output_len=512, + interactive_mode=True): + if finish_reason == 'length': + gr.Warning('WARNING: exceed session max length.' + ' Please restart the session by reset button.') + if tokens < 0: + gr.Warning('WARNING: running on the old session.' + ' Please restart the session by reset button.') + if state_chatbot[-1][-1] is None: + state_chatbot[-1] = (state_chatbot[-1][0], response) + else: + state_chatbot[-1] = (state_chatbot[-1][0], + state_chatbot[-1][1] + response + ) # piece by piece + yield (state_chatbot, state_chatbot, enable_btn, disable_btn, + f'{bot_summarized_response}'.strip()) + + yield (state_chatbot, state_chatbot, disable_btn, enable_btn, + f'{bot_summarized_response}'.strip()) + + +def reset_restful_func(instruction_txtbox: gr.Textbox, state_chatbot: gr.State, + request: gr.Request): + """reset the session. + + Args: + instruction_txtbox (str): user's prompt + state_chatbot (Sequence): the chatting history + request (gr.Request): the request from a user + """ + state_chatbot = [] + + session_id = threading.current_thread().ident + if request is not None: + session_id = ip2id(request.kwargs['client']['host']) + # end the session + for response, tokens, finish_reason in get_streaming_response( + '', + f'{InterFace.api_server_url}/v1/chat/interactive', + session_id=session_id, + request_output_len=0, + interactive_mode=False): + pass + + return ( + state_chatbot, + state_chatbot, + gr.Textbox.update(value=''), + ) + + +def cancel_restful_func(state_chatbot: gr.State, cancel_btn: gr.Button, + reset_btn: gr.Button, request: gr.Request): + """stop the session. + + Args: + instruction_txtbox (str): user's prompt + state_chatbot (Sequence): the chatting history + request (gr.Request): the request from a user + """ + yield (state_chatbot, disable_btn, disable_btn) + session_id = threading.current_thread().ident + if request is not None: + session_id = ip2id(request.kwargs['client']['host']) + # end the session + for out in get_streaming_response( + '', + f'{InterFace.api_server_url}/v1/chat/interactive', + session_id=session_id, + request_output_len=0, + stop=True): + pass + time.sleep(0.5) + messages = [] + for qa in state_chatbot: + messages.append(dict(role='user', content=qa[0])) + if qa[1] is not None: + messages.append(dict(role='assistant', content=qa[1])) + for out in get_streaming_response( + messages, + f'{InterFace.api_server_url}/v1/chat/interactive', + session_id=session_id, + request_output_len=0, + interactive_mode=True): + pass + yield (state_chatbot, disable_btn, enable_btn) + + +def run_api_server(api_server_url: str, + server_name: str = 'localhost', + server_port: int = 6006, + batch_size: int = 32): + """chat with AI assistant through web ui. + + Args: + api_server_url (str): restufl api url + server_name (str): the ip address of gradio server + server_port (int): the port of gradio server + batch_size (int): batch size for running Turbomind directly + """ + InterFace.api_server_url = api_server_url + model_names = get_model_list(f'{api_server_url}/v1/models') + model_name = '' + if isinstance(model_names, list) and len(model_names) > 0: + model_name = model_names[0] + else: + raise ValueError('gradio can find a suitable model from restful-api') + + with gr.Blocks(css=CSS, theme=THEME) as demo: + state_chatbot = gr.State([]) + + with gr.Column(elem_id='container'): + gr.Markdown('## LMDeploy Playground') + + chatbot = gr.Chatbot(elem_id='chatbot', label=model_name) + instruction_txtbox = gr.Textbox( + placeholder='Please input the instruction', + label='Instruction') + with gr.Row(): + cancel_btn = gr.Button(value='Cancel', interactive=False) + reset_btn = gr.Button(value='Reset') + + send_event = instruction_txtbox.submit( + chat_stream_restful, + [instruction_txtbox, state_chatbot, cancel_btn, reset_btn], + [state_chatbot, chatbot, cancel_btn, reset_btn]) + instruction_txtbox.submit( + lambda: gr.Textbox.update(value=''), + [], + [instruction_txtbox], + ) + cancel_btn.click(cancel_restful_func, + [state_chatbot, cancel_btn, reset_btn], + [state_chatbot, cancel_btn, reset_btn], + cancels=[send_event]) + + reset_btn.click(reset_restful_func, + [instruction_txtbox, state_chatbot], + [state_chatbot, chatbot, instruction_txtbox], + cancels=[send_event]) + + print(f'server is gonna mount on: http://{server_name}:{server_port}') + demo.queue(concurrency_count=batch_size, max_size=100, + api_open=True).launch( + max_threads=10, + share=True, + server_port=server_port, + server_name=server_name, + ) diff --git a/lmdeploy/serve/gradio/app.py b/lmdeploy/serve/gradio/app.py index 5c200517b..5b1668224 100644 --- a/lmdeploy/serve/gradio/app.py +++ b/lmdeploy/serve/gradio/app.py @@ -1,538 +1,36 @@ # Copyright (c) OpenMMLab. All rights reserved. -import os -import threading -import time -from functools import partial -from typing import Sequence - -import gradio as gr - -from lmdeploy.serve.async_engine import AsyncEngine -from lmdeploy.serve.gradio.css import CSS -from lmdeploy.serve.openai.api_client import (get_model_list, - get_streaming_response) -from lmdeploy.serve.openai.api_server import ip2id -from lmdeploy.serve.turbomind.chatbot import Chatbot - -THEME = gr.themes.Soft( - primary_hue=gr.themes.colors.blue, - secondary_hue=gr.themes.colors.sky, - font=[gr.themes.GoogleFont('Inconsolata'), 'Arial', 'sans-serif']) - -enable_btn = gr.Button.update(interactive=True) -disable_btn = gr.Button.update(interactive=False) - - -def chat_stream(state_chatbot: Sequence, llama_chatbot: Chatbot, - request: gr.Request): - """Chat with AI assistant. - - Args: - instruction (str): user's prompt - state_chatbot (Sequence): the chatting history - llama_chatbot (Chatbot): the instance of a chatbot - request (gr.Request): the request from a user - model_name (str): the name of deployed model - """ - instruction = state_chatbot[-1][0] - session_id = threading.current_thread().ident - if request is not None: - session_id = ip2id(request.kwargs['client']['host']) - - bot_response = llama_chatbot.stream_infer( - session_id, instruction, f'{session_id}-{len(state_chatbot)}') - - for status, tokens, _ in bot_response: - state_chatbot[-1] = (state_chatbot[-1][0], tokens) - yield (state_chatbot, state_chatbot, '') - - return (state_chatbot, state_chatbot, '') - - -def reset_all_func(instruction_txtbox: gr.Textbox, state_chatbot: gr.State, - llama_chatbot: gr.State, triton_server_addr: str, - model_name: str): - """reset the session.""" - state_chatbot = [] - log_level = os.environ.get('SERVICE_LOG_LEVEL', 'INFO') - llama_chatbot = Chatbot(triton_server_addr, - model_name, - log_level=log_level, - display=True) - - return ( - llama_chatbot, - state_chatbot, - state_chatbot, - gr.Textbox.update(value=''), - ) - - -def cancel_func( - instruction_txtbox: gr.Textbox, - state_chatbot: gr.State, - llama_chatbot: gr.State, -): - """cancel the session.""" - session_id = llama_chatbot._session.session_id - llama_chatbot.cancel(session_id) - - return ( - llama_chatbot, - state_chatbot, - ) - - -def add_instruction(instruction, state_chatbot): - state_chatbot = state_chatbot + [(instruction, None)] - return ('', state_chatbot) - - -def run_server(triton_server_addr: str, - server_name: str = 'localhost', - server_port: int = 6006): - """chat with AI assistant through web ui. - - Args: - triton_server_addr (str): the communication address of inference server - server_name (str): the ip address of gradio server - server_port (int): the port of gradio server - """ - with gr.Blocks(css=CSS, theme=THEME) as demo: - log_level = os.environ.get('SERVICE_LOG_LEVEL', 'INFO') - llama_chatbot = gr.State( - Chatbot(triton_server_addr, log_level=log_level, display=True)) - state_chatbot = gr.State([]) - model_name = llama_chatbot.value.model_name - reset_all = partial(reset_all_func, - model_name=model_name, - triton_server_addr=triton_server_addr) - - with gr.Column(elem_id='container'): - gr.Markdown('## LMDeploy Playground') - - chatbot = gr.Chatbot(elem_id='chatbot', label=model_name) - instruction_txtbox = gr.Textbox( - placeholder='Please input the instruction', - label='Instruction') - with gr.Row(): - cancel_btn = gr.Button(value='Cancel') - reset_btn = gr.Button(value='Reset') - - send_event = instruction_txtbox.submit( - add_instruction, [instruction_txtbox, state_chatbot], - [instruction_txtbox, state_chatbot]).then( - chat_stream, [state_chatbot, llama_chatbot], - [state_chatbot, chatbot]) - - cancel_btn.click(cancel_func, - [instruction_txtbox, state_chatbot, llama_chatbot], - [llama_chatbot, chatbot], - cancels=[send_event]) - - reset_btn.click( - reset_all, [instruction_txtbox, state_chatbot, llama_chatbot], - [llama_chatbot, state_chatbot, chatbot, instruction_txtbox], - cancels=[send_event]) - - print(f'server is gonna mount on: http://{server_name}:{server_port}') - demo.queue(concurrency_count=4, max_size=100, api_open=True).launch( - max_threads=10, - share=True, - server_port=server_port, - server_name=server_name, - ) - - -# a IO interface mananing variables -class InterFace: - async_engine: AsyncEngine = None # for run_local - restful_api_url: str = None # for run_restful - - -def chat_stream_restful( - instruction: str, - state_chatbot: Sequence, - cancel_btn: gr.Button, - reset_btn: gr.Button, - request: gr.Request, -): - """Chat with AI assistant. - - Args: - instruction (str): user's prompt - state_chatbot (Sequence): the chatting history - request (gr.Request): the request from a user - """ - session_id = threading.current_thread().ident - if request is not None: - session_id = ip2id(request.kwargs['client']['host']) - bot_summarized_response = '' - state_chatbot = state_chatbot + [(instruction, None)] - - yield (state_chatbot, state_chatbot, disable_btn, enable_btn, - f'{bot_summarized_response}'.strip()) - - for response, tokens, finish_reason in get_streaming_response( - instruction, - f'{InterFace.restful_api_url}/generate', - session_id=session_id, - request_output_len=512, - sequence_start=(len(state_chatbot) == 1), - sequence_end=False): - if finish_reason == 'length': - gr.Warning('WARNING: exceed session max length.' - ' Please restart the session by reset button.') - if tokens < 0: - gr.Warning('WARNING: running on the old session.' - ' Please restart the session by reset button.') - if state_chatbot[-1][-1] is None: - state_chatbot[-1] = (state_chatbot[-1][0], response) - else: - state_chatbot[-1] = (state_chatbot[-1][0], - state_chatbot[-1][1] + response - ) # piece by piece - yield (state_chatbot, state_chatbot, enable_btn, disable_btn, - f'{bot_summarized_response}'.strip()) - - yield (state_chatbot, state_chatbot, disable_btn, enable_btn, - f'{bot_summarized_response}'.strip()) - - -def reset_restful_func(instruction_txtbox: gr.Textbox, state_chatbot: gr.State, - request: gr.Request): - """reset the session. - - Args: - instruction_txtbox (str): user's prompt - state_chatbot (Sequence): the chatting history - request (gr.Request): the request from a user - """ - state_chatbot = [] - - session_id = threading.current_thread().ident - if request is not None: - session_id = ip2id(request.kwargs['client']['host']) - # end the session - for response, tokens, finish_reason in get_streaming_response( - '', - f'{InterFace.restful_api_url}/generate', - session_id=session_id, - request_output_len=0, - sequence_start=False, - sequence_end=True): - pass - - return ( - state_chatbot, - state_chatbot, - gr.Textbox.update(value=''), - ) - - -def cancel_restful_func(state_chatbot: gr.State, cancel_btn: gr.Button, - reset_btn: gr.Button, request: gr.Request): - """stop the session. - - Args: - instruction_txtbox (str): user's prompt - state_chatbot (Sequence): the chatting history - request (gr.Request): the request from a user - """ - session_id = threading.current_thread().ident - if request is not None: - session_id = ip2id(request.kwargs['client']['host']) - # end the session - for out in get_streaming_response('', - f'{InterFace.restful_api_url}/generate', - session_id=session_id, - request_output_len=0, - sequence_start=False, - sequence_end=False, - stop=True): - pass - time.sleep(0.5) - messages = [] - for qa in state_chatbot: - messages.append(dict(role='user', content=qa[0])) - if qa[1] is not None: - messages.append(dict(role='assistant', content=qa[1])) - for out in get_streaming_response(messages, - f'{InterFace.restful_api_url}/generate', - session_id=session_id, - request_output_len=0, - sequence_start=True, - sequence_end=False): - pass - return (state_chatbot, disable_btn, enable_btn) - - -def run_restful(restful_api_url: str, - server_name: str = 'localhost', - server_port: int = 6006, - batch_size: int = 32): - """chat with AI assistant through web ui. - - Args: - restful_api_url (str): restufl api url - server_name (str): the ip address of gradio server - server_port (int): the port of gradio server - batch_size (int): batch size for running Turbomind directly - """ - InterFace.restful_api_url = restful_api_url - model_names = get_model_list(f'{restful_api_url}/v1/models') - model_name = '' - if isinstance(model_names, list) and len(model_names) > 0: - model_name = model_names[0] - else: - raise ValueError('gradio can find a suitable model from restful-api') - - with gr.Blocks(css=CSS, theme=THEME) as demo: - state_chatbot = gr.State([]) - - with gr.Column(elem_id='container'): - gr.Markdown('## LMDeploy Playground') - - chatbot = gr.Chatbot(elem_id='chatbot', label=model_name) - instruction_txtbox = gr.Textbox( - placeholder='Please input the instruction', - label='Instruction') - with gr.Row(): - cancel_btn = gr.Button(value='Cancel', interactive=False) - reset_btn = gr.Button(value='Reset') - - send_event = instruction_txtbox.submit( - chat_stream_restful, - [instruction_txtbox, state_chatbot, cancel_btn, reset_btn], - [state_chatbot, chatbot, cancel_btn, reset_btn]) - instruction_txtbox.submit( - lambda: gr.Textbox.update(value=''), - [], - [instruction_txtbox], - ) - cancel_btn.click(cancel_restful_func, - [state_chatbot, cancel_btn, reset_btn], - [state_chatbot, cancel_btn, reset_btn], - cancels=[send_event]) - - reset_btn.click(reset_restful_func, - [instruction_txtbox, state_chatbot], - [state_chatbot, chatbot, instruction_txtbox], - cancels=[send_event]) - - print(f'server is gonna mount on: http://{server_name}:{server_port}') - demo.queue(concurrency_count=batch_size, max_size=100, - api_open=True).launch( - max_threads=10, - share=True, - server_port=server_port, - server_name=server_name, - ) - - -async def chat_stream_local( - instruction: str, - state_chatbot: Sequence, - cancel_btn: gr.Button, - reset_btn: gr.Button, - request: gr.Request, -): - """Chat with AI assistant. - - Args: - instruction (str): user's prompt - state_chatbot (Sequence): the chatting history - request (gr.Request): the request from a user - """ - session_id = threading.current_thread().ident - if request is not None: - session_id = ip2id(request.kwargs['client']['host']) - bot_summarized_response = '' - state_chatbot = state_chatbot + [(instruction, None)] - - yield (state_chatbot, state_chatbot, disable_btn, enable_btn, - f'{bot_summarized_response}'.strip()) - - async for outputs in InterFace.async_engine.generate( - instruction, - session_id, - stream_response=True, - sequence_start=(len(state_chatbot) == 1)): - response = outputs.response - if outputs.finish_reason == 'length': - gr.Warning('WARNING: exceed session max length.' - ' Please restart the session by reset button.') - if outputs.generate_token_len < 0: - gr.Warning('WARNING: running on the old session.' - ' Please restart the session by reset button.') - if state_chatbot[-1][-1] is None: - state_chatbot[-1] = (state_chatbot[-1][0], response) - else: - state_chatbot[-1] = (state_chatbot[-1][0], - state_chatbot[-1][1] + response - ) # piece by piece - yield (state_chatbot, state_chatbot, enable_btn, disable_btn, - f'{bot_summarized_response}'.strip()) - - yield (state_chatbot, state_chatbot, disable_btn, enable_btn, - f'{bot_summarized_response}'.strip()) - - -async def reset_local_func(instruction_txtbox: gr.Textbox, - state_chatbot: gr.State, request: gr.Request): - """reset the session. - - Args: - instruction_txtbox (str): user's prompt - state_chatbot (Sequence): the chatting history - request (gr.Request): the request from a user - """ - state_chatbot = [] - - session_id = threading.current_thread().ident - if request is not None: - session_id = ip2id(request.kwargs['client']['host']) - # end the session - async for out in InterFace.async_engine.generate('', - session_id, - request_output_len=1, - stream_response=True, - sequence_start=False, - sequence_end=True): - pass - - return ( - state_chatbot, - state_chatbot, - gr.Textbox.update(value=''), - ) - - -async def cancel_local_func(state_chatbot: gr.State, cancel_btn: gr.Button, - reset_btn: gr.Button, request: gr.Request): - """stop the session. - - Args: - instruction_txtbox (str): user's prompt - state_chatbot (Sequence): the chatting history - request (gr.Request): the request from a user - """ - session_id = threading.current_thread().ident - if request is not None: - session_id = ip2id(request.kwargs['client']['host']) - # end the session - async for out in InterFace.async_engine.generate('', - session_id, - request_output_len=0, - stream_response=True, - sequence_start=False, - sequence_end=False, - stop=True): - pass - messages = [] - for qa in state_chatbot: - messages.append(dict(role='user', content=qa[0])) - if qa[1] is not None: - messages.append(dict(role='assistant', content=qa[1])) - async for out in InterFace.async_engine.generate(messages, - session_id, - request_output_len=0, - stream_response=True, - sequence_start=True, - sequence_end=False): - pass - return (state_chatbot, disable_btn, enable_btn) - - -def run_local(model_path: str, - server_name: str = 'localhost', - server_port: int = 6006, - batch_size: int = 4, - tp: int = 1): - """chat with AI assistant through web ui. - - Args: - model_path (str): the path of the deployed model - server_name (str): the ip address of gradio server - server_port (int): the port of gradio server - batch_size (int): batch size for running Turbomind directly - tp (int): tensor parallel for Turbomind - """ - InterFace.async_engine = AsyncEngine(model_path=model_path, - instance_num=batch_size, - tp=tp) - - with gr.Blocks(css=CSS, theme=THEME) as demo: - state_chatbot = gr.State([]) - - with gr.Column(elem_id='container'): - gr.Markdown('## LMDeploy Playground') - - chatbot = gr.Chatbot( - elem_id='chatbot', - label=InterFace.async_engine.tm_model.model_name) - instruction_txtbox = gr.Textbox( - placeholder='Please input the instruction', - label='Instruction') - with gr.Row(): - cancel_btn = gr.Button(value='Cancel', interactive=False) - reset_btn = gr.Button(value='Reset') - - send_event = instruction_txtbox.submit( - chat_stream_local, - [instruction_txtbox, state_chatbot, cancel_btn, reset_btn], - [state_chatbot, chatbot, cancel_btn, reset_btn]) - instruction_txtbox.submit( - lambda: gr.Textbox.update(value=''), - [], - [instruction_txtbox], - ) - cancel_btn.click(cancel_local_func, - [state_chatbot, cancel_btn, reset_btn], - [state_chatbot, cancel_btn, reset_btn], - cancels=[send_event]) - - reset_btn.click(reset_local_func, [instruction_txtbox, state_chatbot], - [state_chatbot, chatbot, instruction_txtbox], - cancels=[send_event]) - - print(f'server is gonna mount on: http://{server_name}:{server_port}') - demo.queue(concurrency_count=batch_size, max_size=100, - api_open=True).launch( - max_threads=10, - share=True, - server_port=server_port, - server_name=server_name, - ) def run(model_path_or_server: str, - server_name: str = 'localhost', + server_name: str = '0.0.0.0', server_port: int = 6006, batch_size: int = 32, tp: int = 1, - restful_api: bool = False): + **kwargs): """chat with AI assistant through web ui. Args: model_path_or_server (str): the path of the deployed model or the - tritonserver URL or restful api URL. The former is for directly - running service with gradio. The latter is for running with - tritonserver by default. If the input URL is restful api. Please - enable another flag `restful_api`. + tritonserver URL or restful api URL. For example: + - ./workspace + - 0.0.0.0:23333 + - http://0.0.0.0:23333 server_name (str): the ip address of gradio server server_port (int): the port of gradio server batch_size (int): batch size for running Turbomind directly tp (int): tensor parallel for Turbomind - restful_api (bool): a flag for model_path_or_server """ if ':' in model_path_or_server: - if restful_api: - run_restful(model_path_or_server, server_name, server_port, - batch_size) + if 'http:' in model_path_or_server: + from lmdeploy.serve.gradio.api_server_backend import run_api_server + run_api_server(model_path_or_server, server_name, server_port, + batch_size) else: - run_server(model_path_or_server, server_name, server_port) + from lmdeploy.serve.gradio.triton_server_backend import \ + run_triton_server + run_triton_server(model_path_or_server, server_name, server_port) else: + from lmdeploy.serve.gradio.turbomind_coupled import run_local run_local(model_path_or_server, server_name, server_port, batch_size, tp) diff --git a/lmdeploy/serve/gradio/constants.py b/lmdeploy/serve/gradio/constants.py new file mode 100644 index 000000000..891c572e5 --- /dev/null +++ b/lmdeploy/serve/gradio/constants.py @@ -0,0 +1,28 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +import gradio as gr + +CSS = """ +#container { + width: 95%; + margin-left: auto; + margin-right: auto; +} + +#chatbot { + height: 500px; + overflow: auto; +} + +.chat_wrap_space { + margin-left: 0.5em +} +""" + +THEME = gr.themes.Soft( + primary_hue=gr.themes.colors.blue, + secondary_hue=gr.themes.colors.sky, + font=[gr.themes.GoogleFont('Inconsolata'), 'Arial', 'sans-serif']) + +enable_btn = gr.Button.update(interactive=True) +disable_btn = gr.Button.update(interactive=False) diff --git a/lmdeploy/serve/gradio/css.py b/lmdeploy/serve/gradio/css.py deleted file mode 100644 index b3bd23322..000000000 --- a/lmdeploy/serve/gradio/css.py +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. - -CSS = """ -#container { - width: 95%; - margin-left: auto; - margin-right: auto; -} - -#chatbot { - height: 500px; - overflow: auto; -} - -.chat_wrap_space { - margin-left: 0.5em -} -""" diff --git a/lmdeploy/serve/gradio/triton_server_backend.py b/lmdeploy/serve/gradio/triton_server_backend.py new file mode 100644 index 000000000..5936f4ba5 --- /dev/null +++ b/lmdeploy/serve/gradio/triton_server_backend.py @@ -0,0 +1,134 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import threading +from functools import partial +from typing import Sequence + +import gradio as gr + +from lmdeploy.serve.gradio.constants import CSS, THEME, disable_btn, enable_btn +from lmdeploy.serve.openai.api_server import ip2id +from lmdeploy.serve.turbomind.chatbot import Chatbot + + +def chat_stream(state_chatbot: Sequence, llama_chatbot: Chatbot, + cancel_btn: gr.Button, reset_btn: gr.Button, + request: gr.Request): + """Chat with AI assistant. + + Args: + instruction (str): user's prompt + state_chatbot (Sequence): the chatting history + llama_chatbot (Chatbot): the instance of a chatbot + cancel_btn (bool): enable the cancel button or not + reset_btn (bool): enable the reset button or not + request (gr.Request): the request from a user + """ + instruction = state_chatbot[-1][0] + session_id = threading.current_thread().ident + if request is not None: + session_id = ip2id(request.kwargs['client']['host']) + + bot_response = llama_chatbot.stream_infer( + session_id, instruction, f'{session_id}-{len(state_chatbot)}') + + for status, tokens, _ in bot_response: + state_chatbot[-1] = (state_chatbot[-1][0], tokens) + yield (state_chatbot, state_chatbot, enable_btn, disable_btn) + + yield (state_chatbot, state_chatbot, disable_btn, enable_btn) + + +def reset_all_func(instruction_txtbox: gr.Textbox, state_chatbot: gr.State, + llama_chatbot: gr.State, triton_server_addr: str, + model_name: str): + """reset the session.""" + state_chatbot = [] + log_level = os.environ.get('SERVICE_LOG_LEVEL', 'INFO') + llama_chatbot = Chatbot(triton_server_addr, + model_name, + log_level=log_level, + display=True) + + return ( + llama_chatbot, + state_chatbot, + state_chatbot, + gr.Textbox.update(value=''), + ) + + +def cancel_func( + state_chatbot: gr.State, + llama_chatbot: gr.State, + cancel_btn: gr.Button, + reset_btn: gr.Button, +): + """cancel the session.""" + yield (llama_chatbot, state_chatbot, disable_btn, disable_btn) + session_id = llama_chatbot._session.session_id + llama_chatbot.cancel(session_id) + + yield (llama_chatbot, state_chatbot, disable_btn, enable_btn) + + +def add_instruction(instruction, state_chatbot): + state_chatbot = state_chatbot + [(instruction, None)] + return ('', state_chatbot) + + +def run_triton_server(triton_server_addr: str, + server_name: str = 'localhost', + server_port: int = 6006): + """chat with AI assistant through web ui. + + Args: + triton_server_addr (str): the communication address of inference server + server_name (str): the ip address of gradio server + server_port (int): the port of gradio server + """ + with gr.Blocks(css=CSS, theme=THEME) as demo: + log_level = os.environ.get('SERVICE_LOG_LEVEL', 'INFO') + llama_chatbot = gr.State( + Chatbot(triton_server_addr, log_level=log_level, display=True)) + state_chatbot = gr.State([]) + model_name = llama_chatbot.value.model_name + reset_all = partial(reset_all_func, + model_name=model_name, + triton_server_addr=triton_server_addr) + + with gr.Column(elem_id='container'): + gr.Markdown('## LMDeploy Playground') + + chatbot = gr.Chatbot(elem_id='chatbot', label=model_name) + instruction_txtbox = gr.Textbox( + placeholder='Please input the instruction', + label='Instruction') + with gr.Row(): + cancel_btn = gr.Button(value='Cancel', interactive=False) + reset_btn = gr.Button(value='Reset') + + send_event = instruction_txtbox.submit( + add_instruction, [instruction_txtbox, state_chatbot], + [instruction_txtbox, state_chatbot]).then( + chat_stream, + [state_chatbot, llama_chatbot, cancel_btn, reset_btn], + [state_chatbot, chatbot, cancel_btn, reset_btn]) + + cancel_btn.click(cancel_func, + [state_chatbot, llama_chatbot, cancel_btn, reset_btn], + [llama_chatbot, chatbot, cancel_btn, reset_btn], + cancels=[send_event]) + + reset_btn.click( + reset_all, [instruction_txtbox, state_chatbot, llama_chatbot], + [llama_chatbot, state_chatbot, chatbot, instruction_txtbox], + cancels=[send_event]) + + print(f'server is gonna mount on: http://{server_name}:{server_port}') + demo.queue(concurrency_count=4, max_size=100, api_open=True).launch( + max_threads=10, + share=True, + server_port=server_port, + server_name=server_name, + ) diff --git a/lmdeploy/serve/gradio/turbomind_coupled.py b/lmdeploy/serve/gradio/turbomind_coupled.py new file mode 100644 index 000000000..d5cd59867 --- /dev/null +++ b/lmdeploy/serve/gradio/turbomind_coupled.py @@ -0,0 +1,194 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import threading +from typing import Sequence + +import gradio as gr + +from lmdeploy.serve.async_engine import AsyncEngine +from lmdeploy.serve.gradio.constants import CSS, THEME, disable_btn, enable_btn +from lmdeploy.serve.openai.api_server import ip2id + + +class InterFace: + async_engine: AsyncEngine = None + + +async def chat_stream_local( + instruction: str, + state_chatbot: Sequence, + cancel_btn: gr.Button, + reset_btn: gr.Button, + request: gr.Request, +): + """Chat with AI assistant. + + Args: + instruction (str): user's prompt + state_chatbot (Sequence): the chatting history + cancel_btn (bool): enable the cancel button or not + reset_btn (bool): enable the reset button or not + request (gr.Request): the request from a user + """ + session_id = threading.current_thread().ident + if request is not None: + session_id = ip2id(request.kwargs['client']['host']) + bot_summarized_response = '' + state_chatbot = state_chatbot + [(instruction, None)] + + yield (state_chatbot, state_chatbot, disable_btn, enable_btn, + f'{bot_summarized_response}'.strip()) + + async for outputs in InterFace.async_engine.generate( + instruction, + session_id, + stream_response=True, + sequence_start=(len(state_chatbot) == 1), + sequence_end=False): + response = outputs.response + if outputs.finish_reason == 'length': + gr.Warning('WARNING: exceed session max length.' + ' Please restart the session by reset button.') + if outputs.generate_token_len < 0: + gr.Warning('WARNING: running on the old session.' + ' Please restart the session by reset button.') + if state_chatbot[-1][-1] is None: + state_chatbot[-1] = (state_chatbot[-1][0], response) + else: + state_chatbot[-1] = (state_chatbot[-1][0], + state_chatbot[-1][1] + response + ) # piece by piece + yield (state_chatbot, state_chatbot, enable_btn, disable_btn, + f'{bot_summarized_response}'.strip()) + + yield (state_chatbot, state_chatbot, disable_btn, enable_btn, + f'{bot_summarized_response}'.strip()) + + +async def reset_local_func(instruction_txtbox: gr.Textbox, + state_chatbot: gr.State, request: gr.Request): + """reset the session. + + Args: + instruction_txtbox (str): user's prompt + state_chatbot (Sequence): the chatting history + request (gr.Request): the request from a user + """ + state_chatbot = [] + + session_id = threading.current_thread().ident + if request is not None: + session_id = ip2id(request.kwargs['client']['host']) + # end the session + async for out in InterFace.async_engine.generate('', + session_id, + request_output_len=1, + stream_response=True, + sequence_start=False, + sequence_end=True): + pass + + return ( + state_chatbot, + state_chatbot, + gr.Textbox.update(value=''), + ) + + +async def cancel_local_func(state_chatbot: gr.State, cancel_btn: gr.Button, + reset_btn: gr.Button, request: gr.Request): + """stop the session. + + Args: + state_chatbot (Sequence): the chatting history + cancel_btn (bool): enable the cancel button or not + reset_btn (bool): enable the reset button or not + request (gr.Request): the request from a user + """ + yield (state_chatbot, disable_btn, disable_btn) + session_id = threading.current_thread().ident + if request is not None: + session_id = ip2id(request.kwargs['client']['host']) + # end the session + async for out in InterFace.async_engine.generate('', + session_id, + request_output_len=0, + stream_response=True, + sequence_start=False, + sequence_end=False, + stop=True): + pass + messages = [] + for qa in state_chatbot: + messages.append(dict(role='user', content=qa[0])) + if qa[1] is not None: + messages.append(dict(role='assistant', content=qa[1])) + async for out in InterFace.async_engine.generate(messages, + session_id, + request_output_len=0, + stream_response=True, + sequence_start=True, + sequence_end=False): + pass + yield (state_chatbot, disable_btn, enable_btn) + + +def run_local(model_path: str, + server_name: str = 'localhost', + server_port: int = 6006, + batch_size: int = 4, + tp: int = 1): + """chat with AI assistant through web ui. + + Args: + model_path (str): the path of the deployed model + server_name (str): the ip address of gradio server + server_port (int): the port of gradio server + batch_size (int): batch size for running Turbomind directly + tp (int): tensor parallel for Turbomind + """ + InterFace.async_engine = AsyncEngine(model_path=model_path, + instance_num=batch_size, + tp=tp) + + with gr.Blocks(css=CSS, theme=THEME) as demo: + state_chatbot = gr.State([]) + + with gr.Column(elem_id='container'): + gr.Markdown('## LMDeploy Playground') + + chatbot = gr.Chatbot( + elem_id='chatbot', + label=InterFace.async_engine.tm_model.model_name) + instruction_txtbox = gr.Textbox( + placeholder='Please input the instruction', + label='Instruction') + with gr.Row(): + cancel_btn = gr.Button(value='Cancel', interactive=False) + reset_btn = gr.Button(value='Reset') + + send_event = instruction_txtbox.submit( + chat_stream_local, + [instruction_txtbox, state_chatbot, cancel_btn, reset_btn], + [state_chatbot, chatbot, cancel_btn, reset_btn]) + instruction_txtbox.submit( + lambda: gr.Textbox.update(value=''), + [], + [instruction_txtbox], + ) + cancel_btn.click(cancel_local_func, + [state_chatbot, cancel_btn, reset_btn], + [state_chatbot, cancel_btn, reset_btn], + cancels=[send_event]) + + reset_btn.click(reset_local_func, [instruction_txtbox, state_chatbot], + [state_chatbot, chatbot, instruction_txtbox], + cancels=[send_event]) + + print(f'server is gonna mount on: http://{server_name}:{server_port}') + demo.queue(concurrency_count=batch_size, max_size=100, + api_open=True).launch( + max_threads=10, + share=True, + server_port=server_port, + server_name=server_name, + ) diff --git a/lmdeploy/serve/openai/api_client.py b/lmdeploy/serve/openai/api_client.py index 26977bc6c..a1610e05e 100644 --- a/lmdeploy/serve/openai/api_client.py +++ b/lmdeploy/serve/openai/api_client.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import json -from typing import Iterable, List +from typing import Any, Dict, Iterable, List, Optional, Union import requests @@ -14,13 +14,306 @@ def get_model_list(api_url: str): return None +class APIClient: + """Chatbot for LLaMA series models with turbomind as inference engine. + + Args: + api_server_url (str): communicating address 'http://:' of + api_server + """ + + def __init__(self, api_server_url: str, **kwargs): + self.api_server_url = api_server_url + self.chat_intractive_v1_url = f'{api_server_url}/v1/chat/interactive' + self.chat_completions_v1_url = f'{api_server_url}/v1/chat/completions' + self.completions_v1_url = f'{api_server_url}/v1/completions' + self.models_v1_url = f'{api_server_url}/v1/models' + self._available_models = None + + @property + def available_models(self): + """Show available models.""" + if self._available_models is not None: + return self._available_models + response = requests.get(self.models_v1_url) + if hasattr(response, 'text'): + model_list = json.loads(response.text) + model_list = model_list.pop('data', []) + self._available_models = [item['id'] for item in model_list] + return self._available_models + return None + + def chat_completions_v1(self, + model: str, + messages: Union[str, List[Dict[str, str]]], + temperature: Optional[float] = 0.7, + top_p: Optional[float] = 1.0, + n: Optional[int] = 1, + max_tokens: Optional[int] = 512, + stop: Optional[bool] = False, + stream: Optional[bool] = False, + presence_penalty: Optional[float] = 0.0, + frequency_penalty: Optional[float] = 0.0, + user: Optional[str] = None, + repetition_penalty: Optional[float] = 1.0, + session_id: Optional[int] = -1, + ignore_eos: Optional[bool] = False, + **kwargs): + """Chat completion v1. + + Args: + model: model name. Available from self.available_models. + messages: string prompt or chat history in OpenAI format. + temperature (float): to modulate the next token probability + top_p (float): If set to float < 1, only the smallest set of most + probable tokens with probabilities that add up to top_p or + higher are kept for generation. + n (int): How many chat completion choices to generate for each + input message. Only support one here. + stream: whether to stream the results or not. Default to false. + max_tokens (int): output token nums + repetition_penalty (float): The parameter for repetition penalty. + 1.0 means no penalty + ignore_eos (bool): indicator for ignoring eos + session_id (int): if not specified, will set random value + + Yields: + json objects in openai formats + """ + pload = { + k: v + for k, v in locals().copy().items() + if k[:2] != '__' and k not in ['self'] + } + headers = {'content-type': 'application/json'} + response = requests.post(self.chat_completions_v1_url, + headers=headers, + json=pload, + stream=stream) + for chunk in response.iter_lines(chunk_size=8192, + decode_unicode=False, + delimiter=b'\n'): + if chunk: + if stream: + decoded = chunk.decode('utf-8') + if decoded == 'data: [DONE]': + continue + if decoded[:6] == 'data: ': + decoded = decoded[6:] + output = json.loads(decoded) + yield output + else: + decoded = chunk.decode('utf-8') + output = json.loads(decoded) + yield output + + def chat_interactive_v1(self, + prompt: Union[str, List[Dict[str, str]]], + session_id: int = -1, + interactive_mode: bool = False, + stream: bool = False, + stop: bool = False, + request_output_len: int = 512, + top_p: float = 0.8, + top_k: int = 40, + temperature: float = 0.8, + repetition_penalty: float = 1.0, + ignore_eos: bool = False, + **kwargs): + """Interactive completions. + + - On interactive mode, the chat history is kept on the server. Please + set `interactive_mode = True`. + - On normal mode, no chat history is kept on the server. Set + `interactive_mode = False`. + + Args: + prompt: the prompt to use for the generation. + session_id: determine which instance will be called. + If not specified with a value other than -1, using random value + directly. + interactive_mode (bool): turn on interactive mode or not. On + interactive mode, session history is kept on the server (and + vice versa). + stream: whether to stream the results or not. + stop: whether to stop the session response or not. + request_output_len (int): output token nums + top_p (float): If set to float < 1, only the smallest set of most + probable tokens with probabilities that add up to top_p or + higher are kept for generation. + top_k (int): The number of the highest probability vocabulary + tokens to keep for top-k-filtering + temperature (float): to modulate the next token probability + repetition_penalty (float): The parameter for repetition penalty. + 1.0 means no penalty + ignore_eos (bool): indicator for ignoring eos + + Yields: + json objects consist of text, tokens, finish_reason + """ + pload = { + k: v + for k, v in locals().copy().items() + if k[:2] != '__' and k not in ['self'] + } + headers = {'content-type': 'application/json'} + response = requests.post(self.chat_intractive_v1_url, + headers=headers, + json=pload, + stream=stream) + for chunk in response.iter_lines(chunk_size=8192, + decode_unicode=False, + delimiter=b'\n'): + if chunk: + decoded = chunk.decode('utf-8') + output = json.loads(decoded) + yield output + + def completions_v1( + self, + model: str, + prompt: Union[str, List[Any]], + suffix: Optional[str] = None, + temperature: Optional[float] = 0.7, + n: Optional[int] = 1, + max_tokens: Optional[int] = 16, + stream: Optional[bool] = False, + top_p: Optional[float] = 1.0, + user: Optional[str] = None, + # additional argument of lmdeploy + repetition_penalty: Optional[float] = 1.0, + session_id: Optional[int] = -1, + ignore_eos: Optional[bool] = False, + **kwargs): + """Chat completion v1. + + Args: + model (str): model name. Available from /v1/models. + prompt (str): the input prompt. + suffix (str): The suffix that comes after a completion of inserted + text. + max_tokens (int): output token nums + temperature (float): to modulate the next token probability + top_p (float): If set to float < 1, only the smallest set of most + probable tokens with probabilities that add up to top_p or + higher are kept for generation. + n (int): How many chat completion choices to generate for each + input message. Only support one here. + stream: whether to stream the results or not. Default to false. + repetition_penalty (float): The parameter for repetition penalty. + 1.0 means no penalty + user (str): A unique identifier representing your end-user. + ignore_eos (bool): indicator for ignoring eos + session_id (int): if not specified, will set random value + + Yields: + json objects in openai formats + """ + pload = { + k: v + for k, v in locals().copy().items() + if k[:2] != '__' and k not in ['self'] + } + headers = {'content-type': 'application/json'} + response = requests.post(self.completions_v1_url, + headers=headers, + json=pload, + stream=stream) + for chunk in response.iter_lines(chunk_size=8192, + decode_unicode=False, + delimiter=b'\n'): + if chunk: + if stream: + decoded = chunk.decode('utf-8')[6:] + if decoded == 'data: [DONE]': + continue + if decoded[:6] == 'data: ': + decoded = decoded[6:] + output = json.loads(decoded) + yield output + else: + decoded = chunk.decode('utf-8') + output = json.loads(decoded) + yield output + + def chat(self, + prompt: str, + session_id: int, + request_output_len: int = 512, + stream: bool = False, + top_p: float = 0.8, + top_k: int = 40, + temperature: float = 0.8, + repetition_penalty: float = 1.0, + ignore_eos: bool = False): + """Chat with a unique session_id. + + Args: + prompt: the prompt to use for the generation. + session_id: determine which instance will be called. + If not specified with a value other than -1, using random value + directly. + stream: whether to stream the results or not. + stop: whether to stop the session response or not. + request_output_len (int): output token nums + top_p (float): If set to float < 1, only the smallest set of most + probable tokens with probabilities that add up to top_p or + higher are kept for generation. + top_k (int): The number of the highest probability vocabulary + tokens to keep for top-k-filtering + temperature (float): to modulate the next token probability + repetition_penalty (float): The parameter for repetition penalty. + 1.0 means no penalty + ignore_eos (bool): indicator for ignoring eos + + Yields: + text, tokens, finish_reason + """ + assert session_id != -1, 'please set a value other than -1' + for outputs in self.chat_interactive_v1( + prompt, + session_id=session_id, + request_output_len=request_output_len, + interactive_mode=True, + stream=stream, + top_k=top_k, + top_p=top_p, + temperature=temperature, + repetition_penalty=repetition_penalty, + ignore_eos=ignore_eos): + if outputs['finish_reason'] == 'length': + print('WARNING: exceed session max length.' + ' Please end the session.') + yield outputs['text'], outputs['tokens'], outputs['finish_reason'] + + def end_session(self, session_id: int): + """End the session with a unique session_id. + + Args: + session_id: determine which instance will be called. + If not specified with a value other than -1, using random value + directly. + """ + for out in self.chat_interactive_v1(prompt='', + session_id=session_id, + request_output_len=0, + interactive_mode=False): + pass + + +def input_prompt(): + """Input a prompt in the consolo interface.""" + print('\ndouble enter to end input >>> ', end='') + sentinel = '' # ends when this string is seen + return '\n'.join(iter(input, sentinel)) + + def get_streaming_response(prompt: str, api_url: str, session_id: int, request_output_len: int = 512, stream: bool = True, - sequence_start: bool = True, - sequence_end: bool = True, + interactive_mode: bool = False, ignore_eos: bool = False, stop: bool = False) -> Iterable[List[str]]: headers = {'User-Agent': 'Test Client'} @@ -29,8 +322,7 @@ def get_streaming_response(prompt: str, 'stream': stream, 'session_id': session_id, 'request_output_len': request_output_len, - 'sequence_start': sequence_start, - 'sequence_end': sequence_end, + 'interactive_mode': interactive_mode, 'ignore_eos': ignore_eos, 'stop': stop } @@ -49,42 +341,23 @@ def get_streaming_response(prompt: str, yield output, tokens, finish_reason -def input_prompt(): - """Input a prompt in the consolo interface.""" - print('\ndouble enter to end input >>> ', end='') - sentinel = '' # ends when this string is seen - return '\n'.join(iter(input, sentinel)) - - -def main(restful_api_url: str, session_id: int = 0): - nth_round = 1 +def main(api_server_url: str, session_id: int = 0): + api_client = APIClient(api_server_url) while True: prompt = input_prompt() - if prompt == 'exit': - for output, tokens, finish_reason in get_streaming_response( - '', - f'{restful_api_url}/generate', - session_id=session_id, - request_output_len=0, - sequence_start=(nth_round == 1), - sequence_end=True): - pass - exit(0) + if prompt in ['exit', 'end']: + api_client.end_session(session_id) + if prompt == 'exit': + exit(0) else: - for output, tokens, finish_reason in get_streaming_response( + for text, tokens, finish_reason in api_client.chat( prompt, - f'{restful_api_url}/generate', session_id=session_id, request_output_len=512, - sequence_start=(nth_round == 1), - sequence_end=False): + stream=True): if finish_reason == 'length': - print('WARNING: exceed session max length.' - ' Please end the session.') continue - print(output, end='') - - nth_round += 1 + print(text, end='') if __name__ == '__main__': diff --git a/lmdeploy/serve/openai/api_server.py b/lmdeploy/serve/openai/api_server.py index 3ba8b80b4..97e5e518c 100644 --- a/lmdeploy/serve/openai/api_server.py +++ b/lmdeploy/serve/openai/api_server.py @@ -1,5 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. +import asyncio import os +import random import time from http import HTTPStatus from typing import AsyncGenerator, List, Optional @@ -13,8 +15,10 @@ from lmdeploy.serve.openai.protocol import ( # noqa: E501 ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, - ChatCompletionStreamResponse, ChatMessage, DeltaMessage, EmbeddingsRequest, - EmbeddingsResponse, ErrorResponse, GenerateRequest, GenerateResponse, + ChatCompletionStreamResponse, ChatMessage, CompletionRequest, + CompletionResponse, CompletionResponseChoice, + CompletionResponseStreamChoice, CompletionStreamResponse, DeltaMessage, + EmbeddingsRequest, ErrorResponse, GenerateRequest, GenerateResponse, ModelCard, ModelList, ModelPermission, UsageInfo) os.environ['TM_LOG_LEVEL'] = 'ERROR' @@ -104,9 +108,8 @@ async def chat_completions_v1(request: ChatCompletionRequest, 1.0 means no penalty Additional arguments supported by LMDeploy: - - renew_session (bool): Whether renew the session. Can be used when the - session length is exceeded. - ignore_eos (bool): indicator for ignoring eos + - session_id (int): if not specified, will set random value Currently we do not support the following features: - function_call (Users should implement this by themselves) @@ -114,20 +117,22 @@ async def chat_completions_v1(request: ChatCompletionRequest, - presence_penalty (replaced with repetition_penalty) - frequency_penalty (replaced with repetition_penalty) """ - session_id = ip2id(raw_request.client.host) + if request.session_id == -1: + request.session_id = random.randint(1, 10086) error_check_ret = await check_request(request) if error_check_ret is not None: return error_check_ret model_name = request.model - request_id = str(session_id) + request_id = str(request.session_id) created_time = int(time.time()) - result_generator = VariableInterface.async_engine.generate_openai( + result_generator = VariableInterface.async_engine.generate( request.messages, - session_id, + request.session_id, True, # always use stream to enable batching - request.renew_session, + sequence_start=True, + sequence_end=True, request_output_len=request.max_tokens if request.max_tokens else 512, stop=request.stop, top_p=request.top_p, @@ -188,7 +193,7 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: async for res in result_generator: if await raw_request.is_disconnected(): # Abort the request if the client disconnects. - VariableInterface.async_engine.stop_session(session_id) + VariableInterface.async_engine.stop_session(request.session_id) return create_error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected') final_res = res @@ -222,51 +227,191 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: return response -@app.post('/v1/embeddings') -async def create_embeddings(request: EmbeddingsRequest, - raw_request: Request = None): - """Creates embeddings for the text.""" +@app.post('/v1/completions') +async def completions_v1(request: CompletionRequest, + raw_request: Request = None): + """Completion API similar to OpenAI's API. + + Go to `https://platform.openai.com/docs/api-reference/completions/create` + for the API specification. + + The request should be a JSON object with the following fields: + - model (str): model name. Available from /v1/models. + - prompt (str): the input prompt. + - suffix (str): The suffix that comes after a completion of inserted text. + - max_tokens (int): output token nums + - temperature (float): to modulate the next token probability + - top_p (float): If set to float < 1, only the smallest set of most + probable tokens with probabilities that add up to top_p or higher + are kept for generation. + - n (int): How many chat completion choices to generate for each input + message. Only support one here. + - stream: whether to stream the results or not. Default to false. + - repetition_penalty (float): The parameter for repetition penalty. + 1.0 means no penalty + - user (str): A unique identifier representing your end-user. + + Additional arguments supported by LMDeploy: + - ignore_eos (bool): indicator for ignoring eos + - session_id (int): if not specified, will set random value + + Currently we do not support the following features: + - logprobs (not supported yet) + - presence_penalty (replaced with repetition_penalty) + - frequency_penalty (replaced with repetition_penalty) + """ + if request.session_id == -1: + request.session_id = random.randint(1, 10086) error_check_ret = await check_request(request) if error_check_ret is not None: return error_check_ret - if isinstance(request.input, str): - request.input = [request.input] - - data = [] - token_num = 0 - for i, prompt in enumerate(request.input): - embedding = await VariableInterface.async_engine.get_embeddings(prompt) - data.append({ - 'object': 'embedding', - 'embedding': embedding, - 'index': i - }) - token_num += len(embedding) - return EmbeddingsResponse( - data=data, - model=request.model, - usage=UsageInfo( - prompt_tokens=token_num, - total_tokens=token_num, - completion_tokens=None, - ), - ).dict(exclude_none=True) - - -@app.post('/generate') -async def generate(request: GenerateRequest, raw_request: Request = None): + + model_name = request.model + request_id = str(request.session_id) + created_time = int(time.time()) + if isinstance(request.prompt, str): + request.prompt = [request.prompt] + generators = [] + for i in range(len(request.prompt)): + result_generator = VariableInterface.async_engine.generate( + request.prompt[i], + request.session_id + i, + True, # always use stream to enable batching + sequence_start=True, + sequence_end=True, + request_output_len=request.max_tokens + if request.max_tokens else 512, + stop=False, + top_p=request.top_p, + temperature=request.temperature, + repetition_penalty=request.repetition_penalty, + ignore_eos=request.ignore_eos, + do_preprocess=False) + generators.append(result_generator) + + def create_stream_response_json( + index: int, + text: str, + finish_reason: Optional[str] = None, + ) -> str: + choice_data = CompletionResponseStreamChoice( + index=index, + text=text, + finish_reason=finish_reason, + ) + response = CompletionStreamResponse( + id=request_id, + created=created_time, + model=model_name, + choices=[choice_data], + ) + response_json = response.model_dump_json() + + return response_json + + async def completion_stream_generator() -> AsyncGenerator[str, None]: + # First chunk with role + for generator in generators: + for i in range(request.n): + choice_data = CompletionResponseStreamChoice( + index=i, + text='', + finish_reason=None, + ) + chunk = CompletionStreamResponse(id=request_id, + choices=[choice_data], + model=model_name) + data = chunk.model_dump_json(exclude_unset=True) + yield f'data: {data}\n\n' + + async for res in generator: + response_json = create_stream_response_json( + index=0, + text=res.response, + ) + yield f'data: {response_json}\n\n' + yield 'data: [DONE]\n\n' + + # Streaming response + if request.stream: + return StreamingResponse(completion_stream_generator(), + media_type='text/event-stream') + + # Non-streaming response + usage = UsageInfo() + choices = [] + + async def _inner_call(i, generator): + final_res = None + text = '' + async for res in generator: + if await raw_request.is_disconnected(): + # Abort the request if the client disconnects. + VariableInterface.async_engine.stop_session(request.session_id) + return create_error_response(HTTPStatus.BAD_REQUEST, + 'Client disconnected') + final_res = res + text += res.response + assert final_res is not None + choice_data = CompletionResponseChoice( + index=0, + text=text, + finish_reason=final_res.finish_reason, + ) + choices.append(choice_data) + + total_tokens = sum([ + final_res.history_token_len, final_res.input_token_len, + final_res.generate_token_len + ]) + usage.prompt_tokens += final_res.input_token_len + usage.completion_tokens += final_res.generate_token_len + usage.total_tokens += total_tokens + + await asyncio.gather( + *[_inner_call(i, generators[i]) for i in range(len(generators))]) + + response = CompletionResponse( + id=request_id, + created=created_time, + model=model_name, + choices=choices, + usage=usage, + ) + + return response + + +@app.post('/v1/embeddings', tags=['unsupported']) +async def create_embeddings(request: EmbeddingsRequest, + raw_request: Request = None): + """Creates embeddings for the text.""" + return create_error_response(HTTPStatus.BAD_REQUEST, + 'Unsupported by turbomind.') + + +@app.post('/generate', + tags=['deprecated'], + description='please use /v1/chat/interactive') +@app.post('/v1/chat/interactive') +async def chat_interactive_v1(request: GenerateRequest, + raw_request: Request = None): """Generate completion for the request. + - On interactive mode, the chat history is kept on the server. Please set + `interactive_mode = True`. + - On normal mode, no chat history is kept on the server. Set + `interactive_mode = False`. + The request should be a JSON object with the following fields: - prompt: the prompt to use for the generation. - session_id: determine which instance will be called. If not specified - with a value other than -1, using host ip directly. - - sequence_start (bool): indicator for starting a sequence. - - sequence_end (bool): indicator for ending a sequence + with a value other than -1, using random value directly. + - interactive_mode (bool): turn on interactive mode or not. On interactive + mode, session history is kept on the server (and vice versa). - stream: whether to stream the results or not. - stop: whether to stop the session response or not. - request_output_len (int): output token nums - - step (int): the offset of the k/v cache - top_p (float): If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. @@ -278,15 +423,18 @@ async def generate(request: GenerateRequest, raw_request: Request = None): - ignore_eos (bool): indicator for ignoring eos """ if request.session_id == -1: - session_id = ip2id(raw_request.client.host) - request.session_id = session_id + request.session_id = random.randint(10087, 23333) + + async_engine = VariableInterface.async_engine + sequence_start = async_engine.steps.get(str(request.session_id), 0) == 0 + sequence_end = not request.interactive_mode - generation = VariableInterface.async_engine.generate( + generation = async_engine.generate( request.prompt, request.session_id, stream_response=True, # always use stream to enable batching - sequence_start=request.sequence_start, - sequence_end=request.sequence_end, + sequence_start=sequence_start, + sequence_end=sequence_end, request_output_len=request.request_output_len, top_p=request.top_p, top_k=request.top_k, @@ -315,7 +463,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]: async for out in generation: if await raw_request.is_disconnected(): # Abort the request if the client disconnects. - VariableInterface.async_engine.stop_session(session_id) + async_engine.stop_session(request.session_id) return create_error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected') text += out.response @@ -326,14 +474,15 @@ async def stream_results() -> AsyncGenerator[bytes, None]: def main(model_path: str, - server_name: str = 'localhost', + server_name: str = '0.0.0.0', server_port: int = 23333, instance_num: int = 32, tp: int = 1, allow_origins: List[str] = ['*'], allow_credentials: bool = True, allow_methods: List[str] = ['*'], - allow_headers: List[str] = ['*']): + allow_headers: List[str] = ['*'], + **kwargs): """An example to perform model inference through the command line interface. @@ -359,7 +508,8 @@ def main(model_path: str, VariableInterface.async_engine = AsyncEngine(model_path=model_path, instance_num=instance_num, - tp=tp) + tp=tp, + **kwargs) uvicorn.run(app=app, host=server_name, port=server_port, log_level='info') diff --git a/lmdeploy/serve/openai/protocol.py b/lmdeploy/serve/openai/protocol.py index 78bf56531..b39e6cbc8 100644 --- a/lmdeploy/serve/openai/protocol.py +++ b/lmdeploy/serve/openai/protocol.py @@ -70,7 +70,9 @@ class ChatCompletionRequest(BaseModel): user: Optional[str] = None # additional argument of lmdeploy repetition_penalty: Optional[float] = 1.0 - renew_session: Optional[bool] = False + session_id: Optional[int] = -1 + renew_session: Optional[ + bool] = False # lagecy and useless, will be removed ignore_eos: Optional[bool] = False @@ -135,6 +137,10 @@ class CompletionRequest(BaseModel): presence_penalty: Optional[float] = 0.0 frequency_penalty: Optional[float] = 0.0 user: Optional[str] = None + # additional argument of lmdeploy + repetition_penalty: Optional[float] = 1.0 + session_id: Optional[int] = -1 + ignore_eos: Optional[bool] = False class CompletionResponseChoice(BaseModel): @@ -191,8 +197,7 @@ class GenerateRequest(BaseModel): """Generate request.""" prompt: Union[str, List[Dict[str, str]]] session_id: int = -1 - sequence_start: bool = True - sequence_end: bool = False + interactive_mode: bool = False stream: bool = False stop: bool = False request_output_len: int = 512 diff --git a/lmdeploy/turbomind/chat.py b/lmdeploy/turbomind/chat.py index bf0ce7399..69a1f768d 100644 --- a/lmdeploy/turbomind/chat.py +++ b/lmdeploy/turbomind/chat.py @@ -69,8 +69,9 @@ def get_gen_param(cap, def main(model_path, session_id: int = 1, cap: str = 'chat', - tp=1, - stream_output=True, + tp: int = 1, + stream_output: bool = True, + request_output_len: int = 512, **kwargs): """An example to perform model inference through the command line interface. @@ -106,12 +107,13 @@ def main(model_path, elif prompt == 'end': prompt = model.get_prompt('', nth_round == 1) input_ids = tokenizer.encode(prompt) - for outputs in generator.stream_infer(session_id=session_id, - input_ids=[input_ids], - request_output_len=512, - sequence_start=False, - sequence_end=True, - stream_output=stream_output): + for outputs in generator.stream_infer( + session_id=session_id, + input_ids=[input_ids], + request_output_len=request_output_len, + sequence_start=False, + sequence_end=True, + stream_output=stream_output): pass nth_round = 1 step = 0 @@ -119,13 +121,14 @@ def main(model_path, else: prompt = model.get_prompt(prompt, nth_round == 1) input_ids = tokenizer.encode(prompt) - if step + len(input_ids) >= tm_model.session_len: + if step + len( + input_ids) + request_output_len >= tm_model.session_len: print('WARNING: exceed session max length.' ' Please end the session.') continue gen_param = get_gen_param(cap, model.sampling_param, nth_round, - step, **kwargs) + step, request_output_len, **kwargs) print(f'{prompt} ', end='', flush=True) response_size = 0