diff --git a/ivy/data_classes/container/layers.py b/ivy/data_classes/container/layers.py index b56ed5c84a9d6..8e28923ad204a 100644 --- a/ivy/data_classes/container/layers.py +++ b/ivy/data_classes/container/layers.py @@ -1094,10 +1094,10 @@ def _static_multi_head_attention( def multi_head_attention( self: ivy.Container, - /, - *, key: Optional[Union[ivy.Array, ivy.NativeArray, ivy.Container]] = None, value: Optional[Union[ivy.Array, ivy.NativeArray, ivy.Container]] = None, + /, + *, num_heads: Union[int, ivy.Container] = 8, scale: Optional[Union[float, ivy.Container]] = None, attention_mask: Optional[ diff --git a/ivy/functional/ivy/layers.py b/ivy/functional/ivy/layers.py index a471d62d08a78..a378c626db5da 100644 --- a/ivy/functional/ivy/layers.py +++ b/ivy/functional/ivy/layers.py @@ -710,10 +710,10 @@ def scaled_dot_product_attention( @handle_array_function def multi_head_attention( query: Union[ivy.Array, ivy.NativeArray], - /, - *, key: Optional[Union[ivy.Array, ivy.NativeArray]] = None, value: Optional[Union[ivy.Array, ivy.NativeArray]] = None, + /, + *, num_heads: int = 8, scale: Optional[float] = None, attention_mask: Optional[Union[ivy.Array, ivy.NativeArray]] = None,