Skip to content

Commit

Permalink
Add pipeline.chat api for easy use (#1292)
Browse files Browse the repository at this point in the history
* add pipeline.chat

* fix stream output

* fix close

* remove stream output

* vl_async_engine.chat

* update docs

* fix yi-vl template

* use end to close session

* resolve comments

* resolve comments

* remove thread

* check model name when initializing  VLAsyncEngine

* use property
  • Loading branch information
irexyc authored Mar 18, 2024
1 parent 12ef4eb commit bd29205
Show file tree
Hide file tree
Showing 5 changed files with 157 additions and 3 deletions.
19 changes: 19 additions & 0 deletions docs/en/inference/vl_pipeline.md
Original file line number Diff line number Diff line change
Expand Up @@ -144,3 +144,22 @@ prompts = [('describe this image', load_image(img_url)) for img_url in image_url
response = pipe(prompts)
print(response)
```

## Multi-turn conversation

There are two ways to do the multi-turn conversations with the pipeline. One is to construct messages according to the format of OpenAI and use above introduced method, the other is to use the `pipeline.chat` interface.

```python
from lmdeploy import pipeline, TurbomindEngineConfig, GenerationConfig
from lmdeploy.vl import load_image

pipe = pipeline('liuhaotian/llava-v1.6-vicuna-7b',
backend_config=TurbomindEngineConfig(session_len=8192))

image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/demo/resources/human-pose.jpg')
gen_config = GenerationConfig(top_k=40, top_p=0.8, temperature=0.8)
sess = pipe.chat(('describe this image', image), gen_config=gen_config)
print(sess.response.text)
sess = pipe.chat('What is the woman doing?', session=sess, gen_config=gen_config)
print(sess.response.text)
```
19 changes: 19 additions & 0 deletions docs/zh_cn/inference/vl_pipeline.md
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,22 @@ prompts = [('describe this image', load_image(img_url)) for img_url in image_url
response = pipe(prompts)
print(response)
```

## 多轮对话

pipeline 进行多轮对话有两种方式,一种是按照 openai 的格式来构造 messages,另外一种是使用 `pipeline.chat` 接口。

```python
from lmdeploy import pipeline, TurbomindEngineConfig, GenerationConfig
from lmdeploy.vl import load_image

pipe = pipeline('liuhaotian/llava-v1.6-vicuna-7b',
backend_config=TurbomindEngineConfig(session_len=8192))

image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/demo/resources/human-pose.jpg')
gen_config = GenerationConfig(top_k=40, top_p=0.8, temperature=0.6)
sess = pipe.chat(('describe this image', image), gen_config=gen_config)
print(sess.response.text)
sess = pipe.chat('What is the woman doing?', session=sess, gen_config=gen_config)
print(sess.response.text)
```
2 changes: 1 addition & 1 deletion lmdeploy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,7 +877,7 @@ def __init__(
meta_instruction="""This is a chat between an inquisitive human and an AI assistant. Assume the role of the AI assistant. Read all the images carefully, and respond to the human's questions with informative, helpful, detailed and polite answers. 这是一个好奇的人类和一个人工智能助手之间的对话。假设你扮演这个AI助手的角色。仔细阅读所有的图像,并对人类的问题做出信息丰富、有帮助、详细的和礼貌的回答。\n\n""", # noqa: E501
user='### Human: ',
eoh='\n',
assistant='### Assistant: ',
assistant='### Assistant:',
eoa='\n',
stop_words=['###'],
**kwargs):
Expand Down
105 changes: 103 additions & 2 deletions lmdeploy/serve/async_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
import random
from argparse import ArgumentError
from contextlib import asynccontextmanager
from itertools import count
from queue import Empty, Queue
from threading import Thread
from typing import Dict, List, Literal, Optional, Union
from typing import Any, Dict, List, Literal, Optional, Tuple, Union

from lmdeploy.messages import (EngineGenerationConfig, GenerationConfig,
PytorchEngineConfig, Response,
Expand Down Expand Up @@ -73,6 +74,55 @@ class GenOut:
finish_reason: Optional[Literal['stop', 'length']] = None


class Session:
"""Session for AsyncEngine.chat.
Args:
_id (int): session_id for internal use.
_step (int): the offset of the k/v cache for internal use.
_prompt (Any): input prompt for internal use.
_response (Reaponse): model output for prompt.
_engine (Any): engine for internal use.
history (List[Any, str]): chat history.
"""
_ids = count(0)

def __init__(self):
self._id: int = next(self._ids)
self._step: int = 0
self._prompt: Any = None
self._response: Response = None
self._engine: Any = None
self.history: List[Tuple[Any, str]] = []

def _merge_response(self, resp: Response, step: Union[Response, GenOut]):
"""merge response."""
resp.text += step.text if isinstance(step, Response) else step.response
resp.input_token_len = step.input_token_len
resp.generate_token_len = step.generate_token_len
resp.finish_reason = step.finish_reason
return resp

@property
def response(self) -> Response:
"""return response."""
return self._response

def close(self):
"""release engine storage for this session."""
if self._engine:
inst = self._engine.create_instance()
inst.end(self._id)

def __repr__(self) -> str:
res = ''
for user, assistant in self.history:
if isinstance(user, list):
user = str(user)
res += f'USER:\n{user}\nASSISTANT:\n{assistant}\n'
return res


class AsyncEngine:
"""Async inference engine. Maintaining a bunch of tm_model instances.
Expand Down Expand Up @@ -520,7 +570,7 @@ async def generate(
**prompt_input,
gen_config=gen_config,
stream_output=stream_response,
sequence_start=(sequence_start),
sequence_start=sequence_start,
sequence_end=sequence_end,
step=self.id2step[str(session_id)]):
_, res, tokens = outputs
Expand Down Expand Up @@ -550,3 +600,54 @@ async def generate(
# TODO modify pytorch or turbomind api
if self.backend == 'pytorch' and sequence_end:
await self.end_session(session_id)

def chat(self,
prompt: str,
session=None,
gen_config: Optional[Union[GenerationConfig,
EngineGenerationConfig]] = None,
do_preprocess: bool = True,
**kwargs) -> Session:
"""Chat.
Args:
prompt (str): prompt
session (Session): the chat session
gen_config (GenerationConfig | None): a instance of
GenerationConfig. Default to None.
do_preprocess (bool): whether pre-process the messages. Default to
True, which means chat_template will be applied.
**kwargs (dict): ad hoc parametrization of `gen_config
"""
if session is None:
session = Session()
session._engine = self.engine

# sync & init
session._prompt = prompt
session._response = None

sequence_start = session._step == 0

async def _work():
resp = Response('', -1, -1, session._id)
async for output in self.generate(prompt,
session_id=session._id,
gen_config=gen_config,
stream_response=False,
sequence_start=sequence_start,
sequence_end=False,
step=session._step,
do_preprocess=do_preprocess,
**kwargs):
resp = session._merge_response(resp, output)
return resp

from lmdeploy.pytorch.engine.request import _run_until_complete
resp = _run_until_complete(_work())

session._response = resp
session._step += resp.generate_token_len + resp.input_token_len
session.history.append((session._prompt, resp.text))

return session
15 changes: 15 additions & 0 deletions lmdeploy/serve/vl_async_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ class VLAsyncEngine(AsyncEngine):

def __init__(self, model_path: str, **kwargs) -> None:
super().__init__(model_path, **kwargs)
if self.model_name == 'base':
raise RuntimeError(
'please specify chat template as guided in https://lmdeploy.readthedocs.io/en/latest/inference/vl_pipeline.html#set-chat-template' # noqa: E501
)
self.vl_encoder = ImageEncoder(model_path)
self.vl_prompt_template = get_vl_prompt_template(
model_path, self.chat_template, self.model_name)
Expand Down Expand Up @@ -93,3 +97,14 @@ def __call__(self, prompts: Union[VLPromptType, List[Dict],
"""Inference a batch of prompts."""
prompts = self._convert_prompts(prompts)
return super().__call__(prompts, **kwargs)

def chat(self, prompts: VLPromptType, **kwargs):
"""chat."""
_prompts = self._convert_prompts(prompts)
sess = super().chat(_prompts, **kwargs)

# recover prompts & history
sess._prompt = prompts
last_round = sess.history[-1]
sess.history[-1] = (prompts, last_round[-1])
return sess

0 comments on commit bd29205

Please sign in to comment.