Skip to content

Commit

Permalink
fixed bug in testing strategy
Browse files Browse the repository at this point in the history
  • Loading branch information
AnnaTz committed Sep 27, 2023
1 parent 8f5ee33 commit c99bdbd
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions ivy_tests/test_ivy/test_functional/test_nn/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,8 @@ def _mha_helper(draw, same_pre_embed_dim=False, batch_second=False):
)
)

if _self_attention and _qkv_same_dim:
_num_keys = _num_queries
_static_shape = (_num_batches * num_heads, _num_keys, int(_embed_dim // num_heads))
static_k = draw(
st.one_of(
Expand All @@ -251,10 +253,7 @@ def _mha_helper(draw, same_pre_embed_dim=False, batch_second=False):
)
)

_mask_shape = (
_num_queries,
_num_queries if _self_attention and _qkv_same_dim else _num_keys,
)
_mask_shape = (_num_queries, _num_keys)
if len(_batch_dim) and draw(st.booleans()):
_mask_shape = (_num_batches * num_heads, *_mask_shape)
attention_mask = draw(
Expand All @@ -272,7 +271,7 @@ def _mha_helper(draw, same_pre_embed_dim=False, batch_second=False):
st.one_of(
helpers.array_values(
dtype="bool",
shape=(*_batch_dim, _mask_shape[-1]),
shape=(*_batch_dim, _num_keys),
),
st.none(),
)
Expand Down

0 comments on commit c99bdbd

Please sign in to comment.