Skip to content

Commit 163b097

Browse files
committed
feat: Custom masking utils for Gemma3 VLM
Signed-off-by: Balaram Buddharaju <[email protected]>
1 parent 9d894bc commit 163b097

File tree

6 files changed

+408
-24
lines changed

6 files changed

+408
-24
lines changed

tensorrt_llm/_torch/attention_backend/flashinfer.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from ..utils import get_global_attrs, get_model_extra_attrs
1616
from .interface import (AttentionBackend, AttentionMask, AttentionMetadata,
17-
PredefinedAttentionMask)
17+
CustomAttentionMask, PredefinedAttentionMask)
1818

1919
try:
2020
check_cuda_arch()
@@ -366,6 +366,12 @@ def _plan_with_params(self, plan_params: PlanParams) -> PlanParams:
366366
is_causal = plan_params.attention_mask_type == AttentionMaskType.causal
367367

368368
def prefill_plan():
369+
# Setting `window_left` to -1 for custom attention mask is important.
370+
# Else, FlashInfer proceeds to use SWA regardless of attention_mask_data.
371+
if plan_params.attention_mask_data is not None:
372+
window_left = -1
373+
else:
374+
window_left = plan_params.window_left
369375
prefill_wrapper.plan(
370376
self.qo_indptr[:self.num_contexts + 1],
371377
self.paged_kv_indptr_prefill[:self.num_contexts + 1],
@@ -377,9 +383,10 @@ def prefill_plan():
377383
self.page_size,
378384
causal=is_causal,
379385
sm_scale=plan_params.sm_scale,
380-
window_left=plan_params.window_left,
386+
window_left=window_left,
381387
q_data_type=plan_params.q_dtype,
382388
kv_data_type=plan_params.kv_dtype,
389+
custom_mask=plan_params.attention_mask_data,
383390
)
384391

385392
if plan_params in self._plan_params_to_wrappers:
@@ -473,8 +480,14 @@ def forward(self,
473480
*,
474481
attention_window_size: Optional[int] = None,
475482
attention_mask: AttentionMask = PredefinedAttentionMask.CAUSAL,
483+
attention_mask_data: Optional[torch.Tensor] = None,
476484
**kwargs) -> torch.Tensor:
477-
if attention_mask == PredefinedAttentionMask.CAUSAL:
485+
if attention_mask == CustomAttentionMask.CUSTOM:
486+
assert attention_mask_data is not None, "attention_mask_data is required for custom attention mask."
487+
attention_mask_type = int(AttentionMaskType.custom_mask)
488+
attention_mask_data = attention_mask_data if attention_mask_data.ndim == 1 else attention_mask_data.flatten(
489+
)
490+
elif attention_mask == PredefinedAttentionMask.CAUSAL:
478491
attention_mask_type = int(AttentionMaskType.causal)
479492
attention_mask_data = None
480493
elif attention_mask == PredefinedAttentionMask.FULL:

tensorrt_llm/_torch/attention_backend/interface.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -500,8 +500,15 @@ class PredefinedAttentionMask(str, Enum):
500500
FULL = "full"
501501

502502

503+
class CustomAttentionMask(str, Enum):
504+
"""
505+
Custom attention mask types
506+
"""
507+
CUSTOM = "custom"
508+
509+
503510
# May extend to custom attention mask type
504-
AttentionMask = Union[PredefinedAttentionMask]
511+
AttentionMask = Union[PredefinedAttentionMask, CustomAttentionMask]
505512

506513

507514
class AttentionBackend(Generic[TMetadata]):

tensorrt_llm/_torch/models/modeling_gemma3.py

Lines changed: 133 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@
1010
from tensorrt_llm.functional import PositionEmbeddingType, RotaryScalingType
1111
from tensorrt_llm.mapping import Mapping
1212

13-
from ..attention_backend import AttentionMetadata
14-
from ..attention_backend.interface import (PositionalEmbeddingParams,
13+
from ..attention_backend import AttentionMetadata, FlashInferAttentionMetadata
14+
from ..attention_backend.interface import (AttentionMask, CustomAttentionMask,
15+
PositionalEmbeddingParams,
1516
PredefinedAttentionMask, RopeParams)
1617
from ..distributed import AllReduceParams
1718
from ..model_config import ModelConfig
@@ -101,14 +102,19 @@ def forward(
101102
position_ids: Optional[torch.IntTensor],
102103
hidden_states: torch.Tensor,
103104
attn_metadata: AttentionMetadata,
104-
attention_mask: PredefinedAttentionMask = PredefinedAttentionMask.
105-
CAUSAL,
105+
attention_mask: AttentionMask = PredefinedAttentionMask.CAUSAL,
106106
mrope_config: Optional[dict] = None,
107107
all_reduce_params: Optional[AllReduceParams] = None,
108108
lora_params: Optional[dict] = None,
109+
attention_mask_data: Optional[torch.Tensor] = None,
109110
**kwargs,
110111
) -> torch.Tensor:
111112

113+
if attention_mask_data is not None:
114+
assert isinstance(
115+
attn_metadata, FlashInferAttentionMetadata
116+
), "Only FlashInfer backend supports custom attention mask currently."
117+
assert attention_mask == CustomAttentionMask.CUSTOM
112118
return super().forward(position_ids=position_ids,
113119
hidden_states=hidden_states,
114120
attn_metadata=attn_metadata,
@@ -117,6 +123,7 @@ def forward(
117123
all_reduce_params=all_reduce_params,
118124
lora_params=lora_params,
119125
attention_window_size=self.attention_window_size,
126+
attention_mask_data=attention_mask_data,
120127
**kwargs)
121128

122129
def apply_qk_norm(self, q, k):
@@ -214,6 +221,7 @@ def forward(
214221
hidden_states: torch.Tensor,
215222
attn_metadata: AttentionMetadata,
216223
residual: Optional[torch.Tensor] = None,
224+
attention_mask_data: Optional[torch.Tensor] = None,
217225
**kwargs,
218226
) -> torch.Tensor:
219227

@@ -223,6 +231,9 @@ def forward(
223231
position_ids=position_ids,
224232
hidden_states=hidden_states,
225233
attn_metadata=attn_metadata,
234+
attention_mask=CustomAttentionMask.CUSTOM if attention_mask_data
235+
is not None else PredefinedAttentionMask.CAUSAL,
236+
attention_mask_data=attention_mask_data,
226237
**kwargs,
227238
)
228239
hidden_states = self.post_attention_layernorm(hidden_states)
@@ -267,6 +278,8 @@ def forward(
267278
input_ids: Optional[torch.IntTensor] = None,
268279
position_ids: Optional[torch.IntTensor] = None,
269280
inputs_embeds: Optional[torch.FloatTensor] = None,
281+
local_attention_mask_data: Optional[torch.Tensor] = None,
282+
global_attention_mask_data: Optional[torch.Tensor] = None,
270283
**kwargs,
271284
) -> torch.Tensor:
272285
if (input_ids is None) ^ (inputs_embeds is not None):
@@ -280,9 +293,13 @@ def forward(
280293
hidden_states = inputs_embeds.to(self.dtype)
281294

282295
for decoder_layer in self.layers:
283-
hidden_states = decoder_layer(position_ids=position_ids,
284-
hidden_states=hidden_states,
285-
attn_metadata=attn_metadata)
296+
hidden_states = decoder_layer(
297+
position_ids=position_ids,
298+
hidden_states=hidden_states,
299+
attn_metadata=attn_metadata,
300+
attention_mask_data=local_attention_mask_data
301+
if decoder_layer.self_attn.is_sliding else
302+
global_attention_mask_data)
286303

287304
hidden_states = self.norm(hidden_states)
288305
return hidden_states
@@ -301,21 +318,130 @@ def __init__(
301318
hidden_size=model_config.pretrained_config.hidden_size,
302319
vocab_size=model_config.pretrained_config.vocab_size)
303320

321+
def get_context_mask(
322+
self,
323+
image_token_mask: torch.BoolTensor,
324+
effective_sliding_window: Optional[int] = None,
325+
):
326+
"""
327+
Returns an attention mask such that text tokens attend to each other in causal fashion while image
328+
tokens attend in causal fashion as well as to all other image tokens in a bidirectional manner.
329+
Args:
330+
image_token_mask: A boolean tensor of shape (sequence_length,) where True indicates an image token.
331+
effective_sliding_window: The effective sliding window size for the attention mask. Default is None, which means no sliding window.
332+
For Gemma3, this is the sliding window size from config (e.g. 512 for 1B model).
333+
Returns:
334+
A boolean attention mask of shape (sequence_length, sequence_length).
335+
"""
336+
device = image_token_mask.device
337+
sequence_length = len(image_token_mask)
338+
if effective_sliding_window is None or effective_sliding_window >= sequence_length:
339+
causal_mask = torch.arange(
340+
sequence_length, device=device).unsqueeze(0) <= torch.arange(
341+
sequence_length, device=device).unsqueeze(1)
342+
else:
343+
attention_mask_1 = torch.arange(
344+
sequence_length, device=device).unsqueeze(0) <= torch.arange(
345+
sequence_length, device=device).unsqueeze(1)
346+
attention_mask_2 = torch.arange(
347+
sequence_length, device=device).unsqueeze(0) > torch.arange(
348+
sequence_length,
349+
device=device).unsqueeze(1) - effective_sliding_window
350+
causal_mask = attention_mask_1 & attention_mask_2
351+
352+
# Apply a bidirectional mask for image tokens.
353+
token_type_ids = torch.zeros(sequence_length,
354+
dtype=torch.int32,
355+
device=device)
356+
# 1 for image tokens, 0 for text tokens.
357+
token_type_ids[image_token_mask] = 1
358+
token_type_mask = token_type_ids.unsqueeze(
359+
0) == token_type_ids.unsqueeze(1)
360+
# If text token, do not change anything.
361+
token_type_mask[token_type_ids == 0] = False
362+
causal_mask = causal_mask.masked_fill(token_type_mask, True)
363+
return causal_mask
364+
365+
# ASSUMPTIONS:
366+
# 1) Chunked prefill is disabled to avoid chunking image tokens as they need bidirectional attention.
367+
# 2) KV cache reuse is disabled to avoid partially matched image tokens (entire image must be reused to get things correct).
368+
def get_flashinfer_attention_mask(
369+
self,
370+
image_token_mask: torch.BoolTensor,
371+
attn_metadata: AttentionMetadata,
372+
effective_sliding_window: Optional[int] = None) -> torch.Tensor:
373+
"""
374+
This is specifically needed for context phase requests. Currently, we don't create custom mask for generation requests because FlashInfer backend
375+
doesn't use it anyway and there's nothing special we need to do for generation requests.
376+
- This function will only be called for a batch when there's at least one context request in the batch with image tokens.
377+
- In context phase, each sample's input_ids may have a mix of image tokens and text tokens where tokens corresponding to an image
378+
appear as a contiguous blob. Example: torch.IntTensor([2, 3, 4, 5, img_idx, img_idx, img_idx, ..., img_idx, 100])
379+
- While the text tokens attend to other tokens in a causal fashion, image tokens attend to others in a causal fashion and well as
380+
attend to other image tokens in a bidirectional manner. Hence, the need for custom masking.
381+
Args:
382+
image_token_mask: A boolean tensor of shape (len(input_ids),) where True indicates an image token. This corresponds to concatenated
383+
list of tokens for all samples in the batch.
384+
attn_metadata: The attention metadata for the batch.
385+
effective_sliding_window: The effective sliding window size for the attention mask. Default is None, which means no sliding window.
386+
For Gemma3, this is the sliding window size from config (e.g. 512 for 1B model).
387+
Returns:
388+
A flattened boolean mask of shape (sum(q_len[i] * k_len[i] for i in range(batch_size)).
389+
"""
390+
391+
assert isinstance(
392+
attn_metadata, FlashInferAttentionMetadata
393+
), "Only FlashInfer backend supports custom mask currently."
394+
num_contexts = attn_metadata.num_contexts
395+
assert num_contexts > 0, "There should be at least one context request in the batch for custom mask."
396+
397+
qo_indptr = attn_metadata.qo_indptr[:num_contexts + 1]
398+
cached_token_lens = attn_metadata.cached_token_lens[:num_contexts]
399+
assert (cached_token_lens == 0).all(
400+
), "cached_token_lens should be 0 for context requests since chunked prefill and kv cache reuse must be disabled."
401+
402+
# Create masks for context requests.
403+
context_mask_list = []
404+
for i in range(num_contexts):
405+
mask_i = self.get_context_mask(
406+
image_token_mask=image_token_mask[qo_indptr[i]:qo_indptr[i +
407+
1]],
408+
effective_sliding_window=effective_sliding_window,
409+
)
410+
context_mask_list.append(mask_i.flatten())
411+
return torch.cat(context_mask_list, dim=0).contiguous()
412+
304413
def forward(
305414
self,
306415
attn_metadata: AttentionMetadata,
307416
input_ids: torch.IntTensor = None,
308417
position_ids: Optional[torch.IntTensor] = None,
309418
inputs_embeds: Optional[torch.FloatTensor] = None,
310419
return_context_logits: bool = False,
420+
image_token_mask: Optional[torch.Tensor] = None,
311421
**kwargs,
312422
) -> torch.Tensor:
313423

424+
local_attention_mask_data = None
425+
global_attention_mask_data = None
426+
if image_token_mask is not None:
427+
global_attention_mask_data = self.get_flashinfer_attention_mask(
428+
image_token_mask=image_token_mask,
429+
attn_metadata=attn_metadata,
430+
effective_sliding_window=None,
431+
)
432+
local_attention_mask_data = self.get_flashinfer_attention_mask(
433+
image_token_mask=image_token_mask,
434+
attn_metadata=attn_metadata,
435+
effective_sliding_window=self.config.sliding_window,
436+
)
437+
314438
output = self.model(
315439
input_ids=input_ids,
316440
attn_metadata=attn_metadata,
317441
position_ids=position_ids,
318442
inputs_embeds=inputs_embeds,
443+
local_attention_mask_data=local_attention_mask_data,
444+
global_attention_mask_data=global_attention_mask_data,
319445
)
320446

321447
return self.logits_processor.forward(

tensorrt_llm/_torch/models/modeling_gemma3vl.py

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,12 +101,54 @@ def __call__(
101101
input_ids = preprocess_outputs[0]["mm_processor_kwargs"]["input_ids"]
102102
mm_features = self._process(pixel_values)
103103
multimodal_data = {}
104-
multimodal_data["multimodal_embedding"] = mm_features
104+
multimodal_data["multimodal_embedding"] = mm_features.squeeze(dim=0)
105105
return input_ids[0].to(torch.int32).tolist(), {
106106
"multimodal_data": multimodal_data
107107
}
108108

109109

110+
def get_gemma3_causal_mask(
111+
input_ids: torch.Tensor,
112+
image_token_index: int,
113+
sliding_window: Optional[int] = None,
114+
):
115+
print("[get_gemma3_causal_mask] input_ids: ", input_ids)
116+
assert input_ids.ndim == 1, "input_ids should be a 1D tensor."
117+
# Get token type ids. 0 corresponds to text tokens, 1 corresponds to image tokens.
118+
token_type_ids = torch.zeros_like(input_ids, device=input_ids.device)
119+
image_token_mask = (input_ids == image_token_index).to(
120+
device=input_ids.device, dtype=torch.bool)
121+
token_type_ids[image_token_mask] = 1
122+
123+
sequence_length = input_ids.shape[-1]
124+
# TODO: Use causal when sliding_window is larger than sequence_length.
125+
if sliding_window is None:
126+
causal_mask = torch.arange(
127+
sequence_length,
128+
device=input_ids.device).unsqueeze(0) <= torch.arange(
129+
sequence_length, device=input_ids.device).unsqueeze(1)
130+
else:
131+
attention_mask_1 = torch.arange(
132+
sequence_length,
133+
device=input_ids.device).unsqueeze(0) <= torch.arange(
134+
sequence_length, device=input_ids.device).unsqueeze(1)
135+
attention_mask_2 = torch.arange(
136+
sequence_length,
137+
device=input_ids.device).unsqueeze(0) > torch.arange(
138+
sequence_length,
139+
device=input_ids.device).unsqueeze(1) - sliding_window
140+
causal_mask = attention_mask_1 & attention_mask_2
141+
142+
# Apply a bidirectional mask for image tokens.
143+
if token_type_ids is not None:
144+
token_type_mask = token_type_ids.unsqueeze(
145+
0) == token_type_ids.unsqueeze(1)
146+
# If text token, do not change anything.
147+
token_type_mask[token_type_ids == 0] = False
148+
causal_mask = causal_mask.masked_fill(token_type_mask, True)
149+
return causal_mask
150+
151+
110152
@register_auto_model("Gemma3ForConditionalGeneration")
111153
@register_input_processor(Gemma3InputProcessor, model_type="gemma3")
112154
class Gemma3Model(PreTrainedModel):
@@ -129,6 +171,7 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig], *args,
129171

130172
self.model_config = model_config
131173
self.vocab_size = config.text_config.vocab_size
174+
self.sliding_window = config.text_config.sliding_window
132175
self.model_dtype = getattr(config.text_config, "torch_dtype",
133176
torch.float16)
134177
logger.info(f"[Gemma3Model::__init__]{self.dtype=} {self.model_dtype=}")
@@ -172,14 +215,26 @@ def forward(
172215
mm_embed
173216
) == num_context_requests, "Number of multimodal features (if provided) should be equal to number of context requests"
174217

218+
mm_token_ids = torch.tensor([self.image_token_index
219+
]).to(input_ids.device)
220+
mm_token_mask = None
221+
if len(mm_embed) > 0:
222+
# Get token type ids. 0 corresponds to text tokens, 1 corresponds to image tokens.
223+
mm_token_mask = torch.isin(input_ids, mm_token_ids)
175224
input_ids, inputs_embeds = fuse_input_embeds(
176225
embedding_layer=self.llm.model.embed_tokens,
177226
input_ids=input_ids,
178227
mm_embeds=mm_embed,
179228
mm_token_ids=torch.tensor([self.image_token_index
180229
]).to(input_ids.device))
181-
logits = self.llm.forward(attn_metadata, input_ids, position_ids,
182-
inputs_embeds, return_context_logits)
230+
logits = self.llm.forward(
231+
attn_metadata=attn_metadata,
232+
input_ids=input_ids,
233+
position_ids=position_ids,
234+
inputs_embeds=inputs_embeds,
235+
return_context_logits=return_context_logits,
236+
image_token_mask=mm_token_mask,
237+
)
183238
return logits
184239

185240

0 commit comments

Comments
 (0)