From 81ac865b5e9499bcc1acaa7a59899b501d94e262 Mon Sep 17 00:00:00 2001
From: Cyrus Leung <tlleungac@connect.ust.hk>
Date: Wed, 23 Oct 2024 22:09:04 +0800
Subject: [PATCH] [Bugfix] Fix `_init_vision_model` in NVLM_D model (#9611)

Co-authored-by: Isotr0py <2037008807@qq.com>
Signed-off-by: Alvant <alvasian@yandex.ru>
---
 vllm/model_executor/models/nvlm_d.py | 37 +++++++++++++++++++++-------
 1 file changed, 28 insertions(+), 9 deletions(-)

diff --git a/vllm/model_executor/models/nvlm_d.py b/vllm/model_executor/models/nvlm_d.py
index 3e3c3b05879fb..df4fd0a3256e9 100644
--- a/vllm/model_executor/models/nvlm_d.py
+++ b/vllm/model_executor/models/nvlm_d.py
@@ -58,12 +58,31 @@ def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential:
             nn.Linear(llm_intermediate_size, llm_hidden_size, bias=False),
         )
 
-    def _init_vision_model(self, config: PretrainedConfig,
-                           quant_config: Optional[QuantizationConfig],
-                           num_hidden_layers: int):
-        # We added additional dummy heads to the original num of heads to make
-        # the number of heads divisible by 8.
-        return InternVisionModel(config.vision_config,
-                                 quant_config=quant_config,
-                                 num_hidden_layers_override=num_hidden_layers,
-                                 num_dummy_heads=7)
+    def _init_vision_model(
+        self,
+        config: PretrainedConfig,
+        quant_config: Optional[QuantizationConfig],
+        *,
+        is_mono: bool,
+        prefix: str,
+    ):
+        if not is_mono:
+            vision_feature_layer = config.select_layer
+            if vision_feature_layer < 0:
+                num_hidden_layers = config.vision_config.num_hidden_layers \
+                    + vision_feature_layer + 1
+            else:
+                num_hidden_layers = vision_feature_layer + 1
+
+            # We added additional dummy heads to the original num of heads to
+            # make the number of heads divisible by 8.
+            return InternVisionModel(
+                config.vision_config,
+                quant_config=quant_config,
+                num_hidden_layers_override=num_hidden_layers,
+                num_dummy_heads=7,
+                prefix=prefix,
+            )
+        else:
+            msg = "Monolith mode is not applicable to NVLM_D"
+            raise NotImplementedError(msg)