Skip to content

Commit

Permalink
Merge pull request #2388 from laclouis5/fix-mqa-v2
Browse files Browse the repository at this point in the history
Fix MQA V2
  • Loading branch information
rwightman authored Jan 2, 2025
2 parents 2d734d9 + 2d5277e commit d23facd
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 9 deletions.
22 changes: 18 additions & 4 deletions tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
import torch.nn as nn

from timm.layers import create_act_layer, set_layer_config, get_act_layer, get_act_fn, Attention2d
from timm.layers import create_act_layer, set_layer_config, get_act_layer, get_act_fn, Attention2d, MultiQueryAttentionV2

import importlib
import os
Expand Down Expand Up @@ -121,6 +121,23 @@ def test_get_act_fn_none():
assert get_act_fn('') is None


@pytest.mark.parametrize("dim", [128])
@pytest.mark.parametrize("dim_out", [128, 256])
@pytest.mark.parametrize("use_m", [True, False])
def test_mqa_v2(dim, dim_out, use_m):
mqa = MultiQueryAttentionV2(dim, dim_out)

x = torch.randn(1, dim, 32, 48)
if use_m:
m = torch.randn(1, dim, 16, 24)
else:
m = None

y = mqa(x, m=m)

assert (y.shape) == (1, dim_out, 32, 48)


@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("expand_first", [True, False])
@pytest.mark.parametrize("head_first", [True, False])
Expand All @@ -141,6 +158,3 @@ def test_attn2d(bias, expand_first, head_first, attn_mask):
o2 = attn(x, mask)

assert torch.allclose(o1, o2, atol=1e-5), f"{torch.abs(o1 - o2).max()}"



10 changes: 5 additions & 5 deletions timm/layers/attention2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,24 +59,24 @@ def _reshape_input(self, t):

def forward(self, x, m: Optional[torch.Tensor] = None):
"""Run layer computation."""
s = x.shape
m = m or x
b, _, h, w = x.shape
m = m if m is not None else x

reshaped_x = self._reshape_input(x)
reshaped_m = self._reshape_input(m)

q = torch.einsum('bnd,hkd->bnhk', reshaped_x, self.query_proj)
k = torch.einsum('bmd,dk->bmk', reshaped_m, self.key_proj)

attn = torch.einsum('bnhk,bmk->bnhm', q, k)
attn = torch.einsum('bnhk,bmk->bnhm', q, k) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)

v = torch.einsum('bmd,dv->bmv', reshaped_m, self.value_proj)
o = torch.einsum('bnhm,bmv->bnhv', attn, v)
result = torch.einsum('bnhv,dhv->bnd', o, self.out_proj)
result = torch.einsum('bnhv,dhv->bdn', o, self.out_proj)
result = self.proj_drop(result)
return result.reshape(s)
return result.reshape(b, -1, h, w)


class MultiQueryAttention2d(nn.Module):
Expand Down

0 comments on commit d23facd

Please sign in to comment.