@@ -107,48 +107,6 @@ def __call__(
107107 }
108108
109109
110- def get_gemma3_causal_mask (
111- input_ids : torch .Tensor ,
112- image_token_index : int ,
113- sliding_window : Optional [int ] = None ,
114- ):
115- print ("[get_gemma3_causal_mask] input_ids: " , input_ids )
116- assert input_ids .ndim == 1 , "input_ids should be a 1D tensor."
117- # Get token type ids. 0 corresponds to text tokens, 1 corresponds to image tokens.
118- token_type_ids = torch .zeros_like (input_ids , device = input_ids .device )
119- image_token_mask = (input_ids == image_token_index ).to (
120- device = input_ids .device , dtype = torch .bool )
121- token_type_ids [image_token_mask ] = 1
122-
123- sequence_length = input_ids .shape [- 1 ]
124- # TODO: Use causal when sliding_window is larger than sequence_length.
125- if sliding_window is None :
126- causal_mask = torch .arange (
127- sequence_length ,
128- device = input_ids .device ).unsqueeze (0 ) <= torch .arange (
129- sequence_length , device = input_ids .device ).unsqueeze (1 )
130- else :
131- attention_mask_1 = torch .arange (
132- sequence_length ,
133- device = input_ids .device ).unsqueeze (0 ) <= torch .arange (
134- sequence_length , device = input_ids .device ).unsqueeze (1 )
135- attention_mask_2 = torch .arange (
136- sequence_length ,
137- device = input_ids .device ).unsqueeze (0 ) > torch .arange (
138- sequence_length ,
139- device = input_ids .device ).unsqueeze (1 ) - sliding_window
140- causal_mask = attention_mask_1 & attention_mask_2
141-
142- # Apply a bidirectional mask for image tokens.
143- if token_type_ids is not None :
144- token_type_mask = token_type_ids .unsqueeze (
145- 0 ) == token_type_ids .unsqueeze (1 )
146- # If text token, do not change anything.
147- token_type_mask [token_type_ids == 0 ] = False
148- causal_mask = causal_mask .masked_fill (token_type_mask , True )
149- return causal_mask
150-
151-
152110@register_auto_model ("Gemma3ForConditionalGeneration" )
153111@register_input_processor (Gemma3InputProcessor , model_type = "gemma3" )
154112class Gemma3Model (PreTrainedModel ):
0 commit comments