From b67fb2a357f2570e662ed8e13f7f705f4272ee4e Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 3 Oct 2023 11:24:13 +0000 Subject: [PATCH] add kwargs everwhere --- optimum/bettertransformer/models/attention.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/optimum/bettertransformer/models/attention.py b/optimum/bettertransformer/models/attention.py index a7c82569cf3..d7ded056941 100644 --- a/optimum/bettertransformer/models/attention.py +++ b/optimum/bettertransformer/models/attention.py @@ -242,6 +242,7 @@ def opt_forward( attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: raise_on_head_mask(layer_head_mask) @@ -336,6 +337,7 @@ def t5_forward( query_length=None, use_cache=False, output_attentions=False, + **kwargs, ): raise_on_head_mask(layer_head_mask) @@ -466,6 +468,7 @@ def bart_forward( attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" raise_on_head_mask(layer_head_mask) @@ -769,6 +772,7 @@ def gpt_bigcode_forward( encoder_attention_mask: Optional[torch.Tensor] = None, use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, + **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: if output_attentions is True: raise ValueError("output_attentions=True can not be supported with BetterTransformer.") @@ -827,6 +831,7 @@ def bloom_forward( head_mask: Optional[torch.Tensor] = None, use_cache: bool = False, output_attentions: bool = False, + **kwargs, ): raise_on_head_mask(head_mask) @@ -911,6 +916,7 @@ def falcon_forward( head_mask: Optional[torch.Tensor] = None, use_cache: bool = False, output_attentions: bool = False, + **kwargs, ): fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads