Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Fix issues for Pixtral-Large-Instruct-2411 #11393

Merged
merged 1 commit into from
Dec 21, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions vllm/model_executor/models/pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,12 @@
except ImportError:
USE_XFORMERS_OPS = False

PIXTRAL_IMAGE_BREAK_ID = 12
PIXTRAL_IMAGE_END_ID = 13
# These token ids cannot be retrieved from model config
# so we hardcode them here.
PIXTRAL_12B_IMAGE_BREAK_ID = 12
PIXTRAL_12B_IMAGE_END_ID = 13
PIXTRAL_LARGE_IMAGE_BREAK_ID = 14
PIXTRAL_LARGE_IMAGE_END_ID = 15


def get_max_pixtral_image_tokens(ctx: InputContext):
Expand Down Expand Up @@ -118,8 +122,7 @@ def input_mapper_for_pixtral(ctx: InputContext,
for image_data in data_list:
image = ImageChunk(image=image_data)
encoding = tokenizer.instruct.mm_encoder(image)
image = torch.from_numpy(encoding.image).to(device="cuda",
dtype=torch.float16)
image = torch.from_numpy(encoding.image).to(dtype=torch.float16)
images.append(image)
image_tokens_list.append(encoding.tokens)

Expand Down Expand Up @@ -237,8 +240,9 @@ def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:

# NOTE: Image embeddings are split into separate tensors for each image
# by the indices of `[IMG_END]` token.
split_indices = torch.where(
image_tokens == PIXTRAL_IMAGE_END_ID)[0] + 1
image_end_condition = (image_tokens == PIXTRAL_12B_IMAGE_END_ID) | (
image_tokens == PIXTRAL_LARGE_IMAGE_END_ID)
split_indices = torch.where(image_end_condition)[0] + 1
if len(split_indices) <= 1:
# Do not split, return as tensor of shape [1, fs, hs]
return image_embeds.unsqueeze(0)
Expand All @@ -260,8 +264,11 @@ def get_input_embeddings(
if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings, [
self.vision_args.image_token_id, PIXTRAL_IMAGE_END_ID,
PIXTRAL_IMAGE_BREAK_ID
self.vision_args.image_token_id,
PIXTRAL_12B_IMAGE_END_ID,
PIXTRAL_12B_IMAGE_BREAK_ID,
PIXTRAL_LARGE_IMAGE_BREAK_ID,
PIXTRAL_LARGE_IMAGE_END_ID,
])
return inputs_embeds

Expand Down
Loading