Skip to content

Commit 81e281c

Browse files
ebsmothersfacebook-github-bot
authored andcommitted
Fix attn_mask shape in MHA docstrings (#441)
Summary: Pull Request resolved: #441 n/a Reviewed By: ankitade, pikapecan Differential Revision: D47998161 fbshipit-source-id: 48804597c69979f76bcad92ade24a632dfda7f9c
1 parent 1aa2ed2 commit 81e281c

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

torchmultimodal/modules/layers/multi_head_attention.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,9 @@ def forward(
4646
"""
4747
Args:
4848
query (Tensor): input query of shape bsz x seq_len x embed_dim
49-
attn_mask (optional Tensor): attention mask of shape bsz x seq_len x seq_len. Two types of masks are supported.
49+
attn_mask (optional Tensor): attention mask of shape bsz x num_heads x seq_len x seq_len.
50+
Note that the num_heads dimension can equal 1 and the mask will be broadcasted to all heads.
51+
Two types of masks are supported.
5052
A boolean mask where a value of True indicates that the element should take part in attention.
5153
A float mask of the same type as query that is added to the attention score.
5254
is_causal (bool): If true, does causal attention masking. attn_mask should be set to None if this is set to True
@@ -124,7 +126,8 @@ def forward(
124126
query (Tensor): input query of shape bsz x target_seq_len x embed_dim
125127
key (Tensor): key of shape bsz x source_seq_len x embed_dim
126128
value (Tensor): value of shape bsz x source_seq_len x embed_dim
127-
attn_mask (optional Tensor): Attention mask of shape bsz x target_seq_len x source_seq_len.
129+
attn_mask (optional Tensor): Attention mask of shape bsz x num_heads x target_seq_len x source_seq_len.
130+
Note that the num_heads dimension can equal 1 and the mask will be broadcasted to all heads.
128131
Two types of masks are supported. A boolean mask where a value of True
129132
indicates that the element *should* take part in attention.
130133
A float mask of the same type as query, key, value that is added to the attention score.

0 commit comments

Comments
 (0)