Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1] Initial support of multimodal models for V1 re-arch #10699

Merged
merged 31 commits into from
Dec 8, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
246d75b
internvl
ywang96 Nov 27, 2024
2a081bb
fix token id
ywang96 Nov 27, 2024
e4d6bb2
Merge branch 'vllm-project:main' into v1-initial
ywang96 Nov 28, 2024
94d66cc
Pixtral
ywang96 Nov 30, 2024
79f24c6
use special ids
ywang96 Nov 30, 2024
7a88433
comment
ywang96 Nov 30, 2024
af1dbab
cleanup for pixtral
ywang96 Nov 30, 2024
39dd4f2
Merge branch 'vllm-project:main' into v1-initial
ywang96 Nov 30, 2024
6d0df5a
qwen2vl
ywang96 Dec 1, 2024
124b0c1
Merge branch 'vllm-project:main' into v1-initial
ywang96 Dec 2, 2024
8c4da46
molmo
ywang96 Dec 2, 2024
3e3a346
minor changes on interfaces
ywang96 Dec 2, 2024
1c50613
typo
ywang96 Dec 2, 2024
6d8ddff
pad
ywang96 Dec 2, 2024
7ddf7d9
Merge branch 'vllm-project:main' into v1-initial
ywang96 Dec 2, 2024
f1fa769
remove print
ywang96 Dec 3, 2024
ee8e0ae
Merge branch 'vllm-project:main' into v1-initial
ywang96 Dec 3, 2024
319e689
Merge branch 'vllm-project:main' into v1-initial
ywang96 Dec 4, 2024
77256d9
change check order
ywang96 Dec 4, 2024
bdd8da6
Merge branch 'main' into v1-initial
ywang96 Dec 5, 2024
e32efd5
Merge branch 'vllm-project:main' into v1-initial
ywang96 Dec 6, 2024
0176b7b
molmo
ywang96 Dec 6, 2024
69f4e5f
fix launch args
ywang96 Dec 6, 2024
8b7e746
fix qwen2-vl
ywang96 Dec 6, 2024
bb15b01
typing
ywang96 Dec 6, 2024
610e662
add documentation
ywang96 Dec 6, 2024
2b5fdd7
minor fix
ywang96 Dec 6, 2024
a5a38dd
typehint
ywang96 Dec 6, 2024
fbf9cd0
Merge branch 'main' into v1-initial
ywang96 Dec 7, 2024
8d1d80e
iterate
ywang96 Dec 8, 2024
4a79255
revert changes in qwen2vl
ywang96 Dec 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 56 additions & 12 deletions vllm/model_executor/models/internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
InternVisionPatchModel)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.inputs import NestedTensors
from vllm.multimodal.inputs import NestedTensors, PlaceholderRange
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
Expand All @@ -52,12 +52,18 @@ class InternVLImagePixelInputs(TypedDict):
Shape:
`(batch_size * num_images * (1 + num_patches), num_channels, height, width)`
"""
patches_per_image: List[int]
"""
List of number of total patches for each image in the batch.
"""


class InternVLImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: torch.Tensor
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
data: NestedTensors
"""
A tensor of shape `(num_images, total_image_feature_size, hidden_size)`
or a list of tensors of shape `(total_image_feature_size, hidden_size)`

`hidden_size` must match the hidden size of language model backbone.
"""
Expand Down Expand Up @@ -349,10 +355,32 @@ def input_processor(
new_prompt = self._expand_image_prompt(prompt, image_feature_sizes,
num_patches)
new_prompt_token_ids = tokenizer.encode(new_prompt)
img_context_token_id = tokenizer.encode(self.img_context_token,
add_special_tokens=False)
assert len(img_context_token_id) == 1, \
(f"Invalid image token '{self.img_context_token}': A valid image "
f"token encodes to a single token ID, got {img_context_token_id}.")
img_context_token_id = img_context_token_id[0]

# Get precise tracking of placeholder positions
token_idx = image_idx = 0
placeholder_ranges = []
while token_idx < len(new_prompt_token_ids):
if new_prompt_token_ids[token_idx] == img_context_token_id:
curr_image_featue_size = image_feature_sizes[image_idx]
placeholder_ranges.append(
PlaceholderRange(offset=token_idx,
length=curr_image_featue_size))
image_idx += 1
token_idx += curr_image_featue_size
else:
token_idx += 1

return token_inputs(prompt=prompt,
prompt_token_ids=new_prompt_token_ids,
multi_modal_data=multi_modal_data)
return token_inputs(
prompt=prompt,
prompt_token_ids=new_prompt_token_ids,
multi_modal_data=multi_modal_data,
multi_modal_placeholders={"image": placeholder_ranges})

def input_mapper(
self,
Expand Down Expand Up @@ -612,26 +640,46 @@ def _parse_and_validate_image_input(
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")

patches_per_image = []
for request_pixel_values in pixel_values:
for image_pixel_values in request_pixel_values:
patches_per_image.append(image_pixel_values.shape[0])
# We need to flatten (B, N, P) to (B*N*P),
# so we call flatten_bn twice.
return InternVLImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(
flatten_bn(flatten_bn(pixel_values), concat=True)),
)
patches_per_image=patches_per_image)

raise AssertionError("This line should be unreachable.")

def _process_image_input(
self,
image_input: InternVLImageInputs,
) -> torch.Tensor:
) -> Tuple[torch.Tensor]:
if image_input["type"] == "image_embeds":
return image_input["data"]

assert self.vision_model is not None

image_embeds = self.extract_feature(image_input["data"])

patches_per_image = image_input["patches_per_image"]
if len(patches_per_image) == 1:
image_embeds = image_embeds.unsqueeze(0)
return image_embeds

# NOTE: Image embeddings are split into separate tensors for each image
# by the size of each embedding.
feature_size = image_embeds.shape[1]
image_embeds = image_embeds.view(-1,
self.config.text_config.hidden_size)
image_feature_sizes = [
num_patches * feature_size for num_patches in patches_per_image
]
image_embeds = image_embeds.split(image_feature_sizes)
return image_embeds

def _get_visual_token_mask(self, input_ids: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -697,10 +745,6 @@ def forward(
if self.img_context_token_id is not None:
visual_token_mask = self._get_visual_token_mask(input_ids)

# We always overwrite it back to None after computing visual token
# mask so that this doesn't need to depend on encoder output
self.img_context_token_id = None

if self.is_mono:
forward_kwargs.update({"visual_token_mask": visual_token_mask})

Expand Down
114 changes: 85 additions & 29 deletions vllm/model_executor/models/pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@
except ImportError:
USE_XFORMERS_OPS = False

PIXTRAL_IMAGE_BREAK_ID = 12
PIXTRAL_IMAGE_END_ID = 13


def get_max_pixtral_image_tokens(ctx: InputContext):
tokenizer = cached_get_tokenizer(
Expand All @@ -68,7 +71,6 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int,
tokenizer_mode=ctx.model_config.tokenizer_mode)

mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder
patch_size = mm_encoder.mm_config.image_patch_size
image_token_id = mm_encoder.special_ids.img

mm_config = ctx.model_config.multimodal_config
Expand All @@ -78,8 +80,8 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int,
size = 256
image = Image.new("RGB", (size, size), color=0)

image_feature_size = (size**2) // (patch_size**2)

encoding = tokenizer.instruct.mm_encoder(ImageChunk(image=image))
image_feature_size = len(encoding.tokens)
num_image_tokens = image_feature_size * num_images
seq_data = SequenceData.from_prompt_token_counts(
(image_token_id, num_image_tokens),
Expand All @@ -101,50 +103,80 @@ def input_mapper_for_pixtral(ctx: InputContext,

Args:
ctx: Context of the loaded model.
data: data potentially containing image/image embeddings to be mapped
to pixel_values in .forward() for a visual QWenLMHeadModel model.
data: data potentially containing PIL images to be processed
and mapped to `images`.

Returns:
MultiModalKwargs containing the stacked normalized images tensor or
image embeddings.
"""
# Early exit if we have provided an image to a language only Qwen model
model_config = ctx.model_config
tokenizer = cached_get_tokenizer(
model_config.tokenizer, tokenizer_mode=model_config.tokenizer_mode)

data_list = data if isinstance(data, list) else [data]

images = []
image_tokens_list = []
for image_data in data_list:
image = ImageChunk(image=image_data)
encoding = tokenizer.instruct.mm_encoder(image)
image = torch.from_numpy(encoding.image).to(device="cuda",
dtype=torch.float16)
images.append(image)
image_tokens_list.append(encoding.tokens)

return MultiModalKwargs({"images": images})
image_tokens = torch.flatten(
torch.tensor([
token_id for image_tokens in image_tokens_list
for token_id in image_tokens
]))
ywang96 marked this conversation as resolved.
Show resolved Hide resolved
return MultiModalKwargs({"images": images, "image_tokens": image_tokens})


def input_processor_for_pixtral(ctx: InputContext, inputs: DecoderOnlyInputs):
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is not None and "image" in multi_modal_data:
tokenizer = cached_get_tokenizer(
ctx.model_config.tokenizer,
tokenizer_mode=ctx.model_config.tokenizer_mode)
if multi_modal_data is None or "image" not in multi_modal_data:
return inputs

mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder
image_token_id = mm_encoder.special_ids.img
prompt_token_ids = inputs.get("prompt_token_ids")
prompt = inputs.get("prompt")
tokenizer = cached_get_tokenizer(
ctx.model_config.tokenizer,
tokenizer_mode=ctx.model_config.tokenizer_mode)

if image_token_id not in inputs['prompt_token_ids']:
raise ValueError(
f"You've passed {inputs=} without {image_token_id=}"
" Make sure to process your input via mistral_common's"
" tokenizer or pass a chat completion request. For more"
" For more info, see: "
"https://github.com/vllm-project/vllm/issues/8411.")
mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder
image_token_id = mm_encoder.special_ids.img

return inputs
if image_token_id not in inputs['prompt_token_ids']:
raise ValueError(
f"You've passed {inputs=} without {image_token_id=}"
" Make sure to process your input via mistral_common's"
" tokenizer or pass a chat completion request. For more"
" For more info, see: "
"https://github.com/vllm-project/vllm/issues/8411.")

# Get precise tracking of placeholder positions
placeholder_ranges = []
curr_length = 0
curr_offset = 0
for i in range(len(prompt_token_ids)):
if prompt_token_ids[i] in (image_token_id, PIXTRAL_IMAGE_BREAK_ID):
if curr_offset == 0:
curr_offset = i
curr_length += 1
elif prompt_token_ids[i] == PIXTRAL_IMAGE_END_ID:
curr_length += 1
placeholder_ranges.append(
PlaceholderRange(offset=curr_offset, length=curr_length))
curr_offset = 0
curr_length = 0
else:
pass
return token_inputs(prompt=prompt,
prompt_token_ids=prompt_token_ids,
multi_modal_data=multi_modal_data,
multi_modal_placeholders={"image": placeholder_ranges})


@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_pixtral)
Expand Down Expand Up @@ -191,11 +223,33 @@ def sampler(self):
return get_sampler()

def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
image_input = self._parse_and_validate_image_input(**kwargs)
image_input, image_tokens = self._parse_and_validate_image_input(
**kwargs)
if image_input is None:
return None

image_tokens = torch.flatten(
torch.tensor([
token_id for image_tokens_per_request in image_tokens
for token_id in image_tokens_per_request
],
device=self.vision_encoder.device))

vision_embeddings = self._process_image_input(image_input)
return vision_embeddings
image_embeds = self.language_model.get_input_embeddings(image_tokens)
image_token_mask = image_tokens == self.vision_args.image_token_id
image_embeds[image_token_mask] = vision_embeddings

# NOTE: Image embeddings are split into separate tensors for each image
# by the indices of `[IMG_END]` token.
split_indices = torch.where(
image_tokens == PIXTRAL_IMAGE_END_ID)[0] + 1
if len(split_indices) <= 1:
# Do not split, return as tensor of shape [1, fs, hs]
return image_embeds.unsqueeze(0)

image_embeds = image_embeds.tensor_split(split_indices.cpu())
return image_embeds

def get_input_embeddings(
self,
Expand All @@ -205,8 +259,10 @@ def get_input_embeddings(
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.vision_args.image_token_id)
input_ids, inputs_embeds, multimodal_embeddings, [
self.vision_args.image_token_id, PIXTRAL_IMAGE_END_ID,
PIXTRAL_IMAGE_BREAK_ID
])
return inputs_embeds

def forward(
Expand Down Expand Up @@ -244,10 +300,11 @@ def forward(
def _parse_and_validate_image_input(
self,
images: Optional[Union[List[List[torch.Tensor]], List[torch.Tensor],
torch.Tensor]] = None
torch.Tensor]] = None,
image_tokens: Optional[torch.Tensor] = None,
) -> Optional[List[torch.Tensor]]:
if images is None:
return None
return None, None

if isinstance(images, torch.Tensor):
# if passed as batch take all images
Expand All @@ -265,8 +322,7 @@ def _parse_and_validate_image_input(
flatten_images.extend(imgs_per_req)

images = flatten_images

return images
return images, image_tokens

def _process_image_input(self,
image_input: List[torch.Tensor]) -> torch.Tensor:
Expand Down
12 changes: 10 additions & 2 deletions vllm/model_executor/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ def merge_multimodal_embeddings(
input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
multimodal_embeddings: NestedTensors,
placeholder_token_id: int,
placeholder_token_id: Union[int, List[int]],
) -> torch.Tensor:
"""
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
Expand All @@ -402,9 +402,17 @@ def merge_multimodal_embeddings(
Note:
This updates ``inputs_embeds`` in place.
"""
if isinstance(placeholder_token_id, int):
return _merge_multimodal_embeddings(
inputs_embeds,
(input_ids in placeholder_token_id),
ywang96 marked this conversation as resolved.
Show resolved Hide resolved
multimodal_embeddings,
)
placeholder_token_id = torch.tensor(placeholder_token_id,
device=input_ids.device)
return _merge_multimodal_embeddings(
inputs_embeds,
(input_ids == placeholder_token_id),
torch.isin(input_ids, placeholder_token_id),
multimodal_embeddings,
)

Expand Down
4 changes: 2 additions & 2 deletions vllm/v1/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,12 @@ def __init__(
# has the Transformer architecture (e.g., ViT).
# FIXME(woosuk): Below are placeholder values. We need to calculate the
# actual values from the configurations.
self.max_num_encoder_input_tokens = 2048
self.max_num_encoder_input_tokens = 8192
# NOTE(woosuk): For the models without encoder (e.g., text-only models),
# the encoder cache will not be initialized and used, regardless of
# the cache size. This is because the memory space for the encoder cache
# is preallocated in the profiling run.
self.encoder_cache_manager = EncoderCacheManager(cache_size=2048)
self.encoder_cache_manager = EncoderCacheManager(cache_size=8192)

def schedule(self) -> "SchedulerOutput":
# NOTE(woosuk) on the scheduling algorithm:
Expand Down
Loading