Skip to content

Commit

Permalink
Add inference pipeline for VL models (InternLM#1214)
Browse files Browse the repository at this point in the history
* support vl pipeline

* fix load vl model

* support yi-vl

* add missing file

* update llava template

* fix tokenizer encode

* fix _get_prompt_input

* remove LlavaVLModel._load_model

* add runtime deps

* add docstring

* add some log info

* update vl pipeline

* resolve comments

* resolve comments

* resolve comments

* type-hint

* update docs

* move get_hf_config_content to lmdeploy.utils

* unify lmdeploy.pipeline & lmdeploy.vl.pipeline

* remove unused

* update docs

* update index.rst

* resolve comments

* remove llava pretrained model

* resolve comments

* resolve comments

* update docs

* update

* remove <200b> character in docs

* update docs
  • Loading branch information
irexyc authored Mar 14, 2024
1 parent 5682efe commit e6fecd8
Show file tree
Hide file tree
Showing 27 changed files with 1,315 additions and 76 deletions.
1 change: 1 addition & 0 deletions docs/en/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ Welcome to LMDeploy's tutorials!
:caption: Inference

inference/pipeline.md
inference/vl_pipeline.md
inference/turbomind.md
inference/turbomind_config.md
inference/pytorch.md
Expand Down
146 changes: 146 additions & 0 deletions docs/en/inference/vl_pipeline.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# VLM Offline Inference Pipeline

LMDeploy abstracts the complex inference process of multi-modal Vision-Language Models (VLM) into an easy-to-use pipeline, similar to the the Large Language Model (LLM) inference [pipeline](./pipeline.md).
In this article, we will take the [liuhaotian/llava-v1.6-vicuna-7b](https://huggingface.co/liuhaotian/llava-v1.6-vicuna-7b) model as an example, exhibiting the powerful capabilities of the VLM pipeline through various examples.
First, we will demonstrate the most basic utilization of the pipeline and progressively unveil additional functionalities by configuring the engine parameters and generation arguments, such as tensor parallelism, setting context window size, and random sampling, customizing chat template and so on. Next, we will provide inference examples for scenarios involving multiple images, batch prompts etc.

## A 'Hello, world' example

```python
from lmdeploy import pipeline
from lmdeploy.vl import load_image

pipe = pipeline('liuhaotian/llava-v1.6-vicuna-7b')

image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')
response = pipe(('describe this image', image))
print(response)
```

If `ImportError` occurs while executing this case, please install the required dependency packages as prompted.

In the above example, the inference prompt is a tuple structure consisting of (prompt, image). Besides this structure, the pipeline also supports prompts in the OpenAI format:

```python
from lmdeploy import pipeline

pipe = pipeline('liuhaotian/llava-v1.6-vicuna-7b')

prompts = [
{
'role': 'user',
'content': [
{'type': 'text', 'text': 'describe this image'},
{'type': 'image_url', 'image_url': {'url': 'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg'}}
]
}
]
response = pipe(prompts)
print(response)
```

### Set tensor parallelism

Tensor paramllelism can be activated by setting the engine parameter `tp`

```python
from lmdeploy import pipeline, TurbomindEngineConfig
from lmdeploy.vl import load_image

pipe = pipeline('liuhaotian/llava-v1.6-vicuna-7b',
backend_config=TurbomindEngineConfig(tp=2))

image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')
response = pipe(('describe this image', image))
print(response)
```

### Set context window size

When creating the pipeline, you can customize the size of the context window by setting the engine parameter `session_len`.

```python
from lmdeploy import pipeline, TurbomindEngineConfig
from lmdeploy.vl import load_image

pipe = pipeline('liuhaotian/llava-v1.6-vicuna-7b',
backend_config=TurbomindEngineConfig(session_len=8192))

image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')
response = pipe(('describe this image', image))
print(response)
```

### Set sampling parameters

You can change the default sampling parameters of pipeline by passing `GenerationConfig`

```python
from lmdeploy import pipeline, GenerationConfig, TurbomindEngineConfig
from lmdeploy.vl import load_image

pipe = pipeline('liuhaotian/llava-v1.6-vicuna-7b',
backend_config=TurbomindEngineConfig(tp=2, session_len=8192))
gen_config = GenerationConfig(top_k=40, top_p=0.8, temperature=0.6)
image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')
response = pipe(('describe this image', image), gen_config=gen_config)
print(response)
```

### Set chat template

While performing inference, LMDeploy identifies an appropriate chat template from its builtin collection based on the model path and subsequently applies this template to the input prompts. However, when a chat template cannot be told from its model path, users have to specify it. For example, liuhaotian/llava-v1.5-7b employs the 'vicuna' chat template, but the name 'vicuna' cannot be ascertained from the model's path. We can specify it by setting 'vicuna' to `ChatTemplateConfig` as follows:

```python
from lmdeploy import pipeline, ChatTemplateConfig
from lmdeploy.vl import load_image
pipe = pipeline('liuhaotian/llava-v1.5-7b',
chat_template_config=ChatTemplateConfig(model_name='vicuna'))

image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')
response = pipe(('describe this image', image))
print(response)
```

For more information about customizing a chat template, please refer to [this](../advance/chat_template.md) guide

## Multi-images inference

When dealing with multiple images, you can put them all in one list. Keep in mind that multiple images will lead to a higher number of input tokens, and as a result, the size of the [context window](#set-context-window-size) typically needs to be increased.

```python
from lmdeploy import pipeline, TurbomindEngineConfig
from lmdeploy.vl import load_image

pipe = pipeline('liuhaotian/llava-v1.6-vicuna-7b',
backend_config=TurbomindEngineConfig(session_len=8192))

image_urls=[
'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/demo/resources/human-pose.jpg',
'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/demo/resources/det.jpg'
]

images = [load_image(img_url) for img_url in image_urls]
response = pipe(('describe these images', images))
print(response)
```

## Batch prompts inference

Conducting inference with batch prompts is quite straightforward; just place them within a list structure:

```python
from lmdeploy import pipeline, TurbomindEngineConfig
from lmdeploy.vl import load_image

pipe = pipeline('liuhaotian/llava-v1.6-vicuna-7b',
backend_config=TurbomindEngineConfig(session_len=8192))

image_urls=[
"https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/demo/resources/human-pose.jpg",
"https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/demo/resources/det.jpg"
]
prompts = [('describe this image', load_image(img_url)) for img_url in image_urls]
response = pipe(prompts)
print(response)
```
1 change: 1 addition & 0 deletions docs/zh_cn/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
:caption: 推理

inference/pipeline.md
inference/vl_pipeline.md
inference/turbomind.md
inference/turbomind_config.md
inference/pytorch.md
Expand Down
145 changes: 145 additions & 0 deletions docs/zh_cn/inference/vl_pipeline.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# VLM 离线推理 pipeline

LMDeploy 把视觉-语言模型(VLM)复杂的推理过程,抽象为简单好用的 pipeline。它的用法与大语言模型(LLM)推理 [pipeline](./pipeline.md) 类似。本文将以 [liuhaotian/llava-v1.6-vicuna-7b](https://huggingface.co/liuhaotian/llava-v1.6-vicuna-7b) 模型为例,通过若干示例,展示 VLM pipeline 的强大能力。
首先,我们会展示 pipeline 最基础的用法,并在此基础上,通过引擎的配置和生成条件配置,逐步引出更多能力,比如模型并行、自定义上下文长度、随机采样等等。然后,针对多图、批量提示词等场景,给出对应的推理示例。

## "Hello, world" 示例

```python
from lmdeploy import pipeline
from lmdeploy.vl import load_image

pipe = pipeline('liuhaotian/llava-v1.6-vicuna-7b')

image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')
response = pipe(('describe this image', image))
print(response)
```

如果在执行这个用例时,出现 `ImportError` 的错误,请按照提示安装相关的依赖包。

上面的例子中,推理时的提示词是 (prompt, image) 的 tuple 结构。除了这种结构外,pipeline 支持 openai 格式的提示词:

```python
from lmdeploy import pipeline

pipe = pipeline('liuhaotian/llava-v1.6-vicuna-7b')

prompts = [
{
'role': 'user',
'content': [
{'type': 'text', 'text': 'describe this image'},
{'type': 'image_url', 'image_url': {'url': 'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg'}}
]
}
]
response = pipe(prompts)
print(response)
```

### 设置多卡并行

设置引擎参数 `tp`,可激活多卡并行能力

```python
from lmdeploy import pipeline, TurbomindEngineConfig
from lmdeploy.vl import load_image

pipe = pipeline('liuhaotian/llava-v1.6-vicuna-7b',
backend_config=TurbomindEngineConfig(tp=2))

image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')
response = pipe(('describe this image', image))
print(response)
```

### 设置上下文长度

创建 pipeline 时,通过设置引擎参数 `session_len`,可以定制上下文窗口的最大长度

```python
from lmdeploy import pipeline, TurbomindEngineConfig
from lmdeploy.vl import load_image

pipe = pipeline('liuhaotian/llava-v1.6-vicuna-7b',
backend_config=TurbomindEngineConfig(session_len=8192))

image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')
response = pipe(('describe this image', image))
print(response)
```

### 设置随机采样参数

可通过传入 `GenerationConfig` 修改 pipeline 的生成接口中的默认采样参数。

```python
from lmdeploy import pipeline, GenerationConfig, TurbomindEngineConfig
from lmdeploy.vl import load_image

pipe = pipeline('liuhaotian/llava-v1.6-vicuna-7b',
backend_config=TurbomindEngineConfig(tp=2, session_len=8192))
gen_config = GenerationConfig(top_k=40, top_p=0.8, temperature=0.6)
image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')
response = pipe(('describe this image', image), gen_config=gen_config)
print(response)
```

### 设置对话模板

推理时,LMDeploy 会根据模型路径匹配内置的对话模板,并把对话模板应用到输入的提示词上。但是,对于类似 [llava-v1.5-7b](https://huggingface.co/liuhaotian/llava-v1.5-7b) 视觉-语言模型,它使用的对话模板是 vicuna,但是这个模板名无法从模型路径中获取,所以需要用户指定。具体方式如下:

```python
from lmdeploy import pipeline, ChatTemplateConfig
from lmdeploy.vl import load_image
pipe = pipeline('liuhaotian/llava-v1.5-7b',
chat_template_config=ChatTemplateConfig(model_name='vicuna'))

image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')
response = pipe(('describe this image', image))
print(response)
```

关于如何自定义对话模版,请参考[这里](../advance/chat_template.md)

## 多图推理

对于多图的场景,在推理时,只要把它们放在一个列表中即可。不过,多图意味着输入 token 数更多,所以通常需要[增大推理的上下文长度](#设置上下文长度)

```python
from lmdeploy import pipeline, TurbomindEngineConfig
from lmdeploy.vl import load_image

pipe = pipeline('liuhaotian/llava-v1.6-vicuna-7b',
backend_config=TurbomindEngineConfig(session_len=8192))

image_urls=[
'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/demo/resources/human-pose.jpg',
'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/demo/resources/det.jpg'
]

images = [load_image(img_url) for img_url in image_urls]
response = pipe(('describe these images', images))
print(response)
```

## 提示词批处理

做批量提示词推理非常简单,只要把它们放在一个 list 结构中:

```python
from lmdeploy import pipeline, TurbomindEngineConfig
from lmdeploy.vl import load_image

pipe = pipeline('liuhaotian/llava-v1.6-vicuna-7b',
backend_config=TurbomindEngineConfig(session_len=8192))

image_urls=[
"https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/demo/resources/human-pose.jpg",
"https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/demo/resources/det.jpg"
]
prompts = [('describe this image', load_image(img_url)) for img_url in image_urls]
response = pipe(prompts)
print(response)
```
38 changes: 28 additions & 10 deletions lmdeploy/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
from typing import List, Literal, Optional, Union

from .archs import autoget_backend_config
from .archs import autoget_backend_config, get_task
from .messages import PytorchEngineConfig, TurbomindEngineConfig
from .model import ChatTemplateConfig

Expand Down Expand Up @@ -39,19 +39,36 @@ def pipeline(model_path: str,
log_level(str): set log level whose value among [CRITICAL, ERROR, WARNING, INFO, DEBUG]
Examples:
>>> # LLM
>>> import lmdeploy
>>> pipe = lmdeploy.pipeline('internlm/internlm-chat-7b')
>>> response = pipe(['hi','say this is a test'])
>>> print(response)
>>>
>>> # VLM
>>> from lmdeploy.vl import load_image
>>> from lmdeploy import pipeline, TurbomindEngineConfig, ChatTemplateConfig
>>> pipe = pipeline('liuhaotian/llava-v1.5-7b',
... backend_config=TurbomindEngineConfig(session_len=8192),
... chat_template_config=ChatTemplateConfig(model_name='vicuna'))
>>> im = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/demo/resources/human-pose.jpg')
>>> response = pipe([('describe this image', [im])])
>>> print(response)
""" # noqa E501
from lmdeploy.serve.async_engine import AsyncEngine
if os.getenv('TM_LOG_LEVEL') is None:
os.environ['TM_LOG_LEVEL'] = log_level
from lmdeploy.utils import get_logger
logger = get_logger('lmdeploy')
logger.setLevel(log_level)

if type(backend_config) is not PytorchEngineConfig:
pipeline_type, pipeline_class = get_task(model_path)
if pipeline_type == 'vlm':
assert (type(backend_config) is TurbomindEngineConfig) or \
(backend_config is None), \
f'{pipeline_type} model only support turbomind backend.'

if pipeline_type == 'llm' and type(
backend_config) is not PytorchEngineConfig:
# set auto backend mode
backend_config = autoget_backend_config(model_path, backend_config)
backend = 'pytorch' if type(
Expand All @@ -65,13 +82,14 @@ def pipeline(model_path: str,
kwargs.pop('tp')
else:
tp = 1 if backend_config is None else backend_config.tp
return AsyncEngine(model_path,
model_name=model_name,
backend=backend,
backend_config=backend_config,
chat_template_config=chat_template_config,
tp=tp,
**kwargs)

return pipeline_class(model_path,
model_name=model_name,
backend=backend,
backend_config=backend_config,
chat_template_config=chat_template_config,
tp=tp,
**kwargs)


def serve(model_path: str,
Expand Down
Loading

0 comments on commit e6fecd8

Please sign in to comment.