@@ -147,6 +147,17 @@ def __init__(self, model_config: ModelConfig[Gemma3Config]):
147147 f"the { _MULTIMODAL_ENV_NAME } environment variable, or set it to '0'."
148148 )
149149
150+ print (
151+ "\n "
152+ "#####################################################################################\n "
153+ "NOTE: Gemma3VL decoder requires a custom mask while processing images.\n "
154+ "To ensure expected behavior, please:\n "
155+ " - Use the FlashInfer attention backend\n "
156+ " - Disable chunked prefill\n "
157+ " - Disable KV cache reuse\n "
158+ "#####################################################################################\n "
159+ "\n " )
160+
150161 config = model_config .pretrained_config
151162 super ().__init__ (config )
152163
@@ -276,17 +287,3 @@ def _get_image_features(self, pixel_values):
276287 attn_metadata = attn_metadata )[- 1 ]
277288 image_features = self .mm_projector (image_features )
278289 return image_features
279-
280-
281- def _load_weights_into_hf_module (
282- model : torch .nn .Module ,
283- weights : dict ,
284- prefix : str ,
285- model_name : str ,
286- ) -> None :
287- filtered_weights = filter_weights (prefix , weights )
288- missing_keys , _ = model .load_state_dict (filtered_weights )
289- if len (missing_keys ) > 0 :
290- raise KeyError (
291- f"Missing the following keys for the { model_name } in the checkpoint: "
292- f"[{ ', ' .join (missing_keys )} ]." )
0 commit comments