Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
Signed-off-by: ywang96 <[email protected]>
  • Loading branch information
ywang96 committed Dec 21, 2024
1 parent d573aea commit 54b8c70
Showing 1 changed file with 15 additions and 8 deletions.
23 changes: 15 additions & 8 deletions vllm/model_executor/models/pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,12 @@
except ImportError:
USE_XFORMERS_OPS = False

PIXTRAL_IMAGE_BREAK_ID = 12
PIXTRAL_IMAGE_END_ID = 13
# These token ids cannot be retrieved from model config
# so we hardcode them here.
PIXTRAL_12B_IMAGE_BREAK_ID = 12
PIXTRAL_12B_IMAGE_END_ID = 13
PIXTRAL_LARGE_IMAGE_BREAK_ID = 14
PIXTRAL_LARGE_IMAGE_END_ID = 15


def get_max_pixtral_image_tokens(ctx: InputContext):
Expand Down Expand Up @@ -118,8 +122,7 @@ def input_mapper_for_pixtral(ctx: InputContext,
for image_data in data_list:
image = ImageChunk(image=image_data)
encoding = tokenizer.instruct.mm_encoder(image)
image = torch.from_numpy(encoding.image).to(device="cuda",
dtype=torch.float16)
image = torch.from_numpy(encoding.image).to(dtype=torch.float16)
images.append(image)
image_tokens_list.append(encoding.tokens)

Expand Down Expand Up @@ -237,8 +240,9 @@ def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:

# NOTE: Image embeddings are split into separate tensors for each image
# by the indices of `[IMG_END]` token.
split_indices = torch.where(
image_tokens == PIXTRAL_IMAGE_END_ID)[0] + 1
image_end_condition = (image_tokens == PIXTRAL_12B_IMAGE_END_ID) | (
image_tokens == PIXTRAL_LARGE_IMAGE_END_ID)
split_indices = torch.where(image_end_condition)[0] + 1
if len(split_indices) <= 1:
# Do not split, return as tensor of shape [1, fs, hs]
return image_embeds.unsqueeze(0)
Expand All @@ -260,8 +264,11 @@ def get_input_embeddings(
if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings, [
self.vision_args.image_token_id, PIXTRAL_IMAGE_END_ID,
PIXTRAL_IMAGE_BREAK_ID
self.vision_args.image_token_id,
PIXTRAL_12B_IMAGE_END_ID,
PIXTRAL_12B_IMAGE_BREAK_ID,
PIXTRAL_LARGE_IMAGE_BREAK_ID,
PIXTRAL_LARGE_IMAGE_END_ID,
])
return inputs_embeds

Expand Down

0 comments on commit 54b8c70

Please sign in to comment.