Skip to content

Commit

Permalink
fix forwarding attn_softmax_bf16 param for gemma
Browse files Browse the repository at this point in the history
  • Loading branch information
tthakkal committed Oct 22, 2024
1 parent 03fa6dd commit 183dff2
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions optimum/habana/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,6 +723,7 @@ def forward(
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
token_idx: Optional[torch.Tensor] = None,
attn_softmax_bf16: Optional[bool] = False,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
Expand Down Expand Up @@ -751,6 +752,7 @@ def forward(
return_dict=return_dict,
cache_position=cache_position,
token_idx=token_idx,
attn_softmax_bf16=attn_softmax_bf16,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
flash_attention_causal_mask=flash_attention_causal_mask,
Expand Down Expand Up @@ -862,6 +864,7 @@ def prepare_inputs_for_generation(
"attention_mask": attention_mask,
"num_logits_to_keep": num_logits_to_keep,
"token_idx": token_idx,
"attn_softmax_bf16": kwargs.get("attn_softmax_bf16", False),
}
)
return model_inputs

0 comments on commit 183dff2

Please sign in to comment.