From 4279d8ca44eca22994406345d91a31fb89b142a7 Mon Sep 17 00:00:00 2001 From: AllentDan <41138331+AllentDan@users.noreply.github.com> Date: Thu, 24 Aug 2023 19:35:25 +0800 Subject: [PATCH] Enable the Gradio server to call inference services through the RESTful API (#287) * app use async engine * add stop logic * app update cancel * app support restful-api * update doc and use the right model name * set doc url root * add comments * add an example * renew_session * update readme.md * resolve comments * Update restful_api.md * Update restful_api.md * Update restful_api.md --------- Co-authored-by: tpoisonooo --- README.md | 26 ++ README_zh-CN.md | 26 ++ docs/en/restful_api.md | 55 ++- docs/zh_cn/restful_api.md | 54 ++- lmdeploy/serve/{openai => }/async_engine.py | 86 +++-- lmdeploy/serve/gradio/app.py | 397 +++++++++++++++----- lmdeploy/serve/openai/api_client.py | 25 +- lmdeploy/serve/openai/api_server.py | 12 +- lmdeploy/serve/openai/protocol.py | 3 +- requirements.txt | 1 + 10 files changed, 515 insertions(+), 170 deletions(-) rename lmdeploy/serve/{openai => }/async_engine.py (77%) diff --git a/README.md b/README.md index d0252a239..4a4d4dd4a 100644 --- a/README.md +++ b/README.md @@ -133,6 +133,32 @@ python3 -m lmdeploy.serve.gradio.app ./workspace ![](https://github.com/InternLM/lmdeploy/assets/67539920/08d1e6f2-3767-44d5-8654-c85767cec2ab) +#### Serving with Restful API + +Launch inference server by: + +```shell +python3 -m lmdeploy.serve.openai.api_server ./workspace server_ip server_port --instance_num 32 --tp 1 +``` + +Then, you can communicate with it by command line, + +```shell +# restful_api_url is what printed in api_server.py, e.g. http://localhost:23333 +python -m lmdeploy.serve.openai.api_client restful_api_url +``` + +or webui, + +```shell +# restful_api_url is what printed in api_server.py, e.g. http://localhost:23333 +# server_ip and server_port here are for gradio ui +# example: python -m lmdeploy.serve.gradio.app http://localhost:23333 localhost 6006 --restful_api True +python -m lmdeploy.serve.gradio.app restful_api_url server_ip --restful_api True +``` + +Refer to [restful_api.md](docs/en/restful_api.md) for more details. + #### Serving with Triton Inference Server Launch inference server by: diff --git a/README_zh-CN.md b/README_zh-CN.md index d7c31c70d..3e649269e 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -133,6 +133,32 @@ python3 -m lmdeploy.serve.gradio.app ./workspace ![](https://github.com/InternLM/lmdeploy/assets/67539920/08d1e6f2-3767-44d5-8654-c85767cec2ab) +#### 通过 Restful API 部署服务 + +使用下面的命令启动推理服务: + +```shell +python3 -m lmdeploy.serve.openai.api_server ./workspace server_ip server_port --instance_num 32 --tp 1 +``` + +你可以通过命令行方式与推理服务进行对话: + +```shell +# restful_api_url is what printed in api_server.py, e.g. http://localhost:23333 +python -m lmdeploy.serve.openai.api_client restful_api_url +``` + +也可以通过 WebUI 方式来对话: + +```shell +# restful_api_url is what printed in api_server.py, e.g. http://localhost:23333 +# server_ip and server_port here are for gradio ui +# example: python -m lmdeploy.serve.gradio.app http://localhost:23333 localhost 6006 --restful_api True +python -m lmdeploy.serve.gradio.app restful_api_url server_ip --restful_api True +``` + +更多详情可以查阅 [restful_api.md](docs/zh_cn/restful_api.md)。 + #### 通过容器部署推理服务 使用下面的命令启动推理服务: diff --git a/docs/en/restful_api.md b/docs/en/restful_api.md index beff6c161..f55d66b2f 100644 --- a/docs/en/restful_api.md +++ b/docs/en/restful_api.md @@ -3,10 +3,10 @@ ### Launch Service ```shell -python3 -m lmdeploy.serve.openai.api_server ./workspace server_name server_port --instance_num 32 --tp 1 +python3 -m lmdeploy.serve.openai.api_server ./workspace 0.0.0.0 server_port --instance_num 32 --tp 1 ``` -Then, the user can open the swagger UI: http://{server_name}:{server_port}/docs for the detailed api usage. +Then, the user can open the swagger UI: `http://{server_ip}:{server_port}` for the detailed api usage. We provide four restful api in total. Three of them are in OpenAI format. However, we recommend users try our own api which provides more arguments for users to modify. The performance is comparatively better. @@ -50,16 +50,29 @@ def get_streaming_response(prompt: str, for output, tokens in get_streaming_response( - "Hi, how are you?", "http://{server_name}:{server_port}/generate", 0, + "Hi, how are you?", "http://{server_ip}:{server_port}/generate", 0, 512): print(output, end='') ``` -### Golang/Rust +### Java/Golang/Rust -Golang can also build a http request to use the service. You may refer -to [the blog](https://pkg.go.dev/net/http) for details to build own client. -Besides, Rust supports building a client in [many ways](https://blog.logrocket.com/best-rust-http-client/). +May use [openapi-generator-cli](https://github.com/OpenAPITools/openapi-generator-cli) to convert `http://{server_ip}:{server_port}/openapi.json` to java/rust/golang client. +Here is an example: + +```shell +$ docker run -it --rm -v ${PWD}:/local openapitools/openapi-generator-cli generate -i /local/openapi.json -g rust -o /local/rust + +$ ls rust/* +rust/Cargo.toml rust/git_push.sh rust/README.md + +rust/docs: +ChatCompletionRequest.md EmbeddingsRequest.md HttpValidationError.md LocationInner.md Prompt.md +DefaultApi.md GenerateRequest.md Input.md Messages.md ValidationError.md + +rust/src: +apis lib.rs models +``` ### cURL @@ -68,13 +81,13 @@ cURL is a tool for observing the output of the api. List Models: ```bash -curl http://{server_name}:{server_port}/v1/models +curl http://{server_ip}:{server_port}/v1/models ``` Generate: ```bash -curl http://{server_name}:{server_port}/generate \ +curl http://{server_ip}:{server_port}/generate \ -H "Content-Type: application/json" \ -d '{ "model": "internlm-chat-7b", @@ -87,7 +100,7 @@ curl http://{server_name}:{server_port}/generate \ Chat Completions: ```bash -curl http://{server_name}:{server_port}/v1/chat/completions \ +curl http://{server_ip}:{server_port}/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ "model": "internlm-chat-7b", @@ -98,7 +111,7 @@ curl http://{server_name}:{server_port}/v1/chat/completions \ Embeddings: ```bash -curl http://{server_name}:{server_port}/v1/embeddings \ +curl http://{server_ip}:{server_port}/v1/embeddings \ -H "Content-Type: application/json" \ -d '{ "model": "internlm-chat-7b", @@ -106,6 +119,26 @@ curl http://{server_name}:{server_port}/v1/embeddings \ }' ``` +### CLI client + +There is a client script for restful api server. + +```shell +# restful_api_url is what printed in api_server.py, e.g. http://localhost:23333 +python -m lmdeploy.serve.openai.api_client restful_api_url +``` + +### webui + +You can also test restful-api through webui. + +```shell +# restful_api_url is what printed in api_server.py, e.g. http://localhost:23333 +# server_ip and server_port here are for gradio ui +# example: python -m lmdeploy.serve.gradio.app http://localhost:23333 localhost 6006 --restful_api True +python -m lmdeploy.serve.gradio.app restful_api_url server_ip --restful_api True +``` + ### FAQ 1. When user got `"finish_reason":"length"` which means the session is too long to be continued. diff --git a/docs/zh_cn/restful_api.md b/docs/zh_cn/restful_api.md index 09c7c6467..fb796de56 100644 --- a/docs/zh_cn/restful_api.md +++ b/docs/zh_cn/restful_api.md @@ -5,10 +5,10 @@ 运行脚本 ```shell -python3 -m lmdeploy.serve.openai.api_server ./workspace server_name server_port --instance_num 32 --tp 1 +python3 -m lmdeploy.serve.openai.api_server ./workspace 0.0.0.0 server_port --instance_num 32 --tp 1 ``` -然后用户可以打开 swagger UI: http://{server_name}:{server_port}/docs 详细查看所有的 API 及其使用方法。 +然后用户可以打开 swagger UI: `http://{server_ip}:{server_port}` 详细查看所有的 API 及其使用方法。 我们一共提供四个 restful api,其中三个仿照 OpenAI 的形式。不过,我们建议用户用我们提供的另一个 API: `generate`。 它有更好的性能,提供更多的参数让用户自定义修改。 @@ -52,15 +52,29 @@ def get_streaming_response(prompt: str, for output, tokens in get_streaming_response( - "Hi, how are you?", "http://{server_name}:{server_port}/generate", 0, + "Hi, how are you?", "http://{server_ip}:{server_port}/generate", 0, 512): print(output, end='') ``` -### Golang/Rust +### Java/Golang/Rust -Golang 也可以建立 http 请求使用启动的服务,用户可以参考[这篇博客](https://pkg.go.dev/net/http)构建自己的客户端。 -Rust 也有许多[方法](https://blog.logrocket.com/best-rust-http-client/)构建客户端,使用服务。 +可以使用代码生成工具 [openapi-generator-cli](https://github.com/OpenAPITools/openapi-generator-cli) 将 `http://{server_ip}:{server_port}/openapi.json` 转成 java/rust/golang 客户端。 +下面是一个使用示例: + +```shell +$ docker run -it --rm -v ${PWD}:/local openapitools/openapi-generator-cli generate -i /local/openapi.json -g rust -o /local/rust + +$ ls rust/* +rust/Cargo.toml rust/git_push.sh rust/README.md + +rust/docs: +ChatCompletionRequest.md EmbeddingsRequest.md HttpValidationError.md LocationInner.md Prompt.md +DefaultApi.md GenerateRequest.md Input.md Messages.md ValidationError.md + +rust/src: +apis lib.rs models +``` ### cURL @@ -69,13 +83,13 @@ cURL 也可以用于查看 API 的输出结果 查看模型列表: ```bash -curl http://{server_name}:{server_port}/v1/models +curl http://{server_ip}:{server_port}/v1/models ``` 使用 generate: ```bash -curl http://{server_name}:{server_port}/generate \ +curl http://{server_ip}:{server_port}/generate \ -H "Content-Type: application/json" \ -d '{ "model": "internlm-chat-7b", @@ -88,7 +102,7 @@ curl http://{server_name}:{server_port}/generate \ Chat Completions: ```bash -curl http://{server_name}:{server_port}/v1/chat/completions \ +curl http://{server_ip}:{server_port}/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ "model": "internlm-chat-7b", @@ -99,7 +113,7 @@ curl http://{server_name}:{server_port}/v1/chat/completions \ Embeddings: ```bash -curl http://{server_name}:{server_port}/v1/embeddings \ +curl http://{server_ip}:{server_port}/v1/embeddings \ -H "Content-Type: application/json" \ -d '{ "model": "internlm-chat-7b", @@ -107,6 +121,26 @@ curl http://{server_name}:{server_port}/v1/embeddings \ }' ``` +### CLI client + +restful api 服务可以通过客户端测试,例如 + +```shell +# restful_api_url 就是 api_server 产生的,比如 http://localhost:23333 +python -m lmdeploy.serve.openai.api_client restful_api_url +``` + +### webui + +也可以直接用 webui 测试使用 restful-api。 + +```shell +# restful_api_url 就是 api_server 产生的,比如 http://localhost:23333 +# server_ip 和 server_port 是用来提供 gradio ui 访问服务的 +# 例子: python -m lmdeploy.serve.gradio.app http://localhost:23333 localhost 6006 --restful_api True +python -m lmdeploy.serve.gradio.app restful_api_url server_ip --restful_api True +``` + ### FAQ 1. 当返回结果结束原因为 `"finish_reason":"length"`,这表示回话长度超过最大值。 diff --git a/lmdeploy/serve/openai/async_engine.py b/lmdeploy/serve/async_engine.py similarity index 77% rename from lmdeploy/serve/openai/async_engine.py rename to lmdeploy/serve/async_engine.py index 7754c515f..8d7c8a027 100644 --- a/lmdeploy/serve/openai/async_engine.py +++ b/lmdeploy/serve/async_engine.py @@ -3,11 +3,11 @@ import dataclasses import os.path as osp import random +from contextlib import contextmanager from typing import Literal, Optional from lmdeploy import turbomind as tm from lmdeploy.model import MODELS, BaseModel -from lmdeploy.turbomind.tokenizer import Tokenizer @dataclasses.dataclass @@ -30,6 +30,7 @@ class AsyncEngine: """ def __init__(self, model_path, instance_num=32, tp=1) -> None: + from lmdeploy.turbomind.tokenizer import Tokenizer tokenizer_model_path = osp.join(model_path, 'triton_models', 'tokenizer') tokenizer = Tokenizer(tokenizer_model_path) @@ -46,15 +47,22 @@ def __init__(self, model_path, instance_num=32, tp=1) -> None: self.starts = [None] * instance_num self.steps = {} + @contextmanager + def safe_run(self, instance_id: int, stop: bool = False): + self.available[instance_id] = False + yield + self.available[instance_id] = True + async def get_embeddings(self, prompt): prompt = self.model.get_prompt(prompt) input_ids = self.tokenizer.encode(prompt) return input_ids - async def get_generator(self, instance_id): + async def get_generator(self, instance_id: int, stop: bool = False): """Only return the model instance if it is available.""" - while self.available[instance_id] is False: - await asyncio.sleep(0.1) + if not stop: + while self.available[instance_id] is False: + await asyncio.sleep(0.1) return self.generators[instance_id] async def generate( @@ -104,43 +112,43 @@ async def generate( prompt = self.model.messages2prompt(messages, sequence_start) input_ids = self.tokenizer.encode(prompt) finish_reason = 'stop' if stop else None - if self.steps[str(session_id)] + len( + if not sequence_end and self.steps[str(session_id)] + len( input_ids) >= self.tm_model.session_len: finish_reason = 'length' yield GenOut('', self.steps[str(session_id)], len(input_ids), 0, finish_reason) else: - generator = await self.get_generator(instance_id) - self.available[instance_id] = False - response_size = 0 - async for outputs in generator.async_stream_infer( - session_id=session_id, - input_ids=[input_ids], - stream_output=stream_response, - request_output_len=request_output_len, - sequence_start=(sequence_start), - sequence_end=sequence_end, - step=self.steps[str(session_id)], - stop=stop, - top_k=top_k, - top_p=top_p, - temperature=temperature, - repetition_penalty=repetition_penalty, - ignore_eos=ignore_eos, - random_seed=seed if sequence_start else None): - res, tokens = outputs[0] - # decode res - response = self.tokenizer.decode(res[response_size:]) - # response, history token len, input token len, gen token len - yield GenOut(response, self.steps[str(session_id)], - len(input_ids), tokens, finish_reason) - response_size = tokens + generator = await self.get_generator(instance_id, stop) + with self.safe_run(instance_id): + response_size = 0 + async for outputs in generator.async_stream_infer( + session_id=session_id, + input_ids=[input_ids], + stream_output=stream_response, + request_output_len=request_output_len, + sequence_start=(sequence_start), + sequence_end=sequence_end, + step=self.steps[str(session_id)], + stop=stop, + top_k=top_k, + top_p=top_p, + temperature=temperature, + repetition_penalty=repetition_penalty, + ignore_eos=ignore_eos, + random_seed=seed if sequence_start else None): + res, tokens = outputs[0] + # decode res + response = self.tokenizer.decode(res)[response_size:] + # response, history token len, + # input token len, gen token len + yield GenOut(response, self.steps[str(session_id)], + len(input_ids), tokens, finish_reason) + response_size += len(response) - # update step - self.steps[str(session_id)] += len(input_ids) + tokens - if sequence_end: - self.steps[str(session_id)] = 0 - self.available[instance_id] = True + # update step + self.steps[str(session_id)] += len(input_ids) + tokens + if sequence_end or stop: + self.steps[str(session_id)] = 0 async def generate_openai( self, @@ -180,13 +188,11 @@ async def generate_openai( sequence_start = False generator = await self.get_generator(instance_id) self.available[instance_id] = False - if renew_session and str(session_id) in self.steps and self.steps[str( - session_id)] > 0: # renew a session - empty_prompt = self.model.messages2prompt('', False) - empty_input_ids = self.tokenizer.encode(empty_prompt) + if renew_session: # renew a session + empty_input_ids = self.tokenizer.encode('') for outputs in generator.stream_infer(session_id=session_id, input_ids=[empty_input_ids], - request_output_len=1, + request_output_len=0, sequence_start=False, sequence_end=True): pass diff --git a/lmdeploy/serve/gradio/app.py b/lmdeploy/serve/gradio/app.py index c3278eaba..954a5bcd3 100644 --- a/lmdeploy/serve/gradio/app.py +++ b/lmdeploy/serve/gradio/app.py @@ -1,17 +1,17 @@ # Copyright (c) OpenMMLab. All rights reserved. import os -import os.path as osp -import random import threading +import time from functools import partial from typing import Sequence import fire import gradio as gr -from lmdeploy import turbomind as tm -from lmdeploy.model import MODELS +from lmdeploy.serve.async_engine import AsyncEngine from lmdeploy.serve.gradio.css import CSS +from lmdeploy.serve.openai.api_client import (get_model_list, + get_streaming_response) from lmdeploy.serve.turbomind.chatbot import Chatbot THEME = gr.themes.Soft( @@ -19,6 +19,9 @@ secondary_hue=gr.themes.colors.sky, font=[gr.themes.GoogleFont('Inconsolata'), 'Arial', 'sans-serif']) +enable_btn = gr.Button.update(interactive=True) +disable_btn = gr.Button.update(interactive=False) + def chat_stream(state_chatbot: Sequence, llama_chatbot: Chatbot, request: gr.Request): @@ -141,21 +144,17 @@ def run_server(triton_server_addr: str, ) -# a IO interface mananing global variables +# a IO interface mananing variables class InterFace: - tokenizer_model_path = None - tokenizer = None - tm_model = None - request2instance = None - model_name = None - model = None + async_engine: AsyncEngine = None # for run_local + restful_api_url: str = None # for run_restful -def chat_stream_local( +def chat_stream_restful( instruction: str, state_chatbot: Sequence, - step: gr.State, - nth_round: gr.State, + cancel_btn: gr.Button, + reset_btn: gr.Button, request: gr.Request, ): """Chat with AI assistant. @@ -163,173 +162,379 @@ def chat_stream_local( Args: instruction (str): user's prompt state_chatbot (Sequence): the chatting history - step (gr.State): chat history length - nth_round (gr.State): round num request (gr.Request): the request from a user """ - from lmdeploy.turbomind.chat import valid_str session_id = threading.current_thread().ident if request is not None: session_id = int(request.kwargs['client']['host'].replace('.', '')) - if str(session_id) not in InterFace.request2instance: - InterFace.request2instance[str( - session_id)] = InterFace.tm_model.create_instance() - llama_chatbot = InterFace.request2instance[str(session_id)] - seed = random.getrandbits(64) bot_summarized_response = '' state_chatbot = state_chatbot + [(instruction, None)] - instruction = InterFace.model.get_prompt(instruction, nth_round == 1) - if step >= InterFace.tm_model.session_len: - raise gr.Error('WARNING: exceed session max length.' - ' Please end the session.') - input_ids = InterFace.tokenizer.encode(instruction) - bot_response = llama_chatbot.stream_infer( - session_id, [input_ids], - stream_output=True, - request_output_len=512, - sequence_start=(nth_round == 1), - sequence_end=False, - step=step, - stop=False, - top_k=40, - top_p=0.8, - temperature=0.8, - repetition_penalty=1.0, - ignore_eos=False, - random_seed=seed if nth_round == 1 else None) - - yield (state_chatbot, state_chatbot, step, nth_round, + + yield (state_chatbot, state_chatbot, disable_btn, enable_btn, + f'{bot_summarized_response}'.strip()) + + for response, tokens, finish_reason in get_streaming_response( + instruction, + f'{InterFace.restful_api_url}/generate', + instance_id=session_id, + request_output_len=512, + sequence_start=(len(state_chatbot) == 1), + sequence_end=False): + if finish_reason == 'length': + gr.Warning('WARNING: exceed session max length.' + ' Please restart the session by reset button.') + if tokens < 0: + gr.Warning('WARNING: running on the old session.' + ' Please restart the session by reset button.') + if state_chatbot[-1][-1] is None: + state_chatbot[-1] = (state_chatbot[-1][0], response) + else: + state_chatbot[-1] = (state_chatbot[-1][0], + state_chatbot[-1][1] + response + ) # piece by piece + yield (state_chatbot, state_chatbot, enable_btn, disable_btn, + f'{bot_summarized_response}'.strip()) + + yield (state_chatbot, state_chatbot, disable_btn, enable_btn, + f'{bot_summarized_response}'.strip()) + + +def reset_restful_func(instruction_txtbox: gr.Textbox, state_chatbot: gr.State, + request: gr.Request): + """reset the session. + + Args: + instruction_txtbox (str): user's prompt + state_chatbot (Sequence): the chatting history + request (gr.Request): the request from a user + """ + state_chatbot = [] + + session_id = threading.current_thread().ident + if request is not None: + session_id = int(request.kwargs['client']['host'].replace('.', '')) + # end the session + for response, tokens, finish_reason in get_streaming_response( + '', + f'{InterFace.restful_api_url}/generate', + instance_id=session_id, + request_output_len=0, + sequence_start=False, + sequence_end=True): + pass + + return ( + state_chatbot, + state_chatbot, + gr.Textbox.update(value=''), + ) + + +def cancel_restful_func(state_chatbot: gr.State, cancel_btn: gr.Button, + reset_btn: gr.Button, request: gr.Request): + """stop the session. + + Args: + instruction_txtbox (str): user's prompt + state_chatbot (Sequence): the chatting history + request (gr.Request): the request from a user + """ + session_id = threading.current_thread().ident + if request is not None: + session_id = int(request.kwargs['client']['host'].replace('.', '')) + # end the session + for out in get_streaming_response('', + f'{InterFace.restful_api_url}/generate', + instance_id=session_id, + request_output_len=0, + sequence_start=False, + sequence_end=False, + stop=True): + pass + time.sleep(0.5) + messages = [] + for qa in state_chatbot: + messages.append(dict(role='user', content=qa[0])) + if qa[1] is not None: + messages.append(dict(role='assistant', content=qa[1])) + for out in get_streaming_response(messages, + f'{InterFace.restful_api_url}/generate', + instance_id=session_id, + request_output_len=0, + sequence_start=True, + sequence_end=False): + pass + return (state_chatbot, disable_btn, enable_btn) + + +def run_restful(restful_api_url: str, + server_name: str = 'localhost', + server_port: int = 6006, + batch_size: int = 32): + """chat with AI assistant through web ui. + + Args: + restful_api_url (str): restufl api url + server_name (str): the ip address of gradio server + server_port (int): the port of gradio server + batch_size (int): batch size for running Turbomind directly + """ + InterFace.restful_api_url = restful_api_url + model_names = get_model_list(f'{restful_api_url}/v1/models') + model_name = '' + if isinstance(model_names, list) and len(model_names) > 0: + model_name = model_names[0] + else: + raise ValueError('gradio can find a suitable model from restful-api') + + with gr.Blocks(css=CSS, theme=THEME) as demo: + state_chatbot = gr.State([]) + + with gr.Column(elem_id='container'): + gr.Markdown('## LMDeploy Playground') + + chatbot = gr.Chatbot(elem_id='chatbot', label=model_name) + instruction_txtbox = gr.Textbox( + placeholder='Please input the instruction', + label='Instruction') + with gr.Row(): + cancel_btn = gr.Button(value='Cancel', interactive=False) + reset_btn = gr.Button(value='Reset') + + send_event = instruction_txtbox.submit( + chat_stream_restful, + [instruction_txtbox, state_chatbot, cancel_btn, reset_btn], + [state_chatbot, chatbot, cancel_btn, reset_btn]) + instruction_txtbox.submit( + lambda: gr.Textbox.update(value=''), + [], + [instruction_txtbox], + ) + cancel_btn.click(cancel_restful_func, + [state_chatbot, cancel_btn, reset_btn], + [state_chatbot, cancel_btn, reset_btn], + cancels=[send_event]) + + reset_btn.click(reset_restful_func, + [instruction_txtbox, state_chatbot], + [state_chatbot, chatbot, instruction_txtbox], + cancels=[send_event]) + + print(f'server is gonna mount on: http://{server_name}:{server_port}') + demo.queue(concurrency_count=batch_size, max_size=100, + api_open=True).launch( + max_threads=10, + share=True, + server_port=server_port, + server_name=server_name, + ) + + +async def chat_stream_local( + instruction: str, + state_chatbot: Sequence, + cancel_btn: gr.Button, + reset_btn: gr.Button, + request: gr.Request, +): + """Chat with AI assistant. + + Args: + instruction (str): user's prompt + state_chatbot (Sequence): the chatting history + request (gr.Request): the request from a user + """ + session_id = threading.current_thread().ident + if request is not None: + session_id = int(request.kwargs['client']['host'].replace('.', '')) + bot_summarized_response = '' + state_chatbot = state_chatbot + [(instruction, None)] + + yield (state_chatbot, state_chatbot, disable_btn, enable_btn, f'{bot_summarized_response}'.strip()) - response_size = 0 - for outputs in bot_response: - res, tokens = outputs[0] - # decode res - response = InterFace.tokenizer.decode(res)[response_size:] - response = valid_str(response) - response_size += len(response) + async for outputs in InterFace.async_engine.generate( + instruction, + session_id, + stream_response=True, + sequence_start=(len(state_chatbot) == 1)): + response = outputs.response + if outputs.finish_reason == 'length': + gr.Warning('WARNING: exceed session max length.' + ' Please restart the session by reset button.') + if outputs.generate_token_len < 0: + gr.Warning('WARNING: running on the old session.' + ' Please restart the session by reset button.') if state_chatbot[-1][-1] is None: state_chatbot[-1] = (state_chatbot[-1][0], response) else: state_chatbot[-1] = (state_chatbot[-1][0], state_chatbot[-1][1] + response ) # piece by piece - yield (state_chatbot, state_chatbot, step, nth_round, + yield (state_chatbot, state_chatbot, enable_btn, disable_btn, f'{bot_summarized_response}'.strip()) - step += len(input_ids) + tokens - nth_round += 1 - yield (state_chatbot, state_chatbot, step, nth_round, + yield (state_chatbot, state_chatbot, disable_btn, enable_btn, f'{bot_summarized_response}'.strip()) -def reset_local_func(instruction_txtbox: gr.Textbox, state_chatbot: gr.State, - step: gr.State, nth_round: gr.State, request: gr.Request): +async def reset_local_func(instruction_txtbox: gr.Textbox, + state_chatbot: gr.State, request: gr.Request): """reset the session. Args: instruction_txtbox (str): user's prompt state_chatbot (Sequence): the chatting history - step (gr.State): chat history length - nth_round (gr.State): round num request (gr.Request): the request from a user """ state_chatbot = [] - step = 0 - nth_round = 1 session_id = threading.current_thread().ident if request is not None: session_id = int(request.kwargs['client']['host'].replace('.', '')) - InterFace.request2instance[str( - session_id)] = InterFace.tm_model.create_instance() + # end the session + async for out in InterFace.async_engine.generate('', + session_id, + request_output_len=1, + stream_response=True, + sequence_start=False, + sequence_end=True): + pass return ( state_chatbot, state_chatbot, - step, - nth_round, gr.Textbox.update(value=''), ) +async def cancel_local_func(state_chatbot: gr.State, cancel_btn: gr.Button, + reset_btn: gr.Button, request: gr.Request): + """stop the session. + + Args: + instruction_txtbox (str): user's prompt + state_chatbot (Sequence): the chatting history + request (gr.Request): the request from a user + """ + session_id = threading.current_thread().ident + if request is not None: + session_id = int(request.kwargs['client']['host'].replace('.', '')) + # end the session + async for out in InterFace.async_engine.generate('', + session_id, + request_output_len=0, + stream_response=True, + sequence_start=False, + sequence_end=False, + stop=True): + pass + messages = [] + for qa in state_chatbot: + messages.append(dict(role='user', content=qa[0])) + if qa[1] is not None: + messages.append(dict(role='assistant', content=qa[1])) + async for out in InterFace.async_engine.generate(messages, + session_id, + request_output_len=0, + stream_response=True, + sequence_start=True, + sequence_end=False): + pass + return (state_chatbot, disable_btn, enable_btn) + + def run_local(model_path: str, server_name: str = 'localhost', - server_port: int = 6006): + server_port: int = 6006, + batch_size: int = 4, + tp: int = 1): """chat with AI assistant through web ui. Args: model_path (str): the path of the deployed model server_name (str): the ip address of gradio server server_port (int): the port of gradio server + batch_size (int): batch size for running Turbomind directly + tp (int): tensor parallel for Turbomind """ - from lmdeploy.turbomind.tokenizer import Tokenizer - InterFace.tokenizer_model_path = osp.join(model_path, 'triton_models', - 'tokenizer') - InterFace.tokenizer = Tokenizer(InterFace.tokenizer_model_path) - InterFace.tm_model = tm.TurboMind(model_path, - eos_id=InterFace.tokenizer.eos_token_id) - InterFace.request2instance = dict() - InterFace.model_name = InterFace.tm_model.model_name - InterFace.model = MODELS.get(InterFace.model_name)() + InterFace.async_engine = AsyncEngine(model_path=model_path, + instance_num=batch_size, + tp=tp) with gr.Blocks(css=CSS, theme=THEME) as demo: state_chatbot = gr.State([]) - nth_round = gr.State(1) - step = gr.State(0) with gr.Column(elem_id='container'): gr.Markdown('## LMDeploy Playground') - chatbot = gr.Chatbot(elem_id='chatbot', label=InterFace.model_name) + chatbot = gr.Chatbot( + elem_id='chatbot', + label=InterFace.async_engine.tm_model.model_name) instruction_txtbox = gr.Textbox( placeholder='Please input the instruction', label='Instruction') with gr.Row(): - gr.Button(value='Cancel') # noqa: E501 + cancel_btn = gr.Button(value='Cancel', interactive=False) reset_btn = gr.Button(value='Reset') send_event = instruction_txtbox.submit( chat_stream_local, - [instruction_txtbox, state_chatbot, step, nth_round], - [state_chatbot, chatbot, step, nth_round]) + [instruction_txtbox, state_chatbot, cancel_btn, reset_btn], + [state_chatbot, chatbot, cancel_btn, reset_btn]) instruction_txtbox.submit( lambda: gr.Textbox.update(value=''), [], [instruction_txtbox], ) + cancel_btn.click(cancel_local_func, + [state_chatbot, cancel_btn, reset_btn], + [state_chatbot, cancel_btn, reset_btn], + cancels=[send_event]) - reset_btn.click( - reset_local_func, - [instruction_txtbox, state_chatbot, step, nth_round], - [state_chatbot, chatbot, step, nth_round, instruction_txtbox], - cancels=[send_event]) + reset_btn.click(reset_local_func, [instruction_txtbox, state_chatbot], + [state_chatbot, chatbot, instruction_txtbox], + cancels=[send_event]) print(f'server is gonna mount on: http://{server_name}:{server_port}') - demo.queue(concurrency_count=4, max_size=100, api_open=True).launch( - max_threads=10, - share=True, - server_port=server_port, - server_name=server_name, - ) + demo.queue(concurrency_count=batch_size, max_size=100, + api_open=True).launch( + max_threads=10, + share=True, + server_port=server_port, + server_name=server_name, + ) def run(model_path_or_server: str, server_name: str = 'localhost', - server_port: int = 6006): + server_port: int = 6006, + batch_size: int = 32, + tp: int = 1, + restful_api: bool = False): """chat with AI assistant through web ui. Args: model_path_or_server (str): the path of the deployed model or the - tritonserver URL. The former is for directly running service with - gradio. The latter is for running with tritonserver + tritonserver URL or restful api URL. The former is for directly + running service with gradio. The latter is for running with + tritonserver by default. If the input URL is restful api. Please + enable another flag `restful_api`. server_name (str): the ip address of gradio server server_port (int): the port of gradio server + batch_size (int): batch size for running Turbomind directly + tp (int): tensor parallel for Turbomind + restufl_api (bool): a flag for model_path_or_server """ if ':' in model_path_or_server: - run_server(model_path_or_server, server_name, server_port) + if restful_api: + run_restful(model_path_or_server, server_name, server_port, + batch_size) + else: + run_server(model_path_or_server, server_name, server_port) else: - run_local(model_path_or_server, server_name, server_port) + run_local(model_path_or_server, server_name, server_port, batch_size, + tp) if __name__ == '__main__': diff --git a/lmdeploy/serve/openai/api_client.py b/lmdeploy/serve/openai/api_client.py index 3199f5148..449b8a294 100644 --- a/lmdeploy/serve/openai/api_client.py +++ b/lmdeploy/serve/openai/api_client.py @@ -6,6 +6,15 @@ import requests +def get_model_list(api_url: str): + response = requests.get(api_url) + if hasattr(response, 'text'): + model_list = json.loads(response.text) + model_list = model_list.pop('data', []) + return [item['id'] for item in model_list] + return None + + def get_streaming_response(prompt: str, api_url: str, instance_id: int, @@ -13,7 +22,8 @@ def get_streaming_response(prompt: str, stream: bool = True, sequence_start: bool = True, sequence_end: bool = True, - ignore_eos: bool = False) -> Iterable[List[str]]: + ignore_eos: bool = False, + stop: bool = False) -> Iterable[List[str]]: headers = {'User-Agent': 'Test Client'} pload = { 'prompt': prompt, @@ -22,7 +32,8 @@ def get_streaming_response(prompt: str, 'request_output_len': request_output_len, 'sequence_start': sequence_start, 'sequence_end': sequence_end, - 'ignore_eos': ignore_eos + 'ignore_eos': ignore_eos, + 'stop': stop } response = requests.post(api_url, headers=headers, @@ -33,9 +44,9 @@ def get_streaming_response(prompt: str, delimiter=b'\0'): if chunk: data = json.loads(chunk.decode('utf-8')) - output = data['text'] - tokens = data['tokens'] - finish_reason = data['finish_reason'] + output = data.pop('text', '') + tokens = data.pop('tokens', 0) + finish_reason = data.pop('finish_reason', None) yield output, tokens, finish_reason @@ -46,7 +57,7 @@ def input_prompt(): return '\n'.join(iter(input, sentinel)) -def main(server_name: str, server_port: int, session_id: int = 0): +def main(restful_api_url: str, session_id: int = 0): nth_round = 1 while True: prompt = input_prompt() @@ -55,7 +66,7 @@ def main(server_name: str, server_port: int, session_id: int = 0): else: for output, tokens, finish_reason in get_streaming_response( prompt, - f'http://{server_name}:{server_port}/generate', + f'{restful_api_url}/generate', instance_id=session_id, request_output_len=512, sequence_start=(nth_round == 1), diff --git a/lmdeploy/serve/openai/api_server.py b/lmdeploy/serve/openai/api_server.py index 0e257a71d..932372bde 100644 --- a/lmdeploy/serve/openai/api_server.py +++ b/lmdeploy/serve/openai/api_server.py @@ -10,7 +10,7 @@ from fastapi import BackgroundTasks, FastAPI, Request from fastapi.responses import JSONResponse, StreamingResponse -from lmdeploy.serve.openai.async_engine import AsyncEngine +from lmdeploy.serve.async_engine import AsyncEngine from lmdeploy.serve.openai.protocol import ( # noqa: E501 ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, @@ -27,7 +27,7 @@ class VariableInterface: request_hosts = [] -app = FastAPI() +app = FastAPI(docs_url='/') def get_model_list(): @@ -253,11 +253,12 @@ async def generate(request: GenerateRequest, raw_request: Request = None): The request should be a JSON object with the following fields: - prompt: the prompt to use for the generation. - - stream: whether to stream the results or not. - - sequence_start (bool): indicator for starting a sequence. - - sequence_end (bool): indicator for ending a sequence - instance_id: determine which instance will be called. If not specified with a value other than -1, using host ip directly. + - sequence_start (bool): indicator for starting a sequence. + - sequence_end (bool): indicator for ending a sequence + - stream: whether to stream the results or not. + - stop: whether to stop the session response or not. - request_output_len (int): output token nums - step (int): the offset of the k/v cache - top_p (float): If set to float < 1, only the smallest set of most @@ -283,6 +284,7 @@ async def generate(request: GenerateRequest, raw_request: Request = None): request_output_len=request.request_output_len, top_p=request.top_p, top_k=request.top_k, + stop=request.stop, temperature=request.temperature, repetition_penalty=request.repetition_penalty, ignore_eos=request.ignore_eos) diff --git a/lmdeploy/serve/openai/protocol.py b/lmdeploy/serve/openai/protocol.py index dea090d70..8d5b38757 100644 --- a/lmdeploy/serve/openai/protocol.py +++ b/lmdeploy/serve/openai/protocol.py @@ -189,11 +189,12 @@ class EmbeddingsResponse(BaseModel): class GenerateRequest(BaseModel): """Generate request.""" - prompt: str + prompt: Union[str, List[Dict[str, str]]] instance_id: int = -1 sequence_start: bool = True sequence_end: bool = False stream: bool = False + stop: bool = False request_output_len: int = 512 top_p: float = 0.8 top_k: int = 40 diff --git a/requirements.txt b/requirements.txt index 66f04c739..c0cd48396 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,6 +9,7 @@ pybind11 safetensors sentencepiece setuptools +shortuuid tiktoken torch transformers