Skip to content

Commit 43671ec

Browse files
committed
[None][chore]: Add note in Gemma3VL ctor
Signed-off-by: Balaram Buddharaju <[email protected]>
1 parent 3e46624 commit 43671ec

File tree

1 file changed

+11
-14
lines changed

1 file changed

+11
-14
lines changed

tensorrt_llm/_torch/models/modeling_gemma3vl.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,17 @@ def __init__(self, model_config: ModelConfig[Gemma3Config]):
147147
f"the {_MULTIMODAL_ENV_NAME} environment variable, or set it to '0'."
148148
)
149149

150+
print(
151+
"\n"
152+
"#####################################################################################\n"
153+
"NOTE: Gemma3VL decoder requires a custom mask while processing images.\n"
154+
"To ensure expected behavior, please:\n"
155+
" - Use the FlashInfer attention backend\n"
156+
" - Disable chunked prefill\n"
157+
" - Disable KV cache reuse\n"
158+
"#####################################################################################\n"
159+
"\n")
160+
150161
config = model_config.pretrained_config
151162
super().__init__(config)
152163

@@ -276,17 +287,3 @@ def _get_image_features(self, pixel_values):
276287
attn_metadata=attn_metadata)[-1]
277288
image_features = self.mm_projector(image_features)
278289
return image_features
279-
280-
281-
def _load_weights_into_hf_module(
282-
model: torch.nn.Module,
283-
weights: dict,
284-
prefix: str,
285-
model_name: str,
286-
) -> None:
287-
filtered_weights = filter_weights(prefix, weights)
288-
missing_keys, _ = model.load_state_dict(filtered_weights)
289-
if len(missing_keys) > 0:
290-
raise KeyError(
291-
f"Missing the following keys for the {model_name} in the checkpoint: "
292-
f"[{', '.join(missing_keys)}].")

0 commit comments

Comments
 (0)