|
1 | 1 | # Copyright (c) OpenMMLab. All rights reserved. |
2 | 2 | from abc import ABC, abstractmethod |
| 3 | +from itertools import groupby |
3 | 4 | from typing import Dict, List, Union |
4 | 5 |
|
5 | 6 | import numpy as np |
@@ -104,6 +105,18 @@ def preprocess(self, messages: List[Dict]) -> List[Dict]: |
104 | 105 | """ # noqa |
105 | 106 | raise NotImplementedError() |
106 | 107 |
|
| 108 | + def has_input_ids(self, messages: List[Dict]) -> bool: |
| 109 | + """Check whether the messages contain input_ids directly. |
| 110 | +
|
| 111 | + Args: |
| 112 | + messages (List[Dict]): a list of message, which is supposed to be |
| 113 | + the output of `preprocess` |
| 114 | + Returns: |
| 115 | + bool: whether the messages contain input_ids directly |
| 116 | + """ |
| 117 | + users = [x['content'] for x in messages if x['role'] == 'user'] |
| 118 | + return len(users) == 1 and isinstance(users[0], List) and isinstance(users[0][0].get('text', ''), List) |
| 119 | + |
107 | 120 | def forward(self, messages: List[Dict], max_batch_size: int = 1) -> List[Dict]: |
108 | 121 | """Extract image feature. ONLY implement it when the backend is |
109 | 122 | turbomind engine. |
@@ -168,6 +181,43 @@ def collect_images(messages): |
168 | 181 | }) for x in content if x['type'] == 'image']) |
169 | 182 | return images |
170 | 183 |
|
| 184 | + def to_pytorch_with_input_ids(self, messages): |
| 185 | + """Pack the preprocessing results in a format compatible with what is |
| 186 | + required by pytorch engine when input_ids are provided directly. |
| 187 | +
|
| 188 | + Args: |
| 189 | + messages(List[Dict]): the output of `preprocess` |
| 190 | + """ |
| 191 | + # collect all preprocessing result from messages |
| 192 | + preps = [x['content'] for x in messages if x['role'] == 'preprocess'] |
| 193 | + assert len(preps) == 1 |
| 194 | + preps = preps[0] |
| 195 | + |
| 196 | + _input_ids = messages[0]['content'][0]['text'] |
| 197 | + segs = [] |
| 198 | + for k, g in groupby(_input_ids, lambda x: x == self.image_token_id): |
| 199 | + if not k: |
| 200 | + segs.append(list(g)) |
| 201 | + else: |
| 202 | + segs.extend([[]] * (len(list(g)) - 1)) |
| 203 | + if _input_ids[0] == self.image_token_id: |
| 204 | + segs = [[]] + segs |
| 205 | + if _input_ids[-1] == self.image_token_id: |
| 206 | + segs = segs + [[]] |
| 207 | + |
| 208 | + assert self.image_token_id == preps[0]['image_token_id'] |
| 209 | + assert len(segs) == len(preps) + 1, (f'the number of image token id {self.image_token_id} is not equal ' |
| 210 | + f'to input images, {len(segs) - 1} vs {len(preps)}') |
| 211 | + input_ids = [] |
| 212 | + for i, seg in enumerate(segs): |
| 213 | + if i > 0 and i <= len(preps): |
| 214 | + preps[i - 1].update(offset=len(input_ids)) |
| 215 | + image_tokens = preps[i - 1]['image_tokens'] |
| 216 | + input_ids.extend([self.image_token_id] * image_tokens) |
| 217 | + input_ids.extend(seg) |
| 218 | + |
| 219 | + return dict(prompt=None, input_ids=input_ids, multimodal=preps) |
| 220 | + |
171 | 221 | def to_pytorch_aux(self, messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start): |
172 | 222 | """Auxiliary function to pack the preprocessing results in a format |
173 | 223 | compatible with what is required by pytorch engine. |
|
0 commit comments