Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 117 additions & 2 deletions tests/modules/layers/test_multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@

import pytest
import torch
from tests.test_utils import assert_expected
from tests.test_utils import assert_expected, init_weights_with_constant
from torch import nn
from torchmultimodal.modules.layers.multi_head_attention import MultiHeadSelfAttention
from torchmultimodal.modules.layers.multi_head_attention import (
MultiHeadAttentionWithCache,
MultiHeadSelfAttention,
)


class TestMultiHeadSelfAttention:
Expand Down Expand Up @@ -52,3 +55,115 @@ def test_scripting(
q = torch.Tensor([[[1, 2, 3, 1], [4, 3, 2, 1], [1, 1, 1, 1]]])
scripted_model = torch.jit.script(multi_head_self_attn)
assert_expected(scripted_model(q), multi_head_self_attn(q), rtol=0, atol=1e-4)


class TestMultiHeadAttentionWithCache:
@pytest.fixture
def dim_q(self):
return 4

@pytest.fixture
def dim_kv(self):
return 2

@pytest.fixture
def q(self):
return torch.Tensor([[[1, 2, 3, 1], [4, 3, 2, 1], [1, 1, 1, 1]]])

@pytest.fixture
def current_key_value(self):
return torch.Tensor(
[
[
[[8.0, 8.0], [11.0, 11.0], [5.0, 5.0]],
[[8.0, 8.0], [11.0, 11.0], [5.0, 5.0]],
]
]
)

@pytest.fixture
def past_key_value(self):
return torch.Tensor(
[
[
[[7.0, 7.0], [9.0, 9.0], [4.0, 4.0]],
[[7.0, 7.0], [9.0, 9.0], [4.0, 4.0]],
]
]
)

@pytest.fixture
def multi_head_self_attn_use_cache(self, dim_q):
mha = MultiHeadAttentionWithCache(dim_q, dim_q, num_heads=2, use_cache=True)
init_weights_with_constant(mha)
mha.eval()
return mha

@pytest.fixture
def multi_head_cross_attn(self, dim_q, dim_kv):
mha = MultiHeadAttentionWithCache(dim_q, dim_kv, num_heads=2)
init_weights_with_constant(mha)
mha.eval()
return mha

def test_multi_head_self_attention_use_cache(
self,
multi_head_self_attn_use_cache,
current_key_value,
past_key_value,
q,
):
actual = multi_head_self_attn_use_cache(
q, q, q, past_key_value=(past_key_value, past_key_value)
)
expected = torch.tensor(
[
[
[45.0, 45.0, 45.0, 45.0],
[45.0, 45.0, 45.0, 45.0],
[45.0, 45.0, 45.0, 45.0],
]
]
)
assert_expected(actual.attn_output, expected, rtol=0, atol=1e-4)
# Check that the cache is properly updated
assert_expected(
actual.past_key_value[0],
torch.cat([past_key_value, current_key_value], dim=2),
)
assert_expected(
actual.past_key_value[1],
torch.cat([past_key_value, current_key_value], dim=2),
)

def test_multi_head_cross_attention(self, multi_head_cross_attn, q):
kv = torch.Tensor(torch.Tensor([[[3, 2], [1, 1]]]))
actual = multi_head_cross_attn(q, kv, kv)
expected = torch.tensor(
[
[
[25.0, 25.0, 25.0, 25.0],
[25.0, 25.0, 25.0, 25.0],
[25.0, 25.0, 25.0, 25.0],
],
]
)
assert_expected(actual, expected, rtol=0, atol=1e-4)

def test_scripting(
self,
multi_head_self_attn_use_cache,
q,
):
scripted_model = torch.jit.script(multi_head_self_attn_use_cache)
assert_expected(
scripted_model(q, q, q).attn_output,
multi_head_self_attn_use_cache(q, q, q).attn_output,
rtol=0,
atol=1e-4,
)

def test_multi_head_cross_attention_invalid_input(self, multi_head_cross_attn, q):
kv = torch.Tensor(torch.Tensor([[[3, 2]], [[1, 1]]]))
with pytest.raises(ValueError):
actual = multi_head_cross_attn(q, kv, kv)
101 changes: 100 additions & 1 deletion torchmultimodal/modules/layers/multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,19 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional
from typing import NamedTuple, Optional, Tuple, Union

import torch

import torch.nn.functional as F
from torch import nn, Tensor


class MHAWithCacheOutput(NamedTuple):
attn_output: Tensor
past_key_value: Tuple[Tensor, Tensor]


class MultiHeadSelfAttention(nn.Module):
"""
Multihead self attention.
Expand Down Expand Up @@ -66,3 +73,95 @@ def forward(

attn_out = self.output_proj(attn)
return attn_out


class MultiHeadAttentionWithCache(nn.Module):
"""
MultiHeadAttention module for both self-attention(SA) and cross-attention(CA).
This class supports a cache mechanism for decoders to store previous states through
"past_key_value". Key/value states should be only cached for self-attention cases.
q, k, v share the same dimension for self-attention,
but different for cross-attention, CA requires encoder hidden states dim as k, v dims.

Args:
dim_q (int): query embedding dimension
dim_kv (int): key, value embedding dimension,
same as dim_q for SA; equals to encoder dimension for cross-attention
num_heads (int): number of attention heads
dropout (float): dropout rate
"""

def __init__(
self,
dim_q: int,
dim_kv: int,
num_heads: int,
dropout: float = 0.0,
use_cache: bool = False,
) -> None:
super().__init__()
self.num_heads = num_heads
self.q_proj = nn.Linear(dim_q, dim_q)
self.k_proj = nn.Linear(dim_kv, dim_q)
self.v_proj = nn.Linear(dim_kv, dim_q)
self.output_proj = nn.Linear(dim_q, dim_q)
self.dropout = dropout
self.use_cache = use_cache

def forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
attn_mask: Optional[Tensor] = None,
past_key_value: Optional[Tuple[Tensor, Tensor]] = None,
is_causal: bool = False,
) -> Union[Tensor, MHAWithCacheOutput]:
"""
Args:
query (Tensor): input query of shape bsz x target_seq_len x embed_dim
key (Tensor): key of shape bsz x source_seq_len x embed_dim
value (Tensor): value of shape bsz x source_seq_len x embed_dim
attn_mask (optional Tensor): Attention mask of shape bsz x target_seq_len x source_seq_len.
Two types of masks are supported. A boolean mask where a value of True
indicates that the element *should* take part in attention.
A float mask of the same type as query, key, value that is added to the attention score.
past_key_value (optional tuple of tensors): cached key and value with the same shape of key, value inputs.
The size of tuple should be 2, where the first entry is for cached key and second entry is for cached value.
is_causal (bool): If true, does causal attention masking, attn_mask should be set to None if this is set to True
is_causal is a hint that the mask is a causal mask, providing incorrect hints can result in incorrect execution.

Returns:
if use_cache is off, return attn_output tensor of shape bsz x seq_len x embed_dim;
otherwise return namedtuple with attn_output, cached key and value.
"""
bsz = query.size(0)
embed_dim = query.size(-1)
head_dim = embed_dim // self.num_heads
query = self.q_proj(query)
key = self.k_proj(key)
value = self.v_proj(value)

# bsz x seq_len x embed_dim => bsz x num_heads x seq_len x head_dim
query = query.view(bsz, -1, self.num_heads, head_dim).transpose(1, 2)
if key.size(0) != bsz or value.size(0) != bsz:
raise ValueError("key and value should have the same bsz as query.")
key = key.view(bsz, -1, self.num_heads, head_dim).transpose(1, 2)
value = value.view(bsz, -1, self.num_heads, head_dim).transpose(1, 2)

# concat key value with cached values
if past_key_value is not None:
key = torch.cat([past_key_value[0], key], dim=2)
value = torch.cat([past_key_value[1], value], dim=2)

# turn off causal attention inside scaled_dot_product_attention, we handle it separately with attn_mask.
attn = F.scaled_dot_product_attention(
query, key, value, attn_mask, self.dropout, is_causal
)
attn = attn.transpose(1, 2).reshape(bsz, -1, embed_dim)

# add dense layer after attention
attn_output = self.output_proj(attn)
if self.use_cache:
return MHAWithCacheOutput(attn_output, (key, value))
return attn_output