Skip to content

Commit

Permalink
add kwargs everwhere
Browse files Browse the repository at this point in the history
  • Loading branch information
younesbelkada committed Oct 3, 2023
1 parent d0f27d2 commit b67fb2a
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions optimum/bettertransformer/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ def opt_forward(
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
raise_on_head_mask(layer_head_mask)

Expand Down Expand Up @@ -336,6 +337,7 @@ def t5_forward(
query_length=None,
use_cache=False,
output_attentions=False,
**kwargs,
):
raise_on_head_mask(layer_head_mask)

Expand Down Expand Up @@ -466,6 +468,7 @@ def bart_forward(
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
raise_on_head_mask(layer_head_mask)
Expand Down Expand Up @@ -769,6 +772,7 @@ def gpt_bigcode_forward(
encoder_attention_mask: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
if output_attentions is True:
raise ValueError("output_attentions=True can not be supported with BetterTransformer.")
Expand Down Expand Up @@ -827,6 +831,7 @@ def bloom_forward(
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
**kwargs,
):
raise_on_head_mask(head_mask)

Expand Down Expand Up @@ -911,6 +916,7 @@ def falcon_forward(
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
**kwargs,
):
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads
Expand Down

0 comments on commit b67fb2a

Please sign in to comment.