Skip to content

Commit

Permalink
fix attention output with symbolic tensors and attention scores (kera…
Browse files Browse the repository at this point in the history
  • Loading branch information
Surya2k1 authored Dec 26, 2024
1 parent f54c127 commit 8907bcb
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 1 deletion.
2 changes: 1 addition & 1 deletion keras/src/layers/attention/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def compute_output_spec(
output_spec = KerasTensor(output_shape, dtype=self.compute_dtype)

# Handle attention scores if requested
if self._return_attention_scores:
if self._return_attention_scores or return_attention_scores:
scores_shape = (
query.shape[0],
query.shape[1],
Expand Down
12 changes: 12 additions & 0 deletions keras/src/layers/attention/attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,3 +417,15 @@ def test_return_attention_scores_true_tuple_then_unpack(self):
self.assertEqual(
attention_scores.shape, (2, 8, 4)
) # Attention scores shape

def test_return_attention_scores_with_symbolic_tensors(self):
"""Test to check outputs with symbolic tensors with
return_attention_scores = True"""
attention = layers.Attention()
x = layers.Input(shape=(3, 5))
y = layers.Input(shape=(4, 5))
output, attention_scores = attention(
[x, y], return_attention_scores=True
)
self.assertEqual(output.shape, (None, 3, 5)) # Output shape
self.assertEqual(attention_scores.shape, (None, 3, 4))

0 comments on commit 8907bcb

Please sign in to comment.