diff --git a/optimum/bettertransformer/models/attention.py b/optimum/bettertransformer/models/attention.py index 702aca3257b..d7ded056941 100644 --- a/optimum/bettertransformer/models/attention.py +++ b/optimum/bettertransformer/models/attention.py @@ -242,6 +242,7 @@ def opt_forward( attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: raise_on_head_mask(layer_head_mask) @@ -336,6 +337,7 @@ def t5_forward( query_length=None, use_cache=False, output_attentions=False, + **kwargs, ): raise_on_head_mask(layer_head_mask) @@ -466,6 +468,7 @@ def bart_forward( attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" raise_on_head_mask(layer_head_mask) @@ -583,6 +586,7 @@ def llama_forward( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, + **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions is True: raise ValueError("output_attentions=True can not be supported with BetterTransformer.") @@ -768,6 +772,7 @@ def gpt_bigcode_forward( encoder_attention_mask: Optional[torch.Tensor] = None, use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, + **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: if output_attentions is True: raise ValueError("output_attentions=True can not be supported with BetterTransformer.") @@ -826,6 +831,7 @@ def bloom_forward( head_mask: Optional[torch.Tensor] = None, use_cache: bool = False, output_attentions: bool = False, + **kwargs, ): raise_on_head_mask(head_mask) @@ -910,6 +916,7 @@ def falcon_forward( head_mask: Optional[torch.Tensor] = None, use_cache: bool = False, output_attentions: bool = False, + **kwargs, ): fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads