diff --git a/lmdeploy/vl/engine.py b/lmdeploy/vl/engine.py index ee251a5cb..7f786d5f9 100644 --- a/lmdeploy/vl/engine.py +++ b/lmdeploy/vl/engine.py @@ -1,12 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. import asyncio -import queue -import time -from threading import Thread from typing import Dict, List, Optional, Union import torch -from PIL.Image import Image from lmdeploy.messages import (PytorchEngineConfig, TurbomindEngineConfig, VisionConfig) @@ -26,58 +22,6 @@ def _raise_exception_on_finish(task: asyncio.Task) -> None: raise e -class Record: - """Batching manager.""" - - def __init__(self, thread_safe): - self.thread_safe = thread_safe - self.number = [] - self.waiting = [] - self.kwargs = [] - self.done = [] - self.res_que = [] - self.total = 0 - - def enqueue(self, images: List[Image], kwargs: List[Dict], - que: Union[queue.Queue, asyncio.Queue]): - """add ith request to manager.""" - self.number.append(len(images)) - self.waiting.extend(images) - self.kwargs.extend(kwargs) - self.res_que.append(que) - self.total += len(images) - self.log('received', len(images)) - - def dequeue(self, max_batch_size): - """try to dequeue max batch size images.""" - inputs = self.waiting[:max_batch_size] - kwargs = self.kwargs[:max_batch_size] - self.waiting = self.waiting[max_batch_size:] - self.kwargs = self.kwargs[max_batch_size:] - self.total -= len(inputs) - self.log('process', len(inputs)) - return inputs, kwargs - - def notify(self): - """set result if request i is finished.""" - if len(self.number) == 0 or self.number[0] > len(self.done): - return False - num_images = self.number.pop(0) - outputs = self.done[:num_images] - self.done = self.done[num_images:] - que = self.res_que.pop(0) - self.log('done', num_images) - if self.thread_safe: - que._loop.call_soon_threadsafe(que.put_nowait, outputs) - else: - que.put_nowait(outputs) - return True - - def log(self, task: str, num: int): - logger.info(f'ImageEncoder {task} {num} images, ' - f'left {self.total} images.') - - class ImageEncoder: """Image encoder.""" @@ -97,93 +41,14 @@ def __init__( self.vision_config = vision_config self.max_batch_size = vision_config.max_batch_size torch.cuda.empty_cache() - self._que: asyncio.Queue = None - self._loop_task: asyncio.Task = None - if vision_config.thread_safe: - self._create_thread_safe_task() - - def _create_thread_safe_task(self): - """thread safe loop task.""" - self._loop = asyncio.new_event_loop() - self._que = asyncio.Queue() - - def _work_thread(): - asyncio.set_event_loop(self._loop) - self._loop.run_until_complete(self._forward_loop()) - - thread = Thread(target=_work_thread, daemon=True) - thread.start() - self._loop_thread = thread - - def _create_event_loop_task(self): - """event loop task.""" - task = asyncio.get_event_loop().create_task(self._forward_loop()) - self._loop_task = task - self._loop = task.get_loop() - - @property - def req_que(self): - if self.vision_config.thread_safe: - return self._que - if self._que is None: - self._que = asyncio.Queue() - if self._loop_task is None: - self._create_event_loop_task() - if asyncio.get_event_loop() != self._loop: - raise RuntimeError('Current event loop is different from' - ' the one bound to loop task!') - return self._que - - async def _forward_loop(self): - """working loop to process images.""" - logger.info('start ImageEncoder._forward_loop') - record = Record(self.vision_config.thread_safe) - while True: - while record.total == 0 or (self._que.qsize() and - record.total < self.max_batch_size): - while self._que.qsize() == 0: - await asyncio.sleep(0.01) - item = await self._que.get() - record.enqueue(item[0], item[1], item[2]) - inputs, kwargs = record.dequeue(self.max_batch_size) - future = asyncio.get_event_loop().run_in_executor( - None, self.forward, inputs, kwargs) - future.add_done_callback(_raise_exception_on_finish) - outputs = await future - record.done.extend(outputs) - while record.notify(): - pass - - def forward(self, messages: List[List[Dict]]) -> Dict: - # messages in batch - assert all(isinstance(message, List) for message in messages) - - time_start = time.perf_counter() - outputs, n_image = self.model.forward(messages) - if isinstance(outputs[0], torch.Tensor): - outputs = [x.cpu() for x in outputs] - time_end = time.perf_counter() - logger.info(f'ImageEncoder forward {n_image} images, ' - f'cost {time_end - time_start:.3f}s') - return outputs - - def infer(self, messages: List[Dict]) -> Dict: - """perform vision encoding to get a dict, in which there are input_ids, - embeddings, embedding_ranges and so on. They will be used by turbomind - engine. The key in the dict must be the same defined in turbmoind - engine's infer API. - - Args: - messages (List[Dict]): user's input in GPT4V format - """ - assert isinstance(messages, List) - assert all(isinstance(item, Dict) for item in messages) - - return self.forward(messages) async def preprocess(self, messages: List[Dict]) -> List[Dict]: """preprocess multimodal data in the messages.""" - return self.model.preprocess(messages) + future = asyncio.get_event_loop().run_in_executor( + None, self.model.preprocess, messages) + future.add_done_callback(_raise_exception_on_finish) + outputs = await future + return outputs async def async_infer(self, messages: List[Dict]) -> List[Dict]: """get multimodal embedding. @@ -192,7 +57,11 @@ async def async_infer(self, messages: List[Dict]) -> List[Dict]: messages (List[Dict]): a list of message, which is the output of `preprocess()` """ - return self.model.forward(messages) + future = asyncio.get_event_loop().run_in_executor( + None, self.model.forward, messages, self.max_batch_size) + future.add_done_callback(_raise_exception_on_finish) + outputs = await future + return outputs async def wrap_for_pytorch(self, messages: List[Dict], chat_template, tokenizer, sequence_start) -> List[Dict]: diff --git a/lmdeploy/vl/model/base.py b/lmdeploy/vl/model/base.py index 803df4522..87de31e06 100644 --- a/lmdeploy/vl/model/base.py +++ b/lmdeploy/vl/model/base.py @@ -39,7 +39,6 @@ def build_preprocessor(self, ): """ raise NotImplementedError() - @abstractmethod def build_model(self, ): """build model. @@ -90,13 +89,16 @@ def preprocess(self, messages: List[Dict]) -> List[Dict]: """ # noqa raise NotImplementedError() - @abstractmethod - def forward(self, messages: List[Dict]) -> List[Dict]: + def forward(self, + messages: List[Dict], + max_batch_size: int = 1) -> List[Dict]: """extract image feature. ONLY implement it when the backend is turbomind engine. Args: messages(List[Dict]): the outputs of `preprocess` + max_batch_size(int): the max batch size when forwarding vision + model Return: the message list with forwarding results included, which is determined by the derived classes @@ -104,7 +106,6 @@ def forward(self, messages: List[Dict]) -> List[Dict]: if self.backend == 'turbomind': raise NotImplementedError() - @abstractmethod def to_pytorch(self, messages, chat_template, tokenizer, sequence_start): """pack the preprocessing results in a format compatible with what is required by pytorch engine. ONLY implement it when the backend is @@ -119,7 +120,6 @@ def to_pytorch(self, messages, chat_template, tokenizer, sequence_start): if self.backend == 'pytorch': raise NotImplementedError() - @abstractmethod def to_turbomind(self, messages, chat_template, tokenizer, sequence_start): """pack the forwarding results in a format compatible with what is required by turbomind engine. ONLY implement it when the backend is diff --git a/lmdeploy/vl/model/cogvlm.py b/lmdeploy/vl/model/cogvlm.py index fd1a7b90e..8b1ebbc67 100644 --- a/lmdeploy/vl/model/cogvlm.py +++ b/lmdeploy/vl/model/cogvlm.py @@ -1,13 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. -import warnings from typing import Dict, List -import torch -from transformers import AutoModelForCausalLM - from lmdeploy.utils import get_logger from lmdeploy.vl.model.base import VISION_MODELS, VisonModel -from lmdeploy.vl.model.utils import disable_logging logger = get_logger('lmdeploy') @@ -32,49 +27,6 @@ def build_preprocessor(self): patch_size = self.hf_config.vision_config['patch_size'] self.n_token_per_image = 2 + (image_size // patch_size // 2)**2 - def build_model(self): - from accelerate import init_empty_weights, load_checkpoint_and_dispatch - from accelerate.utils import get_balanced_memory, infer_auto_device_map - with init_empty_weights(), warnings.catch_warnings(): - self.model = AutoModelForCausalLM.from_config( - self.hf_config, trust_remote_code=True) - if not self.with_llm: - del self.model.lm_head - for key in ['layers', 'norm', 'embed_tokens']: - setattr(self.model.model, key, None) - else: - self.vl_model = self.model - - no_split_module_classes = ['TransformerLayer'] - max_memory = get_balanced_memory( - self.model, - max_memory=self.max_memory, - dtype=torch.half, - no_split_module_classes=no_split_module_classes) - device_map = infer_auto_device_map( - self.model, - no_split_module_classes=no_split_module_classes, - max_memory=max_memory, - dtype=torch.half) - same_device_keys = [('model.vision.linear_proj', 'model.vision.boi', - 'model.vision.eoi')] - for keys in same_device_keys: - keys = [k for k in keys if k in device_map] - if len(keys) <= 1: - continue - for k in keys[1:]: - device_map[k] = device_map[keys[0]] - - with disable_logging(): - load_checkpoint_and_dispatch( - model=self.model, - checkpoint=self.model_path, - device_map=device_map if not self.with_llm else {'': 'cpu'}, - no_split_module_classes=no_split_module_classes, - dtype=torch.half) - self.model = self.model.model.vision - self.model.eval() - def preprocess(self, messages: List[Dict]) -> List[Dict]: """refer to the spec of `super().preprocess`""" images = self.collect_images(messages) @@ -90,18 +42,6 @@ def preprocess(self, messages: List[Dict]) -> List[Dict]: messages.append(dict(role='preprocess', content=outputs)) return messages - @torch.no_grad() - def forward(self, messages: List[Dict]) -> List[Dict]: - """extract image feature. ONLY implement it when the backend is - turbomind engine. - - Args: - messages(List[Dict]): the outputs of `preprocess` - Return: - the message list with forwarding results included - """ - assert 0, 'cogvlm is not supported by turbomind' - @classmethod def proc_messages(cls, messages, chat_template, sequence_start): """apply chat template to get the prompt.""" @@ -145,6 +85,3 @@ def to_pytorch(self, messages, chat_template, tokenizer, sequence_start): sequence_start) return super().to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start) - - def to_turbomind(self, messages, chat_template, sequence_start): - assert 0, 'cogvlm is not supported by turbomind' diff --git a/lmdeploy/vl/model/deepseek.py b/lmdeploy/vl/model/deepseek.py index 99a3ec649..af682fb3b 100644 --- a/lmdeploy/vl/model/deepseek.py +++ b/lmdeploy/vl/model/deepseek.py @@ -106,26 +106,35 @@ def preprocess(self, messages: List[Dict]) -> List[Dict]: return messages @torch.no_grad() - def forward(self, messages: List[Dict]) -> List[Dict]: + def forward(self, + messages: List[Dict], + max_batch_size: int = 1) -> List[Dict]: """extract image feature. ONLY implement it when the backend is turbomind engine. Args: messages(List[Dict]): the outputs of `preprocess` + max_batch_size(int): the max batch size when forwarding vision + model Return: the message list with forwarding results included """ inputs = [x['content'] for x in messages if x['role'] == 'preprocess'] inputs = inputs[0] - pixel_values = [x['pixel_values'] for x in inputs] - pixel_values = torch.cat(pixel_values, dim=0) - pixel_values = pixel_values.to(device=next( - self.vision_model.parameters()).device, - dtype=torch.float16) - # [b x n_images, T2, D] - images_embeds = self.aligner(self.vision_model(pixel_values)) - outputs = torch.split(images_embeds, 1, dim=0) - outputs = [x.squeeze() for x in outputs] + outputs = [] + for idx in range(0, len(inputs), max_batch_size): + pixel_values = [ + x['pixel_values'] for x in inputs[idx:idx + max_batch_size] + ] + pixel_values = torch.cat(pixel_values, dim=0) + pixel_values = pixel_values.to(device=next( + self.vision_model.parameters()).device, + dtype=torch.float16) + # [b x n_images, T2, D] + logger.info(f'vision forward shape: {pixel_values.shape}') + feats = self.aligner(self.vision_model(pixel_values)) + feats = torch.split(feats, 1, dim=0) + outputs.extend([x.squeeze() for x in feats]) messages.append(dict(role='forward', content=outputs)) return messages diff --git a/lmdeploy/vl/model/glm_4v.py b/lmdeploy/vl/model/glm_4v.py index 5a8e2bb1b..2a72ee18f 100644 --- a/lmdeploy/vl/model/glm_4v.py +++ b/lmdeploy/vl/model/glm_4v.py @@ -1,13 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. -import warnings from typing import Dict, List -import torch from transformers import AutoConfig from lmdeploy.utils import get_logger from lmdeploy.vl.model.base import VISION_MODELS, VisonModel -from lmdeploy.vl.model.utils import disable_logging logger = get_logger('lmdeploy') @@ -40,51 +37,6 @@ def build_preprocessor(self): patch_size = self.hf_config.vision_config['patch_size'] self.n_token_per_image = 2 + (image_size // patch_size // 2)**2 - def build_model(self): - from accelerate import init_empty_weights, load_checkpoint_and_dispatch - from accelerate.utils import infer_auto_device_map - - with init_empty_weights(), warnings.catch_warnings(): - warnings.simplefilter('ignore') - from transformers import AutoModelForCausalLM - self.model = AutoModelForCausalLM.from_config( - self.hf_config, trust_remote_code=True) - if not self.with_llm: - del self.model.transformer.embedding - del self.model.transformer.rotary_pos_emb - del self.model.transformer.encoder - del self.model.transformer.output_layer - else: - self.vl_model = self.model - - no_split_module_classes = ['TransformerLayer'] - - device_map = infer_auto_device_map( - self.model, - no_split_module_classes=no_split_module_classes, - max_memory=self.max_memory, - dtype=torch.half) - - same_device_keys = [ - ('transformer.vision.linear_proj', 'transformer.vision.boi', - 'transformer.vision.eoi') - ] - for keys in same_device_keys: - keys = [k for k in keys if k in device_map] - if len(keys) <= 1: - continue - for k in keys[1:]: - device_map[k] = device_map[keys[0]] - - with disable_logging(): - load_checkpoint_and_dispatch( - model=self.model, - checkpoint=self.model_path, - device_map=device_map if not self.with_llm else {'': 'cpu'}, - no_split_module_classes=no_split_module_classes, - dtype=torch.half) - self.model.eval() - def preprocess(self, messages: List[Dict]) -> List[Dict]: """refers to the spec of `super.preprocess()""" images = self.collect_images(messages) @@ -100,18 +52,6 @@ def preprocess(self, messages: List[Dict]) -> List[Dict]: messages.append(dict(role='preprocess', content=outputs)) return messages - @torch.no_grad() - def forward(self, messages: List[Dict]) -> List[Dict]: - """extract image feature. ONLY implement it when the backend is - turbomind engine. - - Args: - messages(List[Dict]): the outputs of `preprocess` - Return: - the message list with forwarding results included - """ - assert 0, 'glm4v is not supported by turbomind' - @classmethod def proc_messages(cls, messages, chat_template, sequence_start): """apply chat template to get the prompt.""" @@ -136,6 +76,3 @@ def to_pytorch(self, messages, chat_template, tokenizer, sequence_start): sequence_start) return super().to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start) - - def to_turbomind(self, messages, chat_template, tokenizer, sequence_start): - assert 0, 'glm4v is not supported by turbomind' diff --git a/lmdeploy/vl/model/internvl.py b/lmdeploy/vl/model/internvl.py index e130b7738..b5abb55fd 100644 --- a/lmdeploy/vl/model/internvl.py +++ b/lmdeploy/vl/model/internvl.py @@ -167,16 +167,25 @@ def _preprocess_v1_5(self, image, params=None): pixel_values = torch.stack(pixel_values) return pixel_values - def _forward_v1_5(self, inputs): + def _forward_v1_5(self, inputs, max_batch_size): """forward for internvl-chat-v1-5.""" assert all(x.get('pixel_values') is not None for x in inputs) - outputs = [x['pixel_values'] for x in inputs] - split = [x['pixel_values'].shape[0] for x in inputs] - outputs = torch.cat(outputs, dim=0) - outputs = outputs.to(self.model.device, dtype=torch.float16) - outputs = self.model.extract_feature(outputs) - outputs = torch.split(outputs, split, dim=0) - outputs = [x.reshape(-1, x.shape[-1]) for x in outputs] + outputs = [] + for idx in range(0, len(inputs), max_batch_size): + pixel_values = [ + x['pixel_values'] for x in inputs[idx:idx + max_batch_size] + ] + split = [ + x['pixel_values'].shape[0] + for x in inputs[idx:idx + max_batch_size] + ] + pixel_values = torch.cat(pixel_values, dim=0) + pixel_values = pixel_values.to(self.model.device, + dtype=torch.float16) + logger.info(f'vision forward shape: {pixel_values.shape}') + feats = self.model.extract_feature(pixel_values) + feats = torch.split(feats, split, dim=0) + outputs.extend([x.reshape(-1, x.shape[-1]) for x in feats]) return outputs def _preprocess(self, image, params=None): @@ -185,15 +194,21 @@ def _preprocess(self, image, params=None): return_tensors='pt').pixel_values return pixel_values - def _forward(self, inputs): + def _forward(self, inputs, max_batch_size): """forward for internvl-chat-v1-1, internvl-chat-v1-2.""" assert all(x.get('pixel_values') is not None for x in inputs) - outputs = [x['pixel_values'] for x in inputs] - outputs = torch.cat(outputs, dim=0) - outputs = outputs.to(self.model.device, dtype=torch.float16) - outputs = self.model.extract_feature(outputs) - outputs = torch.split(outputs, 1, dim=0) - outputs = [x.squeeze() for x in outputs] + outputs = [] + for idx in range(0, len(inputs), max_batch_size): + pixel_values = [ + x['pixel_values'] for x in inputs[idx:idx + max_batch_size] + ] + pixel_values = torch.cat(outputs, dim=0) + pixel_values = pixel_values.to(self.model.device, + dtype=torch.float16) + logger.info(f'vision forward shape: {pixel_values.shape}') + feats = self.model.extract_feature(pixel_values) + feats = torch.split(feats, 1, dim=0) + outputs.extend([x.squeeze() for x in feats]) return outputs def preprocess(self, messages: List[Dict]) -> List[Dict]: @@ -214,18 +229,22 @@ def preprocess(self, messages: List[Dict]) -> List[Dict]: return messages @torch.no_grad() - def forward(self, messages: List[Dict]) -> List[Dict]: + def forward(self, + messages: List[Dict], + max_batch_size: int = 1) -> List[Dict]: """extract image feature. ONLY implement it when the backend is turbomind engine. Args: messages(List[Dict]): the outputs of `preprocess` + max_batch_size(int): the max batch size when forwarding vision + model Return: the message list with forwarding results included """ inputs = [x['content'] for x in messages if x['role'] == 'preprocess'] inputs = inputs[0] - outputs = self._forward_func(inputs) + outputs = self._forward_func(inputs, max_batch_size) messages.append(dict(role='forward', content=outputs)) return messages diff --git a/lmdeploy/vl/model/internvl_llava.py b/lmdeploy/vl/model/internvl_llava.py index df51e6893..46839f5f9 100644 --- a/lmdeploy/vl/model/internvl_llava.py +++ b/lmdeploy/vl/model/internvl_llava.py @@ -142,29 +142,38 @@ def preprocess(self, messages: List[Dict]) -> List[Dict]: return super().preprocess(messages) @torch.no_grad() - def forward(self, messages: List[Dict]) -> List[Dict]: + def forward(self, + messages: List[Dict], + max_batch_size: int = 1) -> List[Dict]: """extract image feature. ONLY implement it when the backend is turbomind engine. Args: messages(List[Dict]): the outputs of `preprocess` + max_batch_size(int): the max batch size when forwarding vision + model Return: the message list with forwarding results included """ inputs = [x['content'] for x in messages if x['role'] == 'preprocess'] inputs = inputs[0] - pixel_values = [x['pixel_values'] for x in inputs] - split_sizes = [x.shape[0] for x in pixel_values] - pixel_values = torch.cat(pixel_values, dim=0) - pixel_values = pixel_values.to(device=self.vision_tower.device, - dtype=torch.float16) - - if pixel_values.ndim == 5: - image_features = self.encode_images(pixel_values) - image_features = torch.split(image_features, split_sizes, dim=0) - image_features = [x.flatten(0, 1) for x in image_features] - else: - image_features = self.encode_images(pixel_values) - image_features = [x for x in image_features] - messages.append(dict(role='forward', content=image_features)) + outputs = [] + for idx in range(0, len(inputs), max_batch_size): + pixel_values = [ + x['pixel_values'] for x in inputs[idx:idx + max_batch_size] + ] + split_sizes = [x.shape[0] for x in pixel_values] + pixel_values = torch.cat(pixel_values, dim=0) + pixel_values = pixel_values.to(device=self.vision_tower.device, + dtype=torch.float16) + logger.info(f'vision forward shape: {pixel_values.shape}') + if pixel_values.ndim == 5: + feats = self.encode_images(pixel_values) + feats = torch.split(feats, split_sizes, dim=0) + feats = [x.flatten(0, 1) for x in feats] + else: + feats = self.encode_images(pixel_values) + feats = [x for x in feats] + outputs.extend(feats) + messages.append(dict(role='forward', content=outputs)) return messages diff --git a/lmdeploy/vl/model/llava.py b/lmdeploy/vl/model/llava.py index a004d0fb7..30a44f848 100644 --- a/lmdeploy/vl/model/llava.py +++ b/lmdeploy/vl/model/llava.py @@ -327,12 +327,16 @@ def preprocess(self, messages: List[Dict]) -> List[Dict]: return messages @torch.no_grad() - def forward(self, messages: List[Dict]) -> List[Dict]: + def forward(self, + messages: List[Dict], + max_batch_size: int = 1) -> List[Dict]: """extract image feature. ONLY implement it when the backend is turbomind engine. Args: messages(List[Dict]): the outputs of `preprocess` + max_batch_size(int): the max batch size when forwarding vision + model Return: the message list with forwarding results included """ @@ -340,79 +344,79 @@ def forward(self, messages: List[Dict]) -> List[Dict]: unpad_image) inputs = [x['content'] for x in messages if x['role'] == 'preprocess'] inputs = inputs[0] - image_sizes = [x['image_size'] for x in inputs] - pixel_values = [x['pixel_values'] for x in inputs] - pixel_values = torch.cat(pixel_values, dim=0) - pixel_values = pixel_values.to(device=self.vision_tower.device, - dtype=torch.float16) - if pixel_values.ndim == 5: - split_sizes = [x.shape[0] for x in pixel_values] - pixel_values = torch.cat([x for x in pixel_values], dim=0) - image_features = self.encode_images(pixel_values) - image_features = torch.split(image_features, split_sizes, dim=0) - mm_patch_merge_type = getattr(self.config, 'mm_patch_merge_type', - 'flat') - image_aspect_ratio = getattr(self.config, 'image_aspect_ratio', - 'square') - if mm_patch_merge_type == 'flat': - image_features = [x.flatten(0, 1) for x in image_features] - elif mm_patch_merge_type.startswith('spatial'): - new_image_features = [] - for image_idx, image_feature in enumerate(image_features): - if image_feature.shape[0] > 1: - base_image_feature = image_feature[0] - image_feature = image_feature[1:] - height = width = self.vision_tower.num_patches_per_side - assert height * width == base_image_feature.shape[0] - if image_aspect_ratio == 'anyres': - num_patch_width, num_patch_height = \ - get_anyres_image_grid_shape( - image_sizes[image_idx], - self.config.image_grid_pinpoints, - self.vision_tower.config.image_size) - image_feature = image_feature.view( - num_patch_height, num_patch_width, height, - width, -1) - else: - raise NotImplementedError - if 'unpad' in mm_patch_merge_type: - image_feature = image_feature.permute( - 4, 0, 2, 1, 3).contiguous() - image_feature = image_feature.flatten(1, - 2).flatten( - 2, 3) - image_feature = unpad_image( - image_feature, image_sizes[image_idx]) - image_feature = torch.cat(( - image_feature, - self.model.image_newline[:, None, None].expand( - *image_feature.shape[:-1], 1).to( - image_feature.device)), - dim=-1) - image_feature = image_feature.flatten(1, - 2).transpose( - 0, 1) + outputs = [] + for idx in range(0, len(inputs), max_batch_size): + image_sizes = [ + x['image_size'] for x in inputs[idx:idx + max_batch_size] + ] + pixel_values = [ + x['pixel_values'] for x in inputs[idx:idx + max_batch_size] + ] + if pixel_values[0].ndim == 5: + split_sizes = [x.shape[1] for x in pixel_values] + pixel_values = torch.cat([x for x in pixel_values], dim=1) + logger.info(f'vision forward shape: {pixel_values.shape}') + pixel_values = pixel_values.squeeze(0) + pixel_values = pixel_values.to(device=self.vision_tower.device, + dtype=torch.float16) + feats = self.encode_images(pixel_values) + feats = torch.split(feats, split_sizes, dim=0) + mm_patch_merge_type = getattr(self.config, + 'mm_patch_merge_type', 'flat') + image_aspect_ratio = getattr(self.config, 'image_aspect_ratio', + 'square') + if mm_patch_merge_type == 'flat': + outputs.expand([x.flatten(0, 1) for x in feats]) + elif mm_patch_merge_type.startswith('spatial'): + for img_idx, feat in enumerate(feats): + if feat.shape[0] > 1: + base_feat = feat[0] + feat = feat[1:] + height = self.vision_tower.num_patches_per_side + width = self.vision_tower.num_patches_per_side + assert height * width == base_feat.shape[0] + if image_aspect_ratio == 'anyres': + num_patch_width, num_patch_height = \ + get_anyres_image_grid_shape( + image_sizes[img_idx], + self.config.image_grid_pinpoints, + self.vision_tower.config.image_size) + feat = feat.view(num_patch_height, + num_patch_width, height, + width, -1) + else: + raise NotImplementedError + if 'unpad' in mm_patch_merge_type: + feat = feat.permute(4, 0, 2, 1, 3).contiguous() + feat = feat.flatten(1, 2).flatten(2, 3) + feat = unpad_image(feat, image_sizes[img_idx]) + feat = torch.cat( + (feat, self.model. + image_newline[:, None, None].expand( + *feat.shape[:-1], 1).to(feat.device)), + dim=-1) + feat = feat.flatten(1, 2).transpose(0, 1) + else: + feat = feat.permute(0, 2, 1, 3, 4).contiguous() + feat = feat.flatten(0, 3) + feat = torch.cat((base_feat, feat), dim=0) else: - image_feature = image_feature.permute( - 0, 2, 1, 3, 4).contiguous() - image_feature = image_feature.flatten(0, 3) - image_feature = torch.cat( - (base_image_feature, image_feature), dim=0) - else: - image_feature = image_feature[0] - if 'unpad' in mm_patch_merge_type: - image_feature = torch.cat( - (image_feature, - self.model.image_newline[None].to( - image_feature.device)), - dim=0) - new_image_features.append(image_feature) - image_features = new_image_features + feat = feat[0] + if 'unpad' in mm_patch_merge_type: + feat = torch.cat( + (feat, self.model.image_newline[None].to( + feat.device)), + dim=0) + outputs.append(feat) + else: + raise ValueError('Unexpected mm_patch_merge_type: ' + f'{self.config.mm_patch_merge_type}') else: - raise ValueError('Unexpected mm_patch_merge_type: ' - f'{self.config.mm_patch_merge_type}') - else: - image_features = self.encode_images(pixel_values) - image_features = [x for x in image_features] - messages.append(dict(role='forward', content=image_features)) + pixel_values = torch.cat(pixel_values, dim=0) + pixel_values = pixel_values.to(device=self.vision_tower.device, + dtype=torch.float16) + logger.info(f'vision forward shape: {pixel_values.shape}') + feats = self.encode_images(pixel_values) + outputs.extend([x for x in feats]) + messages.append(dict(role='forward', content=outputs)) return messages diff --git a/lmdeploy/vl/model/llava_hf.py b/lmdeploy/vl/model/llava_hf.py index bf8d1abe4..d97a736b9 100644 --- a/lmdeploy/vl/model/llava_hf.py +++ b/lmdeploy/vl/model/llava_hf.py @@ -5,9 +5,12 @@ import torch from transformers import AutoProcessor +from lmdeploy.utils import get_logger from lmdeploy.vl.model.base import VISION_MODELS, VisonModel from lmdeploy.vl.model.utils import disable_logging +logger = get_logger('lmdeploy') + @VISION_MODELS.register_module() class LlavaHfVisionModel(VisonModel): @@ -75,36 +78,45 @@ def preprocess(self, messages: List[Dict]) -> List[Dict]: return messages @torch.no_grad() - def forward(self, messages: List[Dict]) -> List[Dict]: + def forward(self, + messages: List[Dict], + max_batch_size: int = 1) -> List[Dict]: """extract image feature. ONLY implement it when the backend is turbomind engine. Args: messages(List[Dict]): the outputs of `preprocess` + max_batch_size(int): the max batch size when forwarding vision + model Return: the message list with forwarding results included """ inputs = [x['content'] for x in messages if x['role'] == 'preprocess'] inputs = inputs[0] - pixel_values = [x['pixel_values'] for x in inputs] - pixel_values = torch.cat(pixel_values, dim=0) - pixel_values = pixel_values.to(device=self.model.device, - dtype=self.model.dtype) - image_outputs = self.model.vision_tower.forward( - pixel_values, output_hidden_states=True) - image_features = image_outputs.hidden_states[ - self.hf_config.vision_feature_layer] - if self.hf_config.vision_feature_select_strategy == 'default': - image_features = image_features[:, 1:] - elif self.hf_config.vision_feature_select_strategy == 'full': - image_features = image_features - else: - raise ValueError( - 'Unexpected select feature strategy: ' - f'{self.hf_config.vision_feature_select_strategy}') - image_features = self.model.multi_modal_projector(image_features) - outputs = torch.split(image_features, 1, dim=0) - outputs = [x.squeeze() for x in outputs] + outputs = [] + for idx in range(0, len(inputs), max_batch_size): + pixel_values = [ + x['pixel_values'] for x in inputs[idx:idx + max_batch_size] + ] + pixel_values = torch.cat(pixel_values, dim=0) + pixel_values = pixel_values.to(device=self.model.device, + dtype=self.model.dtype) + logger.info(f'vision forward shape: {pixel_values.shape}') + image_outputs = self.model.vision_tower.forward( + pixel_values, output_hidden_states=True) + image_features = image_outputs.hidden_states[ + self.hf_config.vision_feature_layer] + if self.hf_config.vision_feature_select_strategy == 'default': + image_features = image_features[:, 1:] + elif self.hf_config.vision_feature_select_strategy == 'full': + image_features = image_features + else: + raise ValueError( + 'Unexpected select feature strategy: ' + f'{self.hf_config.vision_feature_select_strategy}') + image_features = self.model.multi_modal_projector(image_features) + image_features = torch.split(image_features, 1, dim=0) + outputs.extend([x.squeeze() for x in image_features]) messages.append(dict(role='forward', content=outputs)) return messages diff --git a/lmdeploy/vl/model/llava_next.py b/lmdeploy/vl/model/llava_next.py index 74f4837ed..ab58b105d 100644 --- a/lmdeploy/vl/model/llava_next.py +++ b/lmdeploy/vl/model/llava_next.py @@ -116,65 +116,80 @@ def preprocess(self, messages: List[Dict]) -> List[Dict]: return messages @torch.no_grad() - def forward(self, messages: List[Dict]) -> List[Dict]: + def forward(self, + messages: List[Dict], + max_batch_size: int = 1) -> List[Dict]: """extract image feature. ONLY implement it when the backend is turbomind engine. Args: messages(List[Dict]): the outputs of `preprocess` + max_batch_size(int): the max batch size when forwarding vision + model Return: the message list with forwarding results included """ inputs = [x['content'] for x in messages if x['role'] == 'preprocess'] inputs = inputs[0] - pixel_values = [ - x['pixel_values'].to(device=self.model.device, - dtype=self.model.dtype) for x in inputs - ] - pixel_values = torch.cat(pixel_values, dim=0) - image_sizes = [ - x['image_sizes'].to(device=self.model.device, - dtype=self.model.dtype) for x in inputs - ] - image_sizes = torch.cat(image_sizes, dim=0) - image_num_patches = [x['num_patch'] for x in inputs] - image_num_patches = list(itertools.chain(*image_num_patches)) - # figure out if pixel_values is concatenated or stacked - if pixel_values.dim() == 5: - # stacking when input is - # (batch_size, num_patches, num_channels, height, width) - _pixel_values_list = [ - pix_val[:num_patch] - for pix_val, num_patch in zip(pixel_values, image_num_patches) + outputs = [] + for idx in range(0, len(inputs), max_batch_size): + pixel_values = [ + x['pixel_values'].to(device=self.model.device, + dtype=self.model.dtype) + for x in inputs[idx:idx + max_batch_size] ] - pixel_values = torch.cat(_pixel_values_list, dim=0) - elif pixel_values.dim() != 4: - # otherwise has to be stacked from list of - # (num_patches, num_channels, height, width) - raise ValueError(f'pixel_values of shape {pixel_values.shape}, ' - 'expect to be of 4 or 5 dimensions') - image_outputs = self.model.vision_tower.forward( - pixel_values, output_hidden_states=True) - image_features = image_outputs.hidden_states[ - self.hf_config.vision_feature_layer] - strategy = self.hf_config.vision_feature_select_strategy - if strategy == 'default': - image_features = image_features[:, 1:] - elif strategy == 'full': - image_features = image_features - else: - raise ValueError('Unexpected select feature strategy: ' - f'{strategy}') - image_features = self.model.multi_modal_projector(image_features) - image_features = torch.split(image_features, image_num_patches, dim=0) - image_features, feature_lens = self.model.pack_image_features( - image_features, - image_sizes, - vision_feature_select_strategy=strategy, - image_newline=self.model.image_newline, - ) - outputs = torch.split(image_features, - feature_lens.cpu().numpy().tolist(), - dim=0) + pixel_values = torch.cat(pixel_values, dim=0) + image_sizes = [ + x['image_sizes'].to(device=self.model.device, + dtype=self.model.dtype) + for x in inputs[idx:idx + max_batch_size] + ] + image_sizes = torch.cat(image_sizes, dim=0) + image_num_patches = [ + x['num_patch'] for x in inputs[idx:idx + max_batch_size] + ] + image_num_patches = list(itertools.chain(*image_num_patches)) + # figure out if pixel_values is concatenated or stacked + if pixel_values.dim() == 5: + # stacking when input is + # (batch_size, num_patches, num_channels, height, width) + _pixel_values_list = [ + pix_val[:num_patch] for pix_val, num_patch in zip( + pixel_values, image_num_patches) + ] + pixel_values = torch.cat(_pixel_values_list, dim=0) + elif pixel_values.dim() != 4: + # otherwise has to be stacked from list of + # (num_patches, num_channels, height, width) + raise ValueError( + f'pixel_values of shape {pixel_values.shape}, ' + 'expect to be of 4 or 5 dimensions') + logger.info(f'vision forward shape: {pixel_values.shape}') + image_outputs = self.model.vision_tower.forward( + pixel_values, output_hidden_states=True) + image_features = image_outputs.hidden_states[ + self.hf_config.vision_feature_layer] + strategy = self.hf_config.vision_feature_select_strategy + if strategy == 'default': + image_features = image_features[:, 1:] + elif strategy == 'full': + image_features = image_features + else: + raise ValueError('Unexpected select feature strategy: ' + f'{strategy}') + image_features = self.model.multi_modal_projector(image_features) + image_features = torch.split(image_features, + image_num_patches, + dim=0) + image_features, feature_lens = self.model.pack_image_features( + image_features, + image_sizes, + vision_feature_select_strategy=strategy, + image_newline=self.model.image_newline, + ) + image_features = torch.split(image_features, + feature_lens.cpu().numpy().tolist(), + dim=0) + outputs.extend(image_features) messages.append(dict(role='forward', content=outputs)) return messages diff --git a/lmdeploy/vl/model/mini_gemeni.py b/lmdeploy/vl/model/mini_gemeni.py index 417f32fe6..f7054628c 100644 --- a/lmdeploy/vl/model/mini_gemeni.py +++ b/lmdeploy/vl/model/mini_gemeni.py @@ -7,11 +7,14 @@ import torch +from lmdeploy.utils import get_logger from lmdeploy.vl.model.base import VISION_MODELS, VisonModel from lmdeploy.vl.model.utils import (add_device_hook, disable_logging, disable_transformers_logging, hack_import_with) +logger = get_logger('lmdeploy') + def check_mini_gemini_install(): """check mini gemini install.""" @@ -255,12 +258,16 @@ def preprocess(self, messages: List[Dict]) -> List[Dict]: return messages @torch.no_grad() - def forward(self, messages: List[Dict]) -> List[Dict]: + def forward(self, + messages: List[Dict], + max_batch_size: int = 1) -> List[Dict]: """extract image feature. ONLY implement it when the backend is turbomind engine. Args: messages(List[Dict]): the outputs of `preprocess` + max_batch_size(int): the max batch size when forwarding vision + model Return: the message list with forwarding results included """ @@ -326,12 +333,13 @@ def forward(self, messages: List[Dict]) -> List[Dict]: image.to(self.model.device, dtype=torch.float16) for image in image_tensor_aux ] + logger.info(f'vision forward bs: {len(image_tensor)}') else: image_tensor = image_tensor.to(self.model.device, dtype=torch.float16) image_tensor_aux = image_tensor_aux.to(self.model.device, dtype=torch.float16) - + logger.info(f'vision forward shape: {image_tensor.shape}') images_embeds = self.model.encode_images(image_tensor, image_tensor_aux) diff --git a/lmdeploy/vl/model/minicpmv.py b/lmdeploy/vl/model/minicpmv.py index 3098944be..a31135be4 100644 --- a/lmdeploy/vl/model/minicpmv.py +++ b/lmdeploy/vl/model/minicpmv.py @@ -160,12 +160,16 @@ def preprocess(self, messages: List[Dict]) -> List[Dict]: return messages @torch.no_grad() - def forward(self, messages: List[Dict]) -> List[Dict]: + def forward(self, + messages: List[Dict], + max_batch_size: int = 1) -> List[Dict]: """extract image feature. ONLY implement it when the backend is turbomind engine. Args: messages(List[Dict]): the outputs of `preprocess` + max_batch_size(int): the max batch size when forwarding vision + model Return: the message list with forwarding results included """ @@ -197,6 +201,7 @@ def forward(self, messages: List[Dict]) -> List[Dict]: device=self.model.device) for i in range(B): patch_attn_mask[i, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True + logger.info(f'vision forward shape: {pixel_values.shape}') if self.version == '2.5': embeddings = self.model.vpm( pixel_values.type(torch.half), diff --git a/lmdeploy/vl/model/mllama.py b/lmdeploy/vl/model/mllama.py index 3163381db..779ef4f92 100644 --- a/lmdeploy/vl/model/mllama.py +++ b/lmdeploy/vl/model/mllama.py @@ -2,8 +2,6 @@ from typing import Dict, List -import torch - from lmdeploy.vl.model.base import VISION_MODELS, VisonModel @@ -18,9 +16,6 @@ def build_preprocessor(self): self.processor = AutoProcessor.from_pretrained(self.model_path) self.image_token_id = 128256 - def build_model(self): - assert 0, 'mllama is not supported by turbomind' - def preprocess(self, messages: List[Dict]) -> List[Dict]: """refer to the spec of `super().preprocess`""" images = self.collect_images(messages) @@ -58,23 +53,8 @@ def proc_messages(cls, messages, chat_template, sequence_start): prompt = chat_template.messages2prompt(prompt_messages, sequence_start) return prompt, IMAGE_TOKEN - @torch.no_grad() - def forward(self, messages: List[Dict]) -> List[Dict]: - """extract image feature. ONLY implement it when the backend is - turbomind engine. - - Args: - messages(List[Dict]): the outputs of `preprocess` - Return: - the message list with forwarding results included - """ - assert 0, 'cogvlm is not supported by turbomind' - def to_pytorch(self, messages, chat_template, tokenizer, sequence_start): prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, sequence_start) return super().to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start) - - def to_turbomind(self, messages, chat_template, sequence_start): - assert 0, 'mllama is not supported by turbomind' diff --git a/lmdeploy/vl/model/molmo.py b/lmdeploy/vl/model/molmo.py index 042fc3108..2eec485f0 100644 --- a/lmdeploy/vl/model/molmo.py +++ b/lmdeploy/vl/model/molmo.py @@ -83,12 +83,16 @@ def preprocess(self, messages: List[Dict]) -> List[Dict]: return messages @torch.no_grad() - def forward(self, messages: List[Dict]) -> List[Dict]: + def forward(self, + messages: List[Dict], + max_batch_size: int = 1) -> List[Dict]: """extract image feature. ONLY implement it when the backend is turbomind engine. Args: messages(List[Dict]): the outputs of `preprocess` + max_batch_size(int): the max batch size when forwarding vision + model Return: the message list with forwarding results included """ @@ -113,6 +117,7 @@ def forward(self, messages: List[Dict]) -> List[Dict]: embeddings = self.model.model.transformer.wte(input_ids) images = images.to(self.model.dtype) image_masks = image_masks.to(self.model.dtype) + logger.info(f'vision forward shape: {images.shape}') image_features, _ = self.model.model.vision_backbone( images, image_masks) num_image, num_patch = image_features.shape[1:3] diff --git a/lmdeploy/vl/model/phi3_vision.py b/lmdeploy/vl/model/phi3_vision.py index a4a848f27..80204a2de 100644 --- a/lmdeploy/vl/model/phi3_vision.py +++ b/lmdeploy/vl/model/phi3_vision.py @@ -1,125 +1,15 @@ # Copyright (c) OpenMMLab. All rights reserved. -import warnings from typing import Dict, List -import torch from transformers import AutoProcessor from lmdeploy.vl.model.llava_hf import VISION_MODELS, LlavaHfVisionModel -from lmdeploy.vl.model.utils import disable_logging - - -# from https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_embedding_phi3_v.py # noqa E501 -def _process_image_embedding(self, pixel_values: torch.Tensor, - image_sizes: torch.Tensor): - """process image embedding.""" - img_embeds = pixel_values - img_sizes = image_sizes - target_device = pixel_values.device - target_dtype = pixel_values.dtype - if self.use_hd_transform and img_sizes is not None and len(img_sizes): - assert img_embeds.ndim == 5, f'img_embeds size: {img_embeds.size()}, expect 5D tensor for hd transform' # noqa E501 - # img_embeds: (num_images, max_num_crops, 3, H, W) - # img_sizes: (num_images, 2).view(1, -1) - - bs = img_embeds.shape[0] - # Nx(HW)xC - img_features = self.get_img_features(img_embeds.flatten(0, 1)) - base_feat_height = base_feat_width = int(img_features.shape[1]**0.5) - - assert base_feat_height == 24 and base_feat_width == 24, f'base_feat_height: {base_feat_height}, base_feat_width: {base_feat_width}, expect 24x24 features for hd transform' # noqa E501 - - # bs x max_num_crops x (24x24) x C - img_features = img_features.view(bs, -1, - base_feat_height * base_feat_width, - self.image_dim_out) - C = self.image_dim_out - H = base_feat_height - - output_imgs = [] - output_len = [] - # training is tensor, inference is list - if isinstance(img_sizes, torch.Tensor): - img_sizes = img_sizes.view(-1, 2) - for _bs in range(bs): - h, w = img_sizes[_bs] - h = h // 336 - w = w // 336 - B_ = h * w - - # 1 x (24x24) x 1024 - global_img_feature = img_features[_bs, :1] - - # 1 x 12 x 12 x 4096 - glb_img = global_img_feature.reshape(1, H, H, C).reshape( - 1, H // 2, 2, H // 2, 2, - C).contiguous().permute(0, 1, 3, 2, 4, - 5).reshape(1, H // 2, H // 2, - 4 * C).contiguous() - temp_glb_GN = self.sub_GN.repeat(1, H // 2, 1, 1) - - # 1 x 156 x 4096 - glb_img = torch.cat([glb_img, temp_glb_GN], - dim=2).reshape(1, -1, 4 * C) - - # (max_num_crops-1) x (12x12) x C - sub_img = img_features[_bs, 1:] - # 16x574x1024 - # get rid of padding sub_img - sub_img = sub_img[:B_] - - # (num_crops, 12, 2, 12, 2, 1024)->(num_crops, 12, 12, 2, 2, 1024) - # -> (num_crops, 12*12, 4*1024) - sub_img = sub_img.reshape(B_, H, H, C).reshape( - B_, H // 2, 2, H // 2, 2, - C).contiguous().permute(0, 1, 3, 2, 4, - 5).reshape(B_, -1, 4 * C).contiguous() - sub_img = sub_img.reshape(1, h, w, 12, 12, -1).permute( - 0, 1, 3, 2, 4, 5).reshape(1, h * 12, w * 12, 4 * C) - temp_sub_GN = self.sub_GN.repeat(1, h * 12, 1, 1) - sub_img = torch.cat([sub_img, temp_sub_GN], - dim=2).reshape(1, -1, 4 * C) - # (1, num_img_tokens, 1024*4) - - # glb + sub - if self.hd_transform_order == 'glb_sub': - output_imgs.append( - torch.cat([glb_img, self.glb_GN, sub_img], dim=1)) - elif self.hd_transform_order == 'sub_glb': - output_imgs.append( - torch.cat([sub_img, self.glb_GN, glb_img], dim=1)) - else: - raise NotImplementedError( - f'hd_transform_order = {self.hd_transform_order}' - ) # noqa E501 - - temp_len = int((h * w + 1) * 144 + 1 + (h + 1) * 12) - assert temp_len == output_imgs[-1].shape[ - 1], f'temp_len: {temp_len}, output_imgs[-1].shape[1]: {output_imgs[-1].shape[1]}' # noqa E501 - output_len.append(temp_len) - - img_set_tensor = [] - for _output_img in output_imgs: - img_feature_proj = self.img_projection( - _output_img.to(target_device).to(target_dtype)) - img_set_tensor.append(img_feature_proj) - elif img_embeds.ndim == 4: - tt = (self.get_img_features(img_embeds).to(target_device).to( - target_dtype).reshape(-1, self.image_dim_out)) - img_set_tensor = self.img_projection(tt) # adapted visual features. - elif img_embeds.ndim == 3: - tt = (img_embeds.to(target_device).to(target_dtype).view( - -1, self.image_dim_out)) - img_set_tensor = self.img_projection(tt) # adapted visual features. - else: - raise NotImplementedError - return img_set_tensor @VISION_MODELS.register_module() class Phi3VisionModel(LlavaHfVisionModel): - """Llava hf vision model.""" + """Phi3-vision model.""" _arch = 'Phi3VForCausalLM' @@ -131,56 +21,6 @@ def build_preprocessor(self): processor.tokenizer = None self.processor = processor - def build_model(self): - from accelerate import init_empty_weights, load_checkpoint_and_dispatch - from accelerate.utils import get_balanced_memory, infer_auto_device_map - - with init_empty_weights(), warnings.catch_warnings(): - warnings.simplefilter('ignore') - from transformers import AutoModelForCausalLM - model = AutoModelForCausalLM.from_config(self.hf_config, - trust_remote_code=True) - if not self.with_llm: - del model.lm_head - del model.model.layers - del model.model.norm - del model.model.embed_tokens - del model.model.vision_embed_tokens.wte - else: - self.vl_model = model - - no_split_module_classes = ['CLIPEncoderLayer'] - max_memory = get_balanced_memory( - model, - max_memory=self.max_memory, - dtype=torch.half, - no_split_module_classes=no_split_module_classes) - device_map = infer_auto_device_map( - model, - no_split_module_classes=no_split_module_classes, - max_memory=max_memory, - dtype=torch.half) - same_device_keys = [('model.vision_embed_tokens.img_projection', - 'model.vision_embed_tokens.sub_GN', - 'model.vision_embed_tokens.glb_GN')] - for keys in same_device_keys: - keys = [k for k in keys if k in device_map] - if len(keys) <= 1: - continue - for k in keys[1:]: - device_map[k] = device_map[keys[0]] - - with disable_logging(): - load_checkpoint_and_dispatch( - model=model, - checkpoint=self.model_path, - device_map=device_map if not self.with_llm else {'': 'cpu'}, - no_split_module_classes=no_split_module_classes, - dtype=torch.half) - - model.eval() - self.model = model - def preprocess(self, messages: List[Dict]) -> List[Dict]: """refers to `super.preprocess() for spec.""" images = self.collect_images(messages) @@ -198,27 +38,3 @@ def preprocess(self, messages: List[Dict]) -> List[Dict]: outputs.append(result) messages.append(dict(role='preprocess', content=outputs)) return messages - - @torch.no_grad() - def forward(self, messages: List[Dict]) -> List[Dict]: - """extract image feature. ONLY implement it when the backend is - turbomind engine. - - Args: - messages(List[Dict]): the outputs of `preprocess` - Return: - the message list with forwarding results included - """ - inputs = [x['content'] for x in messages if x['role'] == 'preprocess'] - inputs = inputs[0] - pixel_values = [x['pixel_values'] for x in inputs] - pixel_values = torch.stack(pixel_values, dim=0) - image_sizes = [x['image_sizes'] for x in inputs] - image_sizes = torch.stack(image_sizes, dim=0) - image_features = _process_image_embedding( - self.model.model.vision_embed_tokens, - pixel_values=pixel_values, - image_sizes=image_sizes) - outputs = [x.squeeze() for x in image_features] - messages.append(dict(role='forward', content=outputs)) - return messages diff --git a/lmdeploy/vl/model/qwen.py b/lmdeploy/vl/model/qwen.py index f2089919d..a822da1e4 100644 --- a/lmdeploy/vl/model/qwen.py +++ b/lmdeploy/vl/model/qwen.py @@ -5,9 +5,12 @@ import torch from transformers import AutoModelForCausalLM +from lmdeploy.utils import get_logger from lmdeploy.vl.model.base import VISION_MODELS, VisonModel from lmdeploy.vl.model.utils import disable_logging +logger = get_logger('lmdeploy') + @VISION_MODELS.register_module() class QwenVisionModel(VisonModel): @@ -88,22 +91,31 @@ def preprocess(self, messages: List[Dict]) -> List[Dict]: return messages @torch.no_grad() - def forward(self, messages: List[Dict]) -> List[Dict]: + def forward(self, + messages: List[Dict], + max_batch_size: int = 1) -> List[Dict]: """extract image feature. ONLY implement it when the backend is turbomind engine. Args: messages(List[Dict]): the outputs of `preprocess` + max_batch_size(int): the max batch size when forwarding vision + model Return: the message list with forwarding results included """ inputs = [x['content'] for x in messages if x['role'] == 'preprocess'] inputs = inputs[0] - pixel_values = [x['pixel_values'] for x in inputs] - pixel_values = torch.stack(pixel_values, dim=0) - outputs = self.model(pixel_values) - outputs = torch.split(outputs, 1, dim=0) - outputs = [x.squeeze() for x in outputs] + outputs = [] + for idx in range(0, len(inputs), max_batch_size): + pixel_values = [ + x['pixel_values'] for x in inputs[idx:idx + max_batch_size] + ] + pixel_values = torch.stack(pixel_values, dim=0) + logger.info(f'vision forward shape: {pixel_values.shape}') + feats = self.model(pixel_values) + feats = torch.split(feats, 1, dim=0) + outputs.extend([x.squeeze() for x in feats]) messages.append(dict(role='forward', content=outputs)) return messages diff --git a/lmdeploy/vl/model/qwen2.py b/lmdeploy/vl/model/qwen2.py index d281f49e4..cd4ef4b6c 100644 --- a/lmdeploy/vl/model/qwen2.py +++ b/lmdeploy/vl/model/qwen2.py @@ -99,12 +99,16 @@ def preprocess(self, messages: List[Dict]) -> List[Dict]: return messages @torch.no_grad() - def forward(self, messages: List[Dict]) -> List[Dict]: + def forward(self, + messages: List[Dict], + max_batch_size: int = 1) -> List[Dict]: """extract image feature. ONLY implement it when the backend is turbomind engine. Args: messages(List[Dict]): the outputs of `preprocess` + max_batch_size(int): the max batch size when forwarding vision + model Return: the message list with forwarding results included """ diff --git a/lmdeploy/vl/model/xcomposer2.py b/lmdeploy/vl/model/xcomposer2.py index 4400a00f9..8df093c0f 100644 --- a/lmdeploy/vl/model/xcomposer2.py +++ b/lmdeploy/vl/model/xcomposer2.py @@ -228,35 +228,47 @@ def preprocess(self, messages: List[Dict]) -> List[Dict]: return messages @torch.no_grad() - def forward(self, messages: List[Dict]) -> List[Dict]: + def forward(self, + messages: List[Dict], + max_batch_size: int = 1) -> List[Dict]: """extract image feature. ONLY implement it when the backend is turbomind engine. Args: messages(List[Dict]): the outputs of `preprocess` + max_batch_size(int): the max batch size when forwarding vision + model Return: the message list with forwarding results included """ inputs = [x['content'] for x in messages if x['role'] == 'preprocess'] inputs = inputs[0] - if self.model_type in [ - ModelType.XCOMPOSER2D5, ModelType.XCOMPOSER2_4KHD - ]: - pixel_values = [x['pixel_values'] for x in inputs] - embeds, split = self.model.vit(pixel_values, - self.model.plora_glb_GN, - self.model.plora_sub_GN) - embeds = self.model.vision_proj(embeds) - embeds = torch.split(embeds, split, dim=1) - embeds = [x.squeeze() for x in embeds] - else: - pixel_values = [x['pixel_values'] for x in inputs] - pixel_values = torch.cat(pixel_values, dim=0) - embeds = self.model.vit(pixel_values) - embeds = self.model.vision_proj(embeds) - embeds = torch.split(embeds, 1, dim=0) - embeds = [x.squeeze() for x in embeds] - messages.append(dict(role='forward', content=embeds)) + outputs = [] + for idx in range(0, len(inputs), max_batch_size): + if self.model_type in [ + ModelType.XCOMPOSER2D5, ModelType.XCOMPOSER2_4KHD + ]: + pixel_values = [ + x['pixel_values'] for x in inputs[idx:idx + max_batch_size] + ] + embeds, split = self.model.vit(pixel_values, + self.model.plora_glb_GN, + self.model.plora_sub_GN) + embeds = self.model.vision_proj(embeds) + embeds = torch.split(embeds, split, dim=1) + embeds = [x.squeeze() for x in embeds] + else: + pixel_values = [ + x['pixel_values'] for x in inputs[idx:idx + max_batch_size] + ] + pixel_values = torch.cat(pixel_values, dim=0) + logger.info(f'vision forward shape: {pixel_values.shape}') + embeds = self.model.vit(pixel_values) + embeds = self.model.vision_proj(embeds) + embeds = torch.split(embeds, 1, dim=0) + embeds = [x.squeeze() for x in embeds] + outputs.extend(embeds) + messages.append(dict(role='forward', content=outputs)) return messages @classmethod