Skip to content

Commit 06c354b

Browse files
committed
remove unused functions
1 parent 163b097 commit 06c354b

File tree

2 files changed

+0
-73
lines changed

2 files changed

+0
-73
lines changed

tensorrt_llm/_torch/models/modeling_gemma3vl.py

Lines changed: 0 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -107,48 +107,6 @@ def __call__(
107107
}
108108

109109

110-
def get_gemma3_causal_mask(
111-
input_ids: torch.Tensor,
112-
image_token_index: int,
113-
sliding_window: Optional[int] = None,
114-
):
115-
print("[get_gemma3_causal_mask] input_ids: ", input_ids)
116-
assert input_ids.ndim == 1, "input_ids should be a 1D tensor."
117-
# Get token type ids. 0 corresponds to text tokens, 1 corresponds to image tokens.
118-
token_type_ids = torch.zeros_like(input_ids, device=input_ids.device)
119-
image_token_mask = (input_ids == image_token_index).to(
120-
device=input_ids.device, dtype=torch.bool)
121-
token_type_ids[image_token_mask] = 1
122-
123-
sequence_length = input_ids.shape[-1]
124-
# TODO: Use causal when sliding_window is larger than sequence_length.
125-
if sliding_window is None:
126-
causal_mask = torch.arange(
127-
sequence_length,
128-
device=input_ids.device).unsqueeze(0) <= torch.arange(
129-
sequence_length, device=input_ids.device).unsqueeze(1)
130-
else:
131-
attention_mask_1 = torch.arange(
132-
sequence_length,
133-
device=input_ids.device).unsqueeze(0) <= torch.arange(
134-
sequence_length, device=input_ids.device).unsqueeze(1)
135-
attention_mask_2 = torch.arange(
136-
sequence_length,
137-
device=input_ids.device).unsqueeze(0) > torch.arange(
138-
sequence_length,
139-
device=input_ids.device).unsqueeze(1) - sliding_window
140-
causal_mask = attention_mask_1 & attention_mask_2
141-
142-
# Apply a bidirectional mask for image tokens.
143-
if token_type_ids is not None:
144-
token_type_mask = token_type_ids.unsqueeze(
145-
0) == token_type_ids.unsqueeze(1)
146-
# If text token, do not change anything.
147-
token_type_mask[token_type_ids == 0] = False
148-
causal_mask = causal_mask.masked_fill(token_type_mask, True)
149-
return causal_mask
150-
151-
152110
@register_auto_model("Gemma3ForConditionalGeneration")
153111
@register_input_processor(Gemma3InputProcessor, model_type="gemma3")
154112
class Gemma3Model(PreTrainedModel):

tests/unittest/_torch/modeling/test_modeling_gemma3.py

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -212,37 +212,6 @@ def test_gemma3_sanity(self):
212212

213213
kv_cache_manager.shutdown()
214214

215-
def generate_causal_mask(self,
216-
batch_size,
217-
target_length,
218-
sequence_length,
219-
device=None):
220-
mask = torch.tril(
221-
torch.ones((target_length, sequence_length),
222-
dtype=torch.bool,
223-
device=device))
224-
# Expand to (batch_size, 1, target_length, sequence_length)
225-
mask = mask.unsqueeze(0).unsqueeze(1).expand(batch_size, 1,
226-
target_length,
227-
sequence_length)
228-
return mask
229-
230-
def generate_sliding_window_mask(self, batch_size: int, target_length: int,
231-
cache_position: torch.Tensor,
232-
device: torch.device,
233-
attention_window_size: int):
234-
# TRTLLM's sliding window attention is inclusive.
235-
effective_window_size = attention_window_size + 1
236-
attention_mask_1 = torch.arange(
237-
target_length,
238-
device=device).unsqueeze(0) <= cache_position.unsqueeze(-1)
239-
attention_mask_2 = torch.arange(target_length, device=device).unsqueeze(
240-
0) > cache_position.unsqueeze(-1) - effective_window_size
241-
attention_mask = attention_mask_1 & attention_mask_2
242-
attention_mask = attention_mask[None, None, :, :].expand(
243-
batch_size, 1, -1, -1)
244-
return attention_mask
245-
246215
@parameterized.expand([
247216
Scenario(backend="TRTLLM", config_name="1B"),
248217
Scenario(backend="VANILLA", config_name="1B"),

0 commit comments

Comments
 (0)