Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
lvhan028 committed Dec 8, 2024
1 parent c89a1d9 commit 65c78c2
Show file tree
Hide file tree
Showing 11 changed files with 34 additions and 9 deletions.
1 change: 1 addition & 0 deletions lmdeploy/vl/model/deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def forward(self,
self.vision_model.parameters()).device,
dtype=torch.float16)
# [b x n_images, T2, D]
logger.info(f'vision forward shape: {pixel_values.shape}')
feats = self.aligner(self.vision_model(pixel_values))
feats = torch.split(feats, 1, dim=0)
outputs.extend([x.squeeze() for x in feats])
Expand Down
4 changes: 3 additions & 1 deletion lmdeploy/vl/model/internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def _forward_v1_5(self, inputs, max_batch_size):
pixel_values = torch.cat(pixel_values, dim=0)
pixel_values = pixel_values.to(self.model.device,
dtype=torch.float16)
logger.info(f'vision forward shape: {pixel_values.shape}')
feats = self.model.extract_feature(pixel_values)
feats = torch.split(feats, split, dim=0)
outputs.extend([x.reshape(-1, x.shape[-1]) for x in feats])
Expand All @@ -204,9 +205,10 @@ def _forward(self, inputs, max_batch_size):
pixel_values = torch.cat(outputs, dim=0)
pixel_values = pixel_values.to(self.model.device,
dtype=torch.float16)
logger.info(f'vision forward shape: {pixel_values.shape}')
feats = self.model.extract_feature(pixel_values)
feats = torch.split(feats, 1, dim=0)
outputs.extend([x.squeeze() for x in outputs])
outputs.extend([x.squeeze() for x in feats])
return outputs

def preprocess(self, messages: List[Dict]) -> List[Dict]:
Expand Down
1 change: 1 addition & 0 deletions lmdeploy/vl/model/internvl_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def forward(self,
pixel_values = torch.cat(pixel_values, dim=0)
pixel_values = pixel_values.to(device=self.vision_tower.device,
dtype=torch.float16)
logger.info(f'vision forward shape: {pixel_values.shape}')
if pixel_values.ndim == 5:
feats = self.encode_images(pixel_values)
feats = torch.split(feats, split_sizes, dim=0)
Expand Down
17 changes: 11 additions & 6 deletions lmdeploy/vl/model/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,12 +352,13 @@ def forward(self,
pixel_values = [
x['pixel_values'] for x in inputs[idx:idx + max_batch_size]
]
pixel_values = torch.cat(pixel_values, dim=0)
pixel_values = pixel_values.to(device=self.vision_tower.device,
dtype=torch.float16)
if pixel_values.ndim == 5:
split_sizes = [x.shape[0] for x in pixel_values]
pixel_values = torch.cat([x for x in pixel_values], dim=0)
if pixel_values[0].ndim == 5:
split_sizes = [x.shape[1] for x in pixel_values]
pixel_values = torch.cat([x for x in pixel_values], dim=1)
logger.info(f'vision forward shape: {pixel_values.shape}')
pixel_values = pixel_values.squeeze(0)
pixel_values = pixel_values.to(device=self.vision_tower.device,
dtype=torch.float16)
feats = self.encode_images(pixel_values)
feats = torch.split(feats, split_sizes, dim=0)
mm_patch_merge_type = getattr(self.config,
Expand Down Expand Up @@ -411,6 +412,10 @@ def forward(self,
raise ValueError('Unexpected mm_patch_merge_type: '
f'{self.config.mm_patch_merge_type}')
else:
pixel_values = torch.cat(pixel_values, dim=0)
pixel_values = pixel_values.to(device=self.vision_tower.device,
dtype=torch.float16)
logger.info(f'vision forward shape: {pixel_values.shape}')
feats = self.encode_images(pixel_values)
outputs.extend([x for x in feats])
messages.append(dict(role='forward', content=outputs))
Expand Down
4 changes: 4 additions & 0 deletions lmdeploy/vl/model/llava_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@
import torch
from transformers import AutoProcessor

from lmdeploy.utils import get_logger
from lmdeploy.vl.model.base import VISION_MODELS, VisonModel
from lmdeploy.vl.model.utils import disable_logging

logger = get_logger('lmdeploy')


@VISION_MODELS.register_module()
class LlavaHfVisionModel(VisonModel):
Expand Down Expand Up @@ -98,6 +101,7 @@ def forward(self,
pixel_values = torch.cat(pixel_values, dim=0)
pixel_values = pixel_values.to(device=self.model.device,
dtype=self.model.dtype)
logger.info(f'vision forward shape: {pixel_values.shape}')
image_outputs = self.model.vision_tower.forward(
pixel_values, output_hidden_states=True)
image_features = image_outputs.hidden_states[
Expand Down
1 change: 1 addition & 0 deletions lmdeploy/vl/model/llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def forward(self,
raise ValueError(
f'pixel_values of shape {pixel_values.shape}, '
'expect to be of 4 or 5 dimensions')
logger.info(f'vision forward shape: {pixel_values.shape}')
image_outputs = self.model.vision_tower.forward(
pixel_values, output_hidden_states=True)
image_features = image_outputs.hidden_states[
Expand Down
6 changes: 5 additions & 1 deletion lmdeploy/vl/model/mini_gemeni.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@

import torch

from lmdeploy.utils import get_logger
from lmdeploy.vl.model.base import VISION_MODELS, VisonModel
from lmdeploy.vl.model.utils import (add_device_hook, disable_logging,
disable_transformers_logging,
hack_import_with)

logger = get_logger('lmdeploy')


def check_mini_gemini_install():
"""check mini gemini install."""
Expand Down Expand Up @@ -330,12 +333,13 @@ def forward(self,
image.to(self.model.device, dtype=torch.float16)
for image in image_tensor_aux
]
logger.info(f'vision forward bs: {len(image_tensor)}')
else:
image_tensor = image_tensor.to(self.model.device,
dtype=torch.float16)
image_tensor_aux = image_tensor_aux.to(self.model.device,
dtype=torch.float16)

logger.info(f'vision forward shape: {image_tensor.shape}')
images_embeds = self.model.encode_images(image_tensor,
image_tensor_aux)

Expand Down
1 change: 1 addition & 0 deletions lmdeploy/vl/model/minicpmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ def forward(self,
device=self.model.device)
for i in range(B):
patch_attn_mask[i, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True
logger.info(f'vision forward shape: {pixel_values.shape}')
if self.version == '2.5':
embeddings = self.model.vpm(
pixel_values.type(torch.half),
Expand Down
1 change: 1 addition & 0 deletions lmdeploy/vl/model/molmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def forward(self,
embeddings = self.model.model.transformer.wte(input_ids)
images = images.to(self.model.dtype)
image_masks = image_masks.to(self.model.dtype)
logger.info(f'vision forward shape: {images.shape}')
image_features, _ = self.model.model.vision_backbone(
images, image_masks)
num_image, num_patch = image_features.shape[1:3]
Expand Down
6 changes: 5 additions & 1 deletion lmdeploy/vl/model/qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@
import torch
from transformers import AutoModelForCausalLM

from lmdeploy.utils import get_logger
from lmdeploy.vl.model.base import VISION_MODELS, VisonModel
from lmdeploy.vl.model.utils import disable_logging

logger = get_logger('lmdeploy')


@VISION_MODELS.register_module()
class QwenVisionModel(VisonModel):
Expand Down Expand Up @@ -104,11 +107,12 @@ def forward(self,
inputs = [x['content'] for x in messages if x['role'] == 'preprocess']
inputs = inputs[0]
outputs = []
for idx in range(0, len(messages), max_batch_size):
for idx in range(0, len(inputs), max_batch_size):
pixel_values = [
x['pixel_values'] for x in inputs[idx:idx + max_batch_size]
]
pixel_values = torch.stack(pixel_values, dim=0)
logger.info(f'vision forward shape: {pixel_values.shape}')
feats = self.model(pixel_values)
feats = torch.split(feats, 1, dim=0)
outputs.extend([x.squeeze() for x in feats])
Expand Down
1 change: 1 addition & 0 deletions lmdeploy/vl/model/xcomposer2.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ def forward(self,
x['pixel_values'] for x in inputs[idx:idx + max_batch_size]
]
pixel_values = torch.cat(pixel_values, dim=0)
logger.info(f'vision forward shape: {pixel_values.shape}')
embeds = self.model.vit(pixel_values)
embeds = self.model.vision_proj(embeds)
embeds = torch.split(embeds, 1, dim=0)
Expand Down

0 comments on commit 65c78c2

Please sign in to comment.