diff --git a/README.md b/README.md
index ab406c4d68..11937237cf 100644
--- a/README.md
+++ b/README.md
@@ -50,11 +50,9 @@ And the request throughput of TurboMind is 30% higher than vLLM.
### Installation
-Below are quick steps for installation:
+Install lmdeploy with pip ( python 3.8+) or [from source](./docs/en/build.md)
```shell
-conda create -n lmdeploy python=3.10 -y
-conda activate lmdeploy
pip install lmdeploy
```
@@ -92,7 +90,15 @@ python -m lmdeploy.turbomind.chat ./workspace
> **Note**
> Tensor parallel is available to perform inference on multiple GPUs. Add `--tp=` on `chat` to enable runtime TP.
-#### Serving
+#### Serving with gradio
+
+```shell
+python3 -m lmdeploy.serve.gradio.app ./workspace
+```
+
+
+
+#### Serving with Triton Inference Server
Launch inference server by:
@@ -109,11 +115,9 @@ python3 -m lmdeploy.serve.client {server_ip_addresss}:33337
or webui,
```shell
-python3 -m lmdeploy.app {server_ip_addresss}:33337
+python3 -m lmdeploy.serve.gradio.app {server_ip_addresss}:33337
```
-
-
For the deployment of other supported models, such as LLaMA, LLaMA-2, vicuna and so on, you can find the guide from [here](docs/en/serving.md)
### Inference with PyTorch
diff --git a/README_zh-CN.md b/README_zh-CN.md
index 8f6b480d5b..af594ea9e8 100644
--- a/README_zh-CN.md
+++ b/README_zh-CN.md
@@ -51,9 +51,9 @@ TurboMind 的 output token throughput 超过 2000 token/s, 整体比 DeepSpeed
### 安装
+使用 pip ( python 3.8+) 安装 LMDeploy,或者[源码安装](./docs/zh_cn/build.md)
+
```shell
-conda create -n lmdeploy python=3.10 -y
-conda activate lmdeploy
pip install lmdeploy
```
@@ -90,7 +90,15 @@ python3 -m lmdeploy.turbomind.chat ./workspace
> **Note**
> 使用 Tensor 并发可以利用多张 GPU 进行推理。在 `chat` 时添加参数 `--tp=` 可以启动运行时 TP。
-#### 部署推理服务
+#### 启动 gradio server
+
+```shell
+python3 -m lmdeploy.serve.gradio.app ./workspace
+```
+
+
+
+#### 通过容器部署推理服务
使用下面的命令启动推理服务:
@@ -107,11 +115,9 @@ python3 -m lmdeploy.serve.client {server_ip_addresss}:33337
也可以通过 WebUI 方式来对话:
```shell
-python3 -m lmdeploy.app {server_ip_addresss}:33337
+python3 -m lmdeploy.serve.gradio.app {server_ip_addresss}:33337
```
-
-
其他模型的部署方式,比如 LLaMA,LLaMA-2,vicuna等等,请参考[这里](docs/zh_cn/serving.md)
### 基于 PyTorch 的推理
diff --git a/docs/en/build.md b/docs/en/build.md
new file mode 100644
index 0000000000..e9934c25e6
--- /dev/null
+++ b/docs/en/build.md
@@ -0,0 +1,26 @@
+## Build from source
+
+- make sure local gcc version no less than 9, which can be conformed by `gcc --version`.
+- install packages for compiling and running:
+ ```shell
+ pip install -r requirements.txt
+ ```
+- install [nccl](https://docs.nvidia.com/deeplearning/nccl/install-guide/index.html), set environment variables:
+ ```shell
+ export NCCL_ROOT_DIR=/path/to/nccl/build
+ export NCCL_LIBRARIES=/path/to/nccl/build/lib
+ ```
+- install rapidjson
+- install openmpi, installing from source is recommended.
+ ```shell
+ wget https://download.open-mpi.org/release/open-mpi/v4.1/openmpi-4.1.5.tar.gz
+ tar -xzf openmpi-*.tar.gz && cd openmpi-*
+ ./configure --with-cuda
+ make -j$(nproc)
+ make install
+ ```
+- build and install lmdeploy:
+ ```shell
+ mkdir build && cd build
+ sh ../generate.sh
+ ```
diff --git a/docs/zh_cn/build.md b/docs/zh_cn/build.md
new file mode 100644
index 0000000000..93ca25dae5
--- /dev/null
+++ b/docs/zh_cn/build.md
@@ -0,0 +1,26 @@
+### 源码安装
+
+- 确保物理机环境的 gcc 版本不低于 9,可以通过`gcc --version`确认。
+- 安装编译和运行依赖包:
+ ```shell
+ pip install -r requirements.txt
+ ```
+- 安装 [nccl](https://docs.nvidia.com/deeplearning/nccl/install-guide/index.html),设置环境变量
+ ```shell
+ export NCCL_ROOT_DIR=/path/to/nccl/build
+ export NCCL_LIBRARIES=/path/to/nccl/build/lib
+ ```
+- rapidjson 安装
+- openmpi 安装, 推荐从源码安装:
+ ```shell
+ wget https://download.open-mpi.org/release/open-mpi/v4.1/openmpi-4.1.5.tar.gz
+ tar -xzf openmpi-*.tar.gz && cd openmpi-*
+ ./configure --with-cuda
+ make -j$(nproc)
+ make install
+ ```
+- lmdeploy 编译安装:
+ ```shell
+ mkdir build && cd build
+ sh ../generate.sh
+ ```
diff --git a/lmdeploy/app.py b/lmdeploy/app.py
deleted file mode 100644
index ccb1d36f0e..0000000000
--- a/lmdeploy/app.py
+++ /dev/null
@@ -1,169 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import os
-import threading
-from functools import partial
-from typing import Sequence
-
-import fire
-import gradio as gr
-
-from lmdeploy.serve.turbomind.chatbot import Chatbot
-
-CSS = """
-#container {
- width: 95%;
- margin-left: auto;
- margin-right: auto;
-}
-
-#chatbot {
- height: 500px;
- overflow: auto;
-}
-
-.chat_wrap_space {
- margin-left: 0.5em
-}
-"""
-
-THEME = gr.themes.Soft(
- primary_hue=gr.themes.colors.blue,
- secondary_hue=gr.themes.colors.sky,
- font=[gr.themes.GoogleFont('Inconsolata'), 'Arial', 'sans-serif'])
-
-
-def chat_stream(instruction: str,
- state_chatbot: Sequence,
- llama_chatbot: Chatbot,
- model_name: str = None):
- """Chat with AI assistant.
-
- Args:
- instruction (str): user's prompt
- state_chatbot (Sequence): the chatting history
- llama_chatbot (Chatbot): the instance of a chatbot
- model_name (str): the name of deployed model
- """
- bot_summarized_response = ''
- model_type = 'turbomind'
- state_chatbot = state_chatbot + [(instruction, None)]
- session_id = threading.current_thread().ident
- bot_response = llama_chatbot.stream_infer(
- session_id, instruction, f'{session_id}-{len(state_chatbot)}')
-
- yield (state_chatbot, state_chatbot, f'{bot_summarized_response}'.strip())
-
- for status, tokens, _ in bot_response:
- if state_chatbot[-1][-1] is None or model_type != 'fairscale':
- state_chatbot[-1] = (state_chatbot[-1][0], tokens)
- else:
- state_chatbot[-1] = (state_chatbot[-1][0],
- state_chatbot[-1][1] + tokens
- ) # piece by piece
- yield (state_chatbot, state_chatbot,
- f'{bot_summarized_response}'.strip())
-
- yield (state_chatbot, state_chatbot, f'{bot_summarized_response}'.strip())
-
-
-def reset_all_func(instruction_txtbox: gr.Textbox, state_chatbot: gr.State,
- llama_chatbot: gr.State, triton_server_addr: str,
- model_name: str):
- """reset the session."""
- state_chatbot = []
- log_level = os.environ.get('SERVICE_LOG_LEVEL', 'INFO')
- llama_chatbot = Chatbot(triton_server_addr,
- model_name,
- log_level=log_level,
- display=True)
-
- return (
- llama_chatbot,
- state_chatbot,
- state_chatbot,
- gr.Textbox.update(value=''),
- )
-
-
-def cancel_func(
- instruction_txtbox: gr.Textbox,
- state_chatbot: gr.State,
- llama_chatbot: gr.State,
-):
- """cancel the session."""
- session_id = llama_chatbot._session.session_id
- llama_chatbot.cancel(session_id)
-
- return (
- llama_chatbot,
- state_chatbot,
- )
-
-
-def run(triton_server_addr: str,
- server_name: str = 'localhost',
- server_port: int = 6006):
- """chat with AI assistant through web ui.
-
- Args:
- triton_server_addr (str): the communication address of inference server
- server_name (str): the ip address of gradio server
- server_port (int): the port of gradio server
- """
- with gr.Blocks(css=CSS, theme=THEME) as demo:
- log_level = os.environ.get('SERVICE_LOG_LEVEL', 'INFO')
- _chatbot = Chatbot(triton_server_addr,
- log_level=log_level,
- display=True)
- model_name = _chatbot.model_name
- chat_interface = partial(chat_stream, model_name=model_name)
- reset_all = partial(reset_all_func,
- model_name=model_name,
- triton_server_addr=triton_server_addr)
- llama_chatbot = gr.State(_chatbot)
- state_chatbot = gr.State([])
-
- with gr.Column(elem_id='container'):
- gr.Markdown('## LMDeploy Playground')
-
- chatbot = gr.Chatbot(elem_id='chatbot', label=model_name)
- instruction_txtbox = gr.Textbox(
- placeholder='Please input the instruction',
- label='Instruction')
- with gr.Row():
- cancel_btn = gr.Button(value='Cancel')
- reset_btn = gr.Button(value='Reset')
-
- send_event = instruction_txtbox.submit(
- chat_interface,
- [instruction_txtbox, state_chatbot, llama_chatbot],
- [state_chatbot, chatbot],
- batch=False,
- max_batch_size=1,
- )
- instruction_txtbox.submit(
- lambda: gr.Textbox.update(value=''),
- [],
- [instruction_txtbox],
- )
-
- cancel_btn.click(cancel_func,
- [instruction_txtbox, state_chatbot, llama_chatbot],
- [llama_chatbot, chatbot],
- cancels=[send_event])
-
- reset_btn.click(
- reset_all, [instruction_txtbox, state_chatbot, llama_chatbot],
- [llama_chatbot, state_chatbot, chatbot, instruction_txtbox],
- cancels=[send_event])
-
- demo.queue(concurrency_count=4, max_size=100, api_open=True).launch(
- max_threads=10,
- share=True,
- server_port=server_port,
- server_name=server_name,
- )
-
-
-if __name__ == '__main__':
- fire.Fire(run)
diff --git a/lmdeploy/serve/gradio/__init__.py b/lmdeploy/serve/gradio/__init__.py
new file mode 100644
index 0000000000..ef101fec61
--- /dev/null
+++ b/lmdeploy/serve/gradio/__init__.py
@@ -0,0 +1 @@
+# Copyright (c) OpenMMLab. All rights reserved.
diff --git a/lmdeploy/serve/gradio/app.py b/lmdeploy/serve/gradio/app.py
new file mode 100644
index 0000000000..9fa80cba06
--- /dev/null
+++ b/lmdeploy/serve/gradio/app.py
@@ -0,0 +1,336 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os
+import os.path as osp
+import random
+import threading
+from functools import partial
+from typing import Sequence
+
+import fire
+import gradio as gr
+
+from lmdeploy import turbomind as tm
+from lmdeploy.model import MODELS
+from lmdeploy.serve.gradio.css import CSS
+from lmdeploy.serve.turbomind.chatbot import Chatbot
+from lmdeploy.turbomind.chat import valid_str
+from lmdeploy.turbomind.tokenizer import Tokenizer
+
+THEME = gr.themes.Soft(
+ primary_hue=gr.themes.colors.blue,
+ secondary_hue=gr.themes.colors.sky,
+ font=[gr.themes.GoogleFont('Inconsolata'), 'Arial', 'sans-serif'])
+
+
+def chat_stream(state_chatbot: Sequence, llama_chatbot: Chatbot,
+ request: gr.Request):
+ """Chat with AI assistant.
+
+ Args:
+ instruction (str): user's prompt
+ state_chatbot (Sequence): the chatting history
+ llama_chatbot (Chatbot): the instance of a chatbot
+ request (gr.Request): the request from a user
+ model_name (str): the name of deployed model
+ """
+ instruction = state_chatbot[-1][0]
+ session_id = threading.current_thread().ident
+ if request is not None:
+ session_id = int(request.kwargs['client']['host'].replace('.', ''))
+
+ bot_response = llama_chatbot.stream_infer(
+ session_id, instruction, f'{session_id}-{len(state_chatbot)}')
+
+ for status, tokens, _ in bot_response:
+ state_chatbot[-1] = (state_chatbot[-1][0], tokens)
+ yield (state_chatbot, state_chatbot, '')
+
+ return (state_chatbot, state_chatbot, '')
+
+
+def reset_all_func(instruction_txtbox: gr.Textbox, state_chatbot: gr.State,
+ llama_chatbot: gr.State, triton_server_addr: str,
+ model_name: str):
+ """reset the session."""
+ state_chatbot = []
+ log_level = os.environ.get('SERVICE_LOG_LEVEL', 'INFO')
+ llama_chatbot = Chatbot(triton_server_addr,
+ model_name,
+ log_level=log_level,
+ display=True)
+
+ return (
+ llama_chatbot,
+ state_chatbot,
+ state_chatbot,
+ gr.Textbox.update(value=''),
+ )
+
+
+def cancel_func(
+ instruction_txtbox: gr.Textbox,
+ state_chatbot: gr.State,
+ llama_chatbot: gr.State,
+):
+ """cancel the session."""
+ session_id = llama_chatbot._session.session_id
+ llama_chatbot.cancel(session_id)
+
+ return (
+ llama_chatbot,
+ state_chatbot,
+ )
+
+
+def add_instruction(instruction, state_chatbot):
+ state_chatbot = state_chatbot + [(instruction, None)]
+ return ('', state_chatbot)
+
+
+def run_server(triton_server_addr: str,
+ server_name: str = 'localhost',
+ server_port: int = 6006):
+ """chat with AI assistant through web ui.
+
+ Args:
+ triton_server_addr (str): the communication address of inference server
+ server_name (str): the ip address of gradio server
+ server_port (int): the port of gradio server
+ """
+ with gr.Blocks(css=CSS, theme=THEME) as demo:
+ log_level = os.environ.get('SERVICE_LOG_LEVEL', 'INFO')
+ llama_chatbot = gr.State(
+ Chatbot(triton_server_addr, log_level=log_level, display=True))
+ state_chatbot = gr.State([])
+ model_name = llama_chatbot.value.model_name
+ reset_all = partial(reset_all_func,
+ model_name=model_name,
+ triton_server_addr=triton_server_addr)
+
+ with gr.Column(elem_id='container'):
+ gr.Markdown('## LMDeploy Playground')
+
+ chatbot = gr.Chatbot(elem_id='chatbot', label=model_name)
+ instruction_txtbox = gr.Textbox(
+ placeholder='Please input the instruction',
+ label='Instruction')
+ with gr.Row():
+ cancel_btn = gr.Button(value='Cancel')
+ reset_btn = gr.Button(value='Reset')
+
+ send_event = instruction_txtbox.submit(
+ add_instruction, [instruction_txtbox, state_chatbot],
+ [instruction_txtbox, state_chatbot]).then(
+ chat_stream, [state_chatbot, llama_chatbot],
+ [state_chatbot, chatbot])
+
+ cancel_btn.click(cancel_func,
+ [instruction_txtbox, state_chatbot, llama_chatbot],
+ [llama_chatbot, chatbot],
+ cancels=[send_event])
+
+ reset_btn.click(
+ reset_all, [instruction_txtbox, state_chatbot, llama_chatbot],
+ [llama_chatbot, state_chatbot, chatbot, instruction_txtbox],
+ cancels=[send_event])
+
+ print(f'server is gonna mount on: http://{server_name}:{server_port}')
+ demo.queue(concurrency_count=4, max_size=100, api_open=True).launch(
+ max_threads=10,
+ share=True,
+ server_port=server_port,
+ server_name=server_name,
+ )
+
+
+# a IO interface mananing global variables
+class InterFace:
+ tokenizer_model_path = None
+ tokenizer = None
+ tm_model = None
+ request2instance = None
+ model_name = None
+ model = None
+
+
+def chat_stream_local(
+ instruction: str,
+ state_chatbot: Sequence,
+ step: gr.State,
+ nth_round: gr.State,
+ request: gr.Request,
+):
+ """Chat with AI assistant.
+
+ Args:
+ instruction (str): user's prompt
+ state_chatbot (Sequence): the chatting history
+ step (gr.State): chat history length
+ nth_round (gr.State): round num
+ request (gr.Request): the request from a user
+ """
+ session_id = threading.current_thread().ident
+ if request is not None:
+ session_id = int(request.kwargs['client']['host'].replace('.', ''))
+ if str(session_id) not in InterFace.request2instance:
+ InterFace.request2instance[str(
+ session_id)] = InterFace.tm_model.create_instance()
+ llama_chatbot = InterFace.request2instance[str(session_id)]
+ seed = random.getrandbits(64)
+ bot_summarized_response = ''
+ state_chatbot = state_chatbot + [(instruction, None)]
+ instruction = InterFace.model.get_prompt(instruction, nth_round == 1)
+ if step >= InterFace.tm_model.session_len:
+ raise gr.Error('WARNING: exceed session max length.'
+ ' Please end the session.')
+ input_ids = InterFace.tokenizer.encode(instruction)
+ bot_response = llama_chatbot.stream_infer(
+ session_id, [input_ids],
+ stream_output=True,
+ request_output_len=512,
+ sequence_start=(nth_round == 1),
+ sequence_end=False,
+ step=step,
+ stop=False,
+ top_k=40,
+ top_p=0.8,
+ temperature=0.8,
+ repetition_penalty=1.0,
+ ignore_eos=False,
+ random_seed=seed if nth_round == 1 else None)
+
+ yield (state_chatbot, state_chatbot, step, nth_round,
+ f'{bot_summarized_response}'.strip())
+
+ response_size = 0
+ for outputs in bot_response:
+ res, tokens = outputs[0]
+ # decode res
+ response = InterFace.tokenizer.decode(res)[response_size:]
+ response = valid_str(response)
+ response_size += len(response)
+ if state_chatbot[-1][-1] is None:
+ state_chatbot[-1] = (state_chatbot[-1][0], response)
+ else:
+ state_chatbot[-1] = (state_chatbot[-1][0],
+ state_chatbot[-1][1] + response
+ ) # piece by piece
+ yield (state_chatbot, state_chatbot, step, nth_round,
+ f'{bot_summarized_response}'.strip())
+
+ step += len(input_ids) + tokens
+ nth_round += 1
+ yield (state_chatbot, state_chatbot, step, nth_round,
+ f'{bot_summarized_response}'.strip())
+
+
+def reset_local_func(instruction_txtbox: gr.Textbox, state_chatbot: gr.State,
+ step: gr.State, nth_round: gr.State, request: gr.Request):
+ """reset the session.
+
+ Args:
+ instruction_txtbox (str): user's prompt
+ state_chatbot (Sequence): the chatting history
+ step (gr.State): chat history length
+ nth_round (gr.State): round num
+ request (gr.Request): the request from a user
+ """
+ state_chatbot = []
+ step = 0
+ nth_round = 1
+
+ session_id = threading.current_thread().ident
+ if request is not None:
+ session_id = int(request.kwargs['client']['host'].replace('.', ''))
+ InterFace.request2instance[str(
+ session_id)] = InterFace.tm_model.create_instance()
+
+ return (
+ state_chatbot,
+ state_chatbot,
+ step,
+ nth_round,
+ gr.Textbox.update(value=''),
+ )
+
+
+def run_local(model_path: str,
+ server_name: str = 'localhost',
+ server_port: int = 6006):
+ """chat with AI assistant through web ui.
+
+ Args:
+ model_path (str): the path of the deployed model
+ server_name (str): the ip address of gradio server
+ server_port (int): the port of gradio server
+ """
+ InterFace.tokenizer_model_path = osp.join(model_path, 'triton_models',
+ 'tokenizer')
+ InterFace.tokenizer = Tokenizer(InterFace.tokenizer_model_path)
+ InterFace.tm_model = tm.TurboMind(model_path,
+ eos_id=InterFace.tokenizer.eos_token_id)
+ InterFace.request2instance = dict()
+ InterFace.model_name = InterFace.tm_model.model_name
+ InterFace.model = MODELS.get(InterFace.model_name)()
+
+ with gr.Blocks(css=CSS, theme=THEME) as demo:
+ state_chatbot = gr.State([])
+ nth_round = gr.State(1)
+ step = gr.State(0)
+
+ with gr.Column(elem_id='container'):
+ gr.Markdown('## LMDeploy Playground')
+
+ chatbot = gr.Chatbot(elem_id='chatbot', label=InterFace.model_name)
+ instruction_txtbox = gr.Textbox(
+ placeholder='Please input the instruction',
+ label='Instruction')
+ with gr.Row():
+ gr.Button(value='Cancel') # noqa: E501
+ reset_btn = gr.Button(value='Reset')
+
+ send_event = instruction_txtbox.submit(
+ chat_stream_local,
+ [instruction_txtbox, state_chatbot, step, nth_round],
+ [state_chatbot, chatbot, step, nth_round])
+ instruction_txtbox.submit(
+ lambda: gr.Textbox.update(value=''),
+ [],
+ [instruction_txtbox],
+ )
+
+ reset_btn.click(
+ reset_local_func,
+ [instruction_txtbox, state_chatbot, step, nth_round],
+ [state_chatbot, chatbot, step, nth_round, instruction_txtbox],
+ cancels=[send_event])
+
+ print(f'server is gonna mount on: http://{server_name}:{server_port}')
+ demo.queue(concurrency_count=4, max_size=100, api_open=True).launch(
+ max_threads=10,
+ share=True,
+ server_port=server_port,
+ server_name=server_name,
+ )
+
+
+def run(model_path_or_server: str,
+ server_name: str = 'localhost',
+ server_port: int = 6006):
+ """chat with AI assistant through web ui.
+
+ Args:
+ model_path_or_server (str): the path of the deployed model or the
+ tritonserver URL. The former is for directly running service with
+ gradio. The latter is for running with tritonserver
+ server_name (str): the ip address of gradio server
+ server_port (int): the port of gradio server
+ """
+ if ':' in model_path_or_server:
+ run_server(model_path_or_server, server_name, server_port)
+ else:
+ run_local(model_path_or_server, server_name, server_port)
+
+
+if __name__ == '__main__':
+ fire.Fire(run)
diff --git a/lmdeploy/serve/gradio/css.py b/lmdeploy/serve/gradio/css.py
new file mode 100644
index 0000000000..b3bd233222
--- /dev/null
+++ b/lmdeploy/serve/gradio/css.py
@@ -0,0 +1,18 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+
+CSS = """
+#container {
+ width: 95%;
+ margin-left: auto;
+ margin-right: auto;
+}
+
+#chatbot {
+ height: 500px;
+ overflow: auto;
+}
+
+.chat_wrap_space {
+ margin-left: 0.5em
+}
+"""