Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Further cleanup MiniCPM-V #6995

Closed
wants to merge 5 commits into from
Closed
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 71 additions & 54 deletions vllm/model_executor/models/minicpmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import math
import re
from functools import partial
from typing import Dict, Iterable, List, Optional, Tuple, Union
from typing import Callable, Iterable, List, Optional, Tuple, TypedDict, Union

import numpy as np
import torch
Expand Down Expand Up @@ -60,6 +60,12 @@
}


class MiniCPMVInputs(TypedDict):
input_ids: torch.Tensor
pixel_values: Union[torch.Tensor, List[torch.Tensor]]
tgt_sizes: List[List[int]]


def get_abs_pos(abs_pos: torch.Tensor, tgt_size: torch.Tensor):
# abs_pos: L, C
# tgt_size: (H, W)
Expand Down Expand Up @@ -109,7 +115,7 @@ def get_2d_sincos_pos_embed(embed_dim: int,


def get_2d_sincos_pos_embed_from_grid(embed_dim: int,
grid: Union[int, Tuple[int, int]],
grid: np.ndarray,
version: Tuple[int, int] = (2, 0)):
assert embed_dim % 2 == 0

Expand All @@ -127,7 +133,7 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim: int,


def get_1d_sincos_pos_embed_from_grid(embed_dim: int,
pos: int,
pos: np.ndarray,
version: Tuple[int, int] = (2, 0)):
"""
embed_dim: output dimension for each position
Expand Down Expand Up @@ -164,23 +170,27 @@ class Resampler(nn.Module):
default_norm_layer = partial(nn.LayerNorm, eps=1e-6)

def __init__(self,
num_queries: int,
grid_size: int,
num_queries: Optional[int],
grid_size: Optional[int],
embed_dim: int,
num_heads: int,
kv_dim: Optional[int] = None,
norm_layer: nn.Module = default_norm_layer,
norm_layer: Callable[[int],
nn.LayerNorm] = default_norm_layer,
adaptive: bool = False,
max_size: Tuple[int, int] = (70, 70),
version: Tuple[int, int] = (2, 0)):
super().__init__()

self.version = version
if self.version == (2, 0):
self.num_queries = grid_size**2
assert grid_size is not None
num_queries = grid_size**2
else:
self.num_queries = num_queries
assert num_queries is not None
self.max_size = max_size

self.num_queries = num_queries
self.embed_dim = embed_dim
self.num_heads = num_heads
self.adaptive = adaptive
Expand All @@ -202,6 +212,7 @@ def __init__(self,
(embed_dim**-0.5) * torch.randn(embed_dim, embed_dim))

if self.version == (2, 0):
assert grid_size is not None
self.pos_embed = nn.Parameter(
torch.from_numpy(
get_2d_sincos_pos_embed(
Expand All @@ -223,13 +234,15 @@ def _set_2d_pos_cache(self,

def _adjust_pos_cache(self, tgt_sizes: torch.Tensor,
device: torch.types.Device):
max_h = torch.max(tgt_sizes[:, 0])
max_w = torch.max(tgt_sizes[:, 1])
max_h = tgt_sizes[:, 0].max().item()
max_w = tgt_sizes[:, 1].max().item()
assert isinstance(max_h, int) and isinstance(max_w, int)

if max_h > self.max_size[0] or max_w > self.max_size[1]:
self.max_size = [
self.max_size = (
max(max_h, self.max_size[0]),
max(max_w, self.max_size[1])
]
max(max_w, self.max_size[1]),
)
self._set_2d_pos_cache(self.max_size, device)

def _init_weights(self, m: nn.Module):
Expand All @@ -241,9 +254,7 @@ def _init_weights(self, m: nn.Module):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)

def forward_2_5(self,
x: torch.Tensor,
tgt_sizes: Optional[torch.Tensor] = None):
def forward_2_5(self, x: torch.Tensor, tgt_sizes: torch.Tensor):
assert x.shape[0] == tgt_sizes.shape[0]
bs = x.shape[0]

Expand All @@ -254,14 +265,16 @@ def forward_2_5(self,

self._adjust_pos_cache(tgt_sizes, device=device)

max_patch_len = torch.max(patch_len)
max_patch_len = patch_len.max().item()
assert isinstance(max_patch_len, int)

key_padding_mask = torch.zeros((bs, max_patch_len),
dtype=torch.bool,
device=device)

pos_embed = []
for i in range(bs):
tgt_h, tgt_w = tgt_sizes[i]
tgt_h, tgt_w = tgt_sizes[i].tolist()
pos_embed.append(self.pos_embed[:tgt_h, :tgt_w, :].reshape(
(tgt_h * tgt_w, -1)).to(dtype)) # patches * D
key_padding_mask[i, patch_len[i]:] = True
Expand Down Expand Up @@ -291,7 +304,7 @@ def forward_2_5(self,

def forward_2(self,
x: torch.Tensor,
tgt_sizes: Optional[torch.Tensor] = None,
tgt_sizes: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None):
if self.adaptive:
pos_embed = torch.Tensor(
Expand All @@ -318,7 +331,7 @@ def forward_2(self,

def forward(self,
x: torch.Tensor,
tgt_sizes: Optional[torch.Tensor] = None,
tgt_sizes: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None):
if self.version == (2, 0):
return self.forward_2(x, tgt_sizes=tgt_sizes, attn_mask=attn_mask)
Expand Down Expand Up @@ -413,8 +426,9 @@ def __init__(
else:
self.version = (2, 5)
else:
self.version = str(self.config.version).split(".")
self.version = tuple([int(x) for x in self.version])
major, minor = str(self.config.version).split(".")
self.version = int(major), int(minor)

self.llm = self.init_llm(config, cache_config, quant_config)
self.vpm = self.init_vision_module()
param_dtype = torch.get_default_dtype()
Expand Down Expand Up @@ -491,17 +505,18 @@ def init_vision_module(self):
def init_resampler(self, embed_dim: int, vision_dim: int):
default_dtype = torch.get_default_dtype()
torch.set_default_dtype(torch.float16)

query_num = self.config.query_num
if self.version == (2, 0):
resampler = Resampler(grid_size=int(
math.sqrt(self.config.query_num)),
resampler = Resampler(grid_size=int(math.sqrt(query_num)),
num_queries=None,
embed_dim=embed_dim,
num_heads=embed_dim // 128,
kv_dim=vision_dim,
adaptive=True,
version=self.version)
else:
resampler = Resampler(num_queries=self.config.query_num,
resampler = Resampler(num_queries=query_num,
grid_size=None,
embed_dim=embed_dim,
num_heads=embed_dim // 128,
Expand All @@ -511,11 +526,13 @@ def init_resampler(self, embed_dim: int, vision_dim: int):
torch.set_default_dtype(default_dtype)
return resampler

def get_vision_embedding(self,
pixel_values: List[List[torch.Tensor]],
patch_attn_mask: Optional[torch.Tensor] = None,
tgt_sizes: Optional[torch.Tensor] = None,
version: Tuple[int, int] = (2, 0)):
def get_vision_embedding(
self,
pixel_values: Union[List[torch.Tensor], torch.Tensor],
patch_attn_mask: Optional[torch.Tensor] = None,
tgt_sizes: Optional[torch.Tensor] = None,
version: Tuple[int, int] = (2, 0),
):
if version == (2, 0):
res = []
dtype = self.vpm.pos_embed.data.dtype
Expand Down Expand Up @@ -568,9 +585,7 @@ def get_image_bounds(self, input_ids: torch.Tensor):

return image_bound

def get_vision_hidden_states(self, data: Dict[str,
Union[List[torch.Tensor],
torch.Tensor]]):
def get_vision_hidden_states(self, data: MiniCPMVInputs):
if "vision_hidden_states" not in data:
pixel_values = data["pixel_values"]
tgt_sizes = data["tgt_sizes"]
Expand Down Expand Up @@ -627,8 +642,7 @@ def get_vision_hidden_states(self, data: Dict[str,

return vision_hidden_states

def get_embedding(self, data: Dict[str, Union[List[torch.Tensor],
torch.Tensor]]):
def get_embedding(self, data: MiniCPMVInputs):
input_ids = data["input_ids"]

vision_hidden_states = self.get_vision_hidden_states(data)
Expand Down Expand Up @@ -659,19 +673,23 @@ def get_embedding(self, data: Dict[str, Union[List[torch.Tensor],
vision_hidden_states.view(-1, vision_hidden_states.shape[-1]))
return vlm_embedding, vision_hidden_states

def process_multimodal_inputs(self, inputs: Dict[str,
Union[List[torch.Tensor],
torch.Tensor]]):
pixel_values = []
tgt_sizes = []
for b in range(len(inputs["pixel_values"])):
pixel_values += inputs["pixel_values"][b]
tgt_sizes += inputs["tgt_sizes"][b]
return {
"pixel_values": pixel_values,
"input_ids": inputs["input_ids"],
"tgt_sizes": tgt_sizes
}
def process_multimodal_inputs(
self,
pixel_values: Union[List[torch.Tensor], torch.Tensor],
input_ids: torch.Tensor,
tgt_sizes: torch.Tensor,
):
pixel_values_lst = []
tgt_sizes_lst = []
for b in range(len(pixel_values)):
pixel_values += pixel_values[b]
tgt_sizes += tgt_sizes[b]

return MiniCPMVInputs(
pixel_values=pixel_values_lst,
input_ids=input_ids,
tgt_sizes=tgt_sizes_lst,
)

def forward(
self,
Expand All @@ -682,12 +700,11 @@ def forward(
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object,
):
inputs = {
"pixel_values": kwargs.pop("pixel_values", []),
"input_ids": input_ids,
"tgt_sizes": kwargs.pop("tgt_sizes", None),
}
inputs = self.process_multimodal_inputs(inputs)
inputs = self.process_multimodal_inputs(
pixel_values=kwargs.pop("pixel_values", []), # type: ignore
input_ids=input_ids,
tgt_sizes=kwargs.pop("tgt_sizes"), # type: ignore
)

vlm_embeddings, vision_hidden_states = self.get_embedding(inputs)

Expand Down
Loading