From 54b8c70f7963d5811b6b29724d92af0ad96e6b0e Mon Sep 17 00:00:00 2001
From: ywang96 <ywang@example.com>
Date: Sat, 21 Dec 2024 06:13:04 +0000
Subject: [PATCH] fix

Signed-off-by: ywang96 <ywang@example.com>
---
 vllm/model_executor/models/pixtral.py | 23 +++++++++++++++--------
 1 file changed, 15 insertions(+), 8 deletions(-)

diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py
index 6676dd16e005f..f3d66c2313198 100644
--- a/vllm/model_executor/models/pixtral.py
+++ b/vllm/model_executor/models/pixtral.py
@@ -45,8 +45,12 @@
 except ImportError:
     USE_XFORMERS_OPS = False
 
-PIXTRAL_IMAGE_BREAK_ID = 12
-PIXTRAL_IMAGE_END_ID = 13
+# These token ids cannot be retrieved from model config
+# so we hardcode them here.
+PIXTRAL_12B_IMAGE_BREAK_ID = 12
+PIXTRAL_12B_IMAGE_END_ID = 13
+PIXTRAL_LARGE_IMAGE_BREAK_ID = 14
+PIXTRAL_LARGE_IMAGE_END_ID = 15
 
 
 def get_max_pixtral_image_tokens(ctx: InputContext):
@@ -118,8 +122,7 @@ def input_mapper_for_pixtral(ctx: InputContext,
     for image_data in data_list:
         image = ImageChunk(image=image_data)
         encoding = tokenizer.instruct.mm_encoder(image)
-        image = torch.from_numpy(encoding.image).to(device="cuda",
-                                                    dtype=torch.float16)
+        image = torch.from_numpy(encoding.image).to(dtype=torch.float16)
         images.append(image)
         image_tokens_list.append(encoding.tokens)
 
@@ -237,8 +240,9 @@ def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
 
         # NOTE: Image embeddings are split into separate tensors for each image
         # by the indices of `[IMG_END]` token.
-        split_indices = torch.where(
-            image_tokens == PIXTRAL_IMAGE_END_ID)[0] + 1
+        image_end_condition = (image_tokens == PIXTRAL_12B_IMAGE_END_ID) | (
+            image_tokens == PIXTRAL_LARGE_IMAGE_END_ID)
+        split_indices = torch.where(image_end_condition)[0] + 1
         if len(split_indices) <= 1:
             # Do not split, return as tensor of shape [1, fs, hs]
             return image_embeds.unsqueeze(0)
@@ -260,8 +264,11 @@ def get_input_embeddings(
         if multimodal_embeddings is not None:
             inputs_embeds = merge_multimodal_embeddings(
                 input_ids, inputs_embeds, multimodal_embeddings, [
-                    self.vision_args.image_token_id, PIXTRAL_IMAGE_END_ID,
-                    PIXTRAL_IMAGE_BREAK_ID
+                    self.vision_args.image_token_id,
+                    PIXTRAL_12B_IMAGE_END_ID,
+                    PIXTRAL_12B_IMAGE_BREAK_ID,
+                    PIXTRAL_LARGE_IMAGE_BREAK_ID,
+                    PIXTRAL_LARGE_IMAGE_END_ID,
                 ])
         return inputs_embeds