diff --git a/specforge/core/eagle3.py b/specforge/core/eagle3.py index 183c2f42..a7ef992f 100644 --- a/specforge/core/eagle3.py +++ b/specforge/core/eagle3.py @@ -276,13 +276,6 @@ def forward( input_ids = padding(input_ids, left=False) position_mask = padding(position_mask, left=False) loss_mask = padding(loss_mask, left=False) - if self.attention_backend == "sdpa": - ind = torch.arange(seq_length, device=attention_mask.device) - ind0 = ind[idx:] - ind1 = ind[: seq_length - idx] - attention_mask[:, :, ind0, ind1] = torch.finfo( - attention_mask.dtype - ).min # Flex attention mask shirnking is handled inside attention module return plosses, vlosses, acces @@ -658,13 +651,6 @@ def forward( input_ids = padding(input_ids, left=False) position_mask = padding(position_mask, left=False) loss_mask = padding(loss_mask, left=False) - if self.attention_backend == "sdpa": - ind = torch.arange(seq_length, device=attention_mask.device) - ind0 = ind[idx:] - ind1 = ind[: seq_length - idx] - attention_mask[:, :, ind0, ind1] = torch.finfo( - attention_mask.dtype - ).min # Flex attention mask shirnking is handled inside attention module return plosses, vlosses, acces