Skip to content

Commit a6aa375

Browse files
authored
support image_data input to /generate endpoint (#4086)
* support image_data input to /generate endpoint * fix input_ids
1 parent c83bedc commit a6aa375

File tree

4 files changed

+87
-8
lines changed

4 files changed

+87
-8
lines changed

lmdeploy/serve/openai/api_server.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -906,6 +906,25 @@ async def generate(request: GenerateReqInput, raw_request: Request = None):
906906
return error_check_ret
907907
if VariableInterface.async_engine.id2step.get(request.session_id, 0) != 0:
908908
return create_error_response(HTTPStatus.BAD_REQUEST, f'The session_id `{request.session_id}` is occupied.')
909+
if (request.prompt is not None) ^ (request.input_ids is None):
910+
return create_error_response(HTTPStatus.BAD_REQUEST, 'You must specify exactly one of prompt or input_ids')
911+
912+
prompt = request.prompt
913+
input_ids = request.input_ids
914+
image_data = request.image_data
915+
if image_data is not None:
916+
# convert to openai format
917+
image_input = []
918+
if not isinstance(image_data, List):
919+
image_data = [image_data]
920+
for img in image_data:
921+
if isinstance(img, str):
922+
image_input.append(dict(type='image_url', image_url=dict(url=img)))
923+
else:
924+
image_input.append(dict(type='image_url', image_url=img))
925+
text_input = dict(type='text', text=prompt if prompt else input_ids)
926+
prompt = [dict(role='user', content=[text_input] + image_input)]
927+
input_ids = None
909928

910929
gen_config = GenerationConfig(
911930
max_new_tokens=request.max_tokens,
@@ -925,9 +944,9 @@ async def generate(request: GenerateReqInput, raw_request: Request = None):
925944
)
926945

927946
result_generator = VariableInterface.async_engine.generate(
928-
messages=request.prompt,
947+
messages=prompt,
929948
session_id=request.session_id,
930-
input_ids=request.input_ids,
949+
input_ids=input_ids,
931950
gen_config=gen_config,
932951
stream_response=True, # always use stream to enable batching
933952
sequence_start=True,

lmdeploy/serve/openai/protocol.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,11 +439,17 @@ class UpdateParamsRequest(BaseModel):
439439
finished: bool = False
440440

441441

442+
# str for url/base64, base64 should be data:image/jpeg;base64, dict should be {'url': url/base64, 'options': ...}
443+
ImageDataInputItem = Union[str, Dict]
444+
ImageDataFormat = Union[ImageDataInputItem, List[ImageDataInputItem]]
445+
446+
442447
# /generate input
443448
class GenerateReqInput(BaseModel):
444449
session_id: Optional[int] = -1
445450
prompt: Optional[str] = None
446451
input_ids: Optional[List[int]] = None
452+
image_data: Optional[ImageDataFormat] = None
447453
return_logprob: Optional[bool] = None
448454
max_tokens: int = 128
449455
stop: Optional[Union[str, List[str]]] = None

lmdeploy/vl/engine.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,16 @@ async def wrap_for_pytorch(
8686
]
8787
)
8888
"""
89-
result = self.model.to_pytorch(messages,
90-
chat_template,
91-
tokenizer,
92-
sequence_start,
93-
tools=tools,
94-
enable_thinking=enable_thinking)
89+
has_input_ids = self.model.has_input_ids(messages)
90+
if not has_input_ids:
91+
result = self.model.to_pytorch(messages,
92+
chat_template,
93+
tokenizer,
94+
sequence_start,
95+
tools=tools,
96+
enable_thinking=enable_thinking)
97+
else:
98+
result = self.model.to_pytorch_with_input_ids(messages)
9599
# clear data
96100
for i, message in enumerate(messages):
97101
if isinstance(message['content'], List):

lmdeploy/vl/model/base.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
from abc import ABC, abstractmethod
3+
from itertools import groupby
34
from typing import Dict, List, Union
45

56
import numpy as np
@@ -104,6 +105,18 @@ def preprocess(self, messages: List[Dict]) -> List[Dict]:
104105
""" # noqa
105106
raise NotImplementedError()
106107

108+
def has_input_ids(self, messages: List[Dict]) -> bool:
109+
"""Check whether the messages contain input_ids directly.
110+
111+
Args:
112+
messages (List[Dict]): a list of message, which is supposed to be
113+
the output of `preprocess`
114+
Returns:
115+
bool: whether the messages contain input_ids directly
116+
"""
117+
users = [x['content'] for x in messages if x['role'] == 'user']
118+
return len(users) == 1 and isinstance(users[0], List) and isinstance(users[0][0].get('text', ''), List)
119+
107120
def forward(self, messages: List[Dict], max_batch_size: int = 1) -> List[Dict]:
108121
"""Extract image feature. ONLY implement it when the backend is
109122
turbomind engine.
@@ -168,6 +181,43 @@ def collect_images(messages):
168181
}) for x in content if x['type'] == 'image'])
169182
return images
170183

184+
def to_pytorch_with_input_ids(self, messages):
185+
"""Pack the preprocessing results in a format compatible with what is
186+
required by pytorch engine when input_ids are provided directly.
187+
188+
Args:
189+
messages(List[Dict]): the output of `preprocess`
190+
"""
191+
# collect all preprocessing result from messages
192+
preps = [x['content'] for x in messages if x['role'] == 'preprocess']
193+
assert len(preps) == 1
194+
preps = preps[0]
195+
196+
_input_ids = messages[0]['content'][0]['text']
197+
segs = []
198+
for k, g in groupby(_input_ids, lambda x: x == self.image_token_id):
199+
if not k:
200+
segs.append(list(g))
201+
else:
202+
segs.extend([[]] * (len(list(g)) - 1))
203+
if _input_ids[0] == self.image_token_id:
204+
segs = [[]] + segs
205+
if _input_ids[-1] == self.image_token_id:
206+
segs = segs + [[]]
207+
208+
assert self.image_token_id == preps[0]['image_token_id']
209+
assert len(segs) == len(preps) + 1, (f'the number of image token id {self.image_token_id} is not equal '
210+
f'to input images, {len(segs) - 1} vs {len(preps)}')
211+
input_ids = []
212+
for i, seg in enumerate(segs):
213+
if i > 0 and i <= len(preps):
214+
preps[i - 1].update(offset=len(input_ids))
215+
image_tokens = preps[i - 1]['image_tokens']
216+
input_ids.extend([self.image_token_id] * image_tokens)
217+
input_ids.extend(seg)
218+
219+
return dict(prompt=None, input_ids=input_ids, multimodal=preps)
220+
171221
def to_pytorch_aux(self, messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start):
172222
"""Auxiliary function to pack the preprocessing results in a format
173223
compatible with what is required by pytorch engine.

0 commit comments

Comments
 (0)