Skip to content

Commit

Permalink
move from keyword arg to pos arg multihead (ivy-llc#23376)
Browse files Browse the repository at this point in the history
  • Loading branch information
Killua7362 authored Sep 10, 2023
1 parent c010d17 commit 0d17da0
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions ivy/data_classes/container/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1094,10 +1094,10 @@ def _static_multi_head_attention(

def multi_head_attention(
self: ivy.Container,
/,
*,
key: Optional[Union[ivy.Array, ivy.NativeArray, ivy.Container]] = None,
value: Optional[Union[ivy.Array, ivy.NativeArray, ivy.Container]] = None,
/,
*,
num_heads: Union[int, ivy.Container] = 8,
scale: Optional[Union[float, ivy.Container]] = None,
attention_mask: Optional[
Expand Down
4 changes: 2 additions & 2 deletions ivy/functional/ivy/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,10 +710,10 @@ def scaled_dot_product_attention(
@handle_array_function
def multi_head_attention(
query: Union[ivy.Array, ivy.NativeArray],
/,
*,
key: Optional[Union[ivy.Array, ivy.NativeArray]] = None,
value: Optional[Union[ivy.Array, ivy.NativeArray]] = None,
/,
*,
num_heads: int = 8,
scale: Optional[float] = None,
attention_mask: Optional[Union[ivy.Array, ivy.NativeArray]] = None,
Expand Down

0 comments on commit 0d17da0

Please sign in to comment.