From c99bdbd9339b70b1035b6847828f7f6821b527d9 Mon Sep 17 00:00:00 2001 From: AnnaTz <111577222+AnnaTz@users.noreply.github.com> Date: Wed, 27 Sep 2023 18:02:35 +0100 Subject: [PATCH] fixed bug in testing strategy --- .../test_ivy/test_functional/test_nn/test_layers.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/ivy_tests/test_ivy/test_functional/test_nn/test_layers.py b/ivy_tests/test_ivy/test_functional/test_nn/test_layers.py index 332be410ab5b2..f186b7f0757ed 100644 --- a/ivy_tests/test_ivy/test_functional/test_nn/test_layers.py +++ b/ivy_tests/test_ivy/test_functional/test_nn/test_layers.py @@ -225,6 +225,8 @@ def _mha_helper(draw, same_pre_embed_dim=False, batch_second=False): ) ) + if _self_attention and _qkv_same_dim: + _num_keys = _num_queries _static_shape = (_num_batches * num_heads, _num_keys, int(_embed_dim // num_heads)) static_k = draw( st.one_of( @@ -251,10 +253,7 @@ def _mha_helper(draw, same_pre_embed_dim=False, batch_second=False): ) ) - _mask_shape = ( - _num_queries, - _num_queries if _self_attention and _qkv_same_dim else _num_keys, - ) + _mask_shape = (_num_queries, _num_keys) if len(_batch_dim) and draw(st.booleans()): _mask_shape = (_num_batches * num_heads, *_mask_shape) attention_mask = draw( @@ -272,7 +271,7 @@ def _mha_helper(draw, same_pre_embed_dim=False, batch_second=False): st.one_of( helpers.array_values( dtype="bool", - shape=(*_batch_dim, _mask_shape[-1]), + shape=(*_batch_dim, _num_keys), ), st.none(), )