1010from tensorrt_llm .functional import PositionEmbeddingType , RotaryScalingType
1111from tensorrt_llm .mapping import Mapping
1212
13- from ..attention_backend import AttentionMetadata
14- from ..attention_backend .interface import (PositionalEmbeddingParams ,
13+ from ..attention_backend import AttentionMetadata , FlashInferAttentionMetadata
14+ from ..attention_backend .interface import (AttentionMask , CustomAttentionMask ,
15+ PositionalEmbeddingParams ,
1516 PredefinedAttentionMask , RopeParams )
1617from ..distributed import AllReduceParams
1718from ..model_config import ModelConfig
@@ -101,14 +102,19 @@ def forward(
101102 position_ids : Optional [torch .IntTensor ],
102103 hidden_states : torch .Tensor ,
103104 attn_metadata : AttentionMetadata ,
104- attention_mask : PredefinedAttentionMask = PredefinedAttentionMask .
105- CAUSAL ,
105+ attention_mask : AttentionMask = PredefinedAttentionMask .CAUSAL ,
106106 mrope_config : Optional [dict ] = None ,
107107 all_reduce_params : Optional [AllReduceParams ] = None ,
108108 lora_params : Optional [dict ] = None ,
109+ attention_mask_data : Optional [torch .Tensor ] = None ,
109110 ** kwargs ,
110111 ) -> torch .Tensor :
111112
113+ if attention_mask_data is not None :
114+ assert isinstance (
115+ attn_metadata , FlashInferAttentionMetadata
116+ ), "Only FlashInfer backend supports custom attention mask currently."
117+ assert attention_mask == CustomAttentionMask .CUSTOM
112118 return super ().forward (position_ids = position_ids ,
113119 hidden_states = hidden_states ,
114120 attn_metadata = attn_metadata ,
@@ -117,6 +123,7 @@ def forward(
117123 all_reduce_params = all_reduce_params ,
118124 lora_params = lora_params ,
119125 attention_window_size = self .attention_window_size ,
126+ attention_mask_data = attention_mask_data ,
120127 ** kwargs )
121128
122129 def apply_qk_norm (self , q , k ):
@@ -214,6 +221,7 @@ def forward(
214221 hidden_states : torch .Tensor ,
215222 attn_metadata : AttentionMetadata ,
216223 residual : Optional [torch .Tensor ] = None ,
224+ attention_mask_data : Optional [torch .Tensor ] = None ,
217225 ** kwargs ,
218226 ) -> torch .Tensor :
219227
@@ -223,6 +231,9 @@ def forward(
223231 position_ids = position_ids ,
224232 hidden_states = hidden_states ,
225233 attn_metadata = attn_metadata ,
234+ attention_mask = CustomAttentionMask .CUSTOM if attention_mask_data
235+ is not None else PredefinedAttentionMask .CAUSAL ,
236+ attention_mask_data = attention_mask_data ,
226237 ** kwargs ,
227238 )
228239 hidden_states = self .post_attention_layernorm (hidden_states )
@@ -267,6 +278,8 @@ def forward(
267278 input_ids : Optional [torch .IntTensor ] = None ,
268279 position_ids : Optional [torch .IntTensor ] = None ,
269280 inputs_embeds : Optional [torch .FloatTensor ] = None ,
281+ local_attention_mask_data : Optional [torch .Tensor ] = None ,
282+ global_attention_mask_data : Optional [torch .Tensor ] = None ,
270283 ** kwargs ,
271284 ) -> torch .Tensor :
272285 if (input_ids is None ) ^ (inputs_embeds is not None ):
@@ -280,9 +293,13 @@ def forward(
280293 hidden_states = inputs_embeds .to (self .dtype )
281294
282295 for decoder_layer in self .layers :
283- hidden_states = decoder_layer (position_ids = position_ids ,
284- hidden_states = hidden_states ,
285- attn_metadata = attn_metadata )
296+ hidden_states = decoder_layer (
297+ position_ids = position_ids ,
298+ hidden_states = hidden_states ,
299+ attn_metadata = attn_metadata ,
300+ attention_mask_data = local_attention_mask_data
301+ if decoder_layer .self_attn .is_sliding else
302+ global_attention_mask_data )
286303
287304 hidden_states = self .norm (hidden_states )
288305 return hidden_states
@@ -301,21 +318,130 @@ def __init__(
301318 hidden_size = model_config .pretrained_config .hidden_size ,
302319 vocab_size = model_config .pretrained_config .vocab_size )
303320
321+ def get_context_mask (
322+ self ,
323+ image_token_mask : torch .BoolTensor ,
324+ effective_sliding_window : Optional [int ] = None ,
325+ ):
326+ """
327+ Returns an attention mask such that text tokens attend to each other in causal fashion while image
328+ tokens attend in causal fashion as well as to all other image tokens in a bidirectional manner.
329+ Args:
330+ image_token_mask: A boolean tensor of shape (sequence_length,) where True indicates an image token.
331+ effective_sliding_window: The effective sliding window size for the attention mask. Default is None, which means no sliding window.
332+ For Gemma3, this is the sliding window size from config (e.g. 512 for 1B model).
333+ Returns:
334+ A boolean attention mask of shape (sequence_length, sequence_length).
335+ """
336+ device = image_token_mask .device
337+ sequence_length = len (image_token_mask )
338+ if effective_sliding_window is None or effective_sliding_window >= sequence_length :
339+ causal_mask = torch .arange (
340+ sequence_length , device = device ).unsqueeze (0 ) <= torch .arange (
341+ sequence_length , device = device ).unsqueeze (1 )
342+ else :
343+ attention_mask_1 = torch .arange (
344+ sequence_length , device = device ).unsqueeze (0 ) <= torch .arange (
345+ sequence_length , device = device ).unsqueeze (1 )
346+ attention_mask_2 = torch .arange (
347+ sequence_length , device = device ).unsqueeze (0 ) > torch .arange (
348+ sequence_length ,
349+ device = device ).unsqueeze (1 ) - effective_sliding_window
350+ causal_mask = attention_mask_1 & attention_mask_2
351+
352+ # Apply a bidirectional mask for image tokens.
353+ token_type_ids = torch .zeros (sequence_length ,
354+ dtype = torch .int32 ,
355+ device = device )
356+ # 1 for image tokens, 0 for text tokens.
357+ token_type_ids [image_token_mask ] = 1
358+ token_type_mask = token_type_ids .unsqueeze (
359+ 0 ) == token_type_ids .unsqueeze (1 )
360+ # If text token, do not change anything.
361+ token_type_mask [token_type_ids == 0 ] = False
362+ causal_mask = causal_mask .masked_fill (token_type_mask , True )
363+ return causal_mask
364+
365+ # ASSUMPTIONS:
366+ # 1) Chunked prefill is disabled to avoid chunking image tokens as they need bidirectional attention.
367+ # 2) KV cache reuse is disabled to avoid partially matched image tokens (entire image must be reused to get things correct).
368+ def get_flashinfer_attention_mask (
369+ self ,
370+ image_token_mask : torch .BoolTensor ,
371+ attn_metadata : AttentionMetadata ,
372+ effective_sliding_window : Optional [int ] = None ) -> torch .Tensor :
373+ """
374+ This is specifically needed for context phase requests. Currently, we don't create custom mask for generation requests because FlashInfer backend
375+ doesn't use it anyway and there's nothing special we need to do for generation requests.
376+ - This function will only be called for a batch when there's at least one context request in the batch with image tokens.
377+ - In context phase, each sample's input_ids may have a mix of image tokens and text tokens where tokens corresponding to an image
378+ appear as a contiguous blob. Example: torch.IntTensor([2, 3, 4, 5, img_idx, img_idx, img_idx, ..., img_idx, 100])
379+ - While the text tokens attend to other tokens in a causal fashion, image tokens attend to others in a causal fashion and well as
380+ attend to other image tokens in a bidirectional manner. Hence, the need for custom masking.
381+ Args:
382+ image_token_mask: A boolean tensor of shape (len(input_ids),) where True indicates an image token. This corresponds to concatenated
383+ list of tokens for all samples in the batch.
384+ attn_metadata: The attention metadata for the batch.
385+ effective_sliding_window: The effective sliding window size for the attention mask. Default is None, which means no sliding window.
386+ For Gemma3, this is the sliding window size from config (e.g. 512 for 1B model).
387+ Returns:
388+ A flattened boolean mask of shape (sum(q_len[i] * k_len[i] for i in range(batch_size)).
389+ """
390+
391+ assert isinstance (
392+ attn_metadata , FlashInferAttentionMetadata
393+ ), "Only FlashInfer backend supports custom mask currently."
394+ num_contexts = attn_metadata .num_contexts
395+ assert num_contexts > 0 , "There should be at least one context request in the batch for custom mask."
396+
397+ qo_indptr = attn_metadata .qo_indptr [:num_contexts + 1 ]
398+ cached_token_lens = attn_metadata .cached_token_lens [:num_contexts ]
399+ assert (cached_token_lens == 0 ).all (
400+ ), "cached_token_lens should be 0 for context requests since chunked prefill and kv cache reuse must be disabled."
401+
402+ # Create masks for context requests.
403+ context_mask_list = []
404+ for i in range (num_contexts ):
405+ mask_i = self .get_context_mask (
406+ image_token_mask = image_token_mask [qo_indptr [i ]:qo_indptr [i +
407+ 1 ]],
408+ effective_sliding_window = effective_sliding_window ,
409+ )
410+ context_mask_list .append (mask_i .flatten ())
411+ return torch .cat (context_mask_list , dim = 0 ).contiguous ()
412+
304413 def forward (
305414 self ,
306415 attn_metadata : AttentionMetadata ,
307416 input_ids : torch .IntTensor = None ,
308417 position_ids : Optional [torch .IntTensor ] = None ,
309418 inputs_embeds : Optional [torch .FloatTensor ] = None ,
310419 return_context_logits : bool = False ,
420+ image_token_mask : Optional [torch .Tensor ] = None ,
311421 ** kwargs ,
312422 ) -> torch .Tensor :
313423
424+ local_attention_mask_data = None
425+ global_attention_mask_data = None
426+ if image_token_mask is not None :
427+ global_attention_mask_data = self .get_flashinfer_attention_mask (
428+ image_token_mask = image_token_mask ,
429+ attn_metadata = attn_metadata ,
430+ effective_sliding_window = None ,
431+ )
432+ local_attention_mask_data = self .get_flashinfer_attention_mask (
433+ image_token_mask = image_token_mask ,
434+ attn_metadata = attn_metadata ,
435+ effective_sliding_window = self .config .sliding_window ,
436+ )
437+
314438 output = self .model (
315439 input_ids = input_ids ,
316440 attn_metadata = attn_metadata ,
317441 position_ids = position_ids ,
318442 inputs_embeds = inputs_embeds ,
443+ local_attention_mask_data = local_attention_mask_data ,
444+ global_attention_mask_data = global_attention_mask_data ,
319445 )
320446
321447 return self .logits_processor .forward (
0 commit comments