Skip to content

Commit

Permalink
add attention on attention (gating on final attention output, with qu…
Browse files Browse the repository at this point in the history
…eries incorporated)
  • Loading branch information
lucidrains committed Nov 8, 2020
1 parent f2063a4 commit dc8ec28
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 12 deletions.
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -366,3 +366,14 @@ model(x, mask = mask) # (1, 1024, 20000)
primaryClass = {cs.CV}
}
```

```bibtex
@misc{huang2019attention,
title = {Attention on Attention for Image Captioning},
author = {Lun Huang and Wenmin Wang and Jie Chen and Xiao-Yong Wei},
year = {2019},
eprint = {1908.06954},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '0.0.25',
version = '0.0.26',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand Down
30 changes: 19 additions & 11 deletions x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,8 @@ def __init__(
sparse_topk = None,
use_entmax15 = False,
num_mem_kv = 0,
dropout = 0.
dropout = 0.,
on_attn = False
):
super().__init__()
self.scale = dim_head ** -0.5
Expand All @@ -159,7 +160,6 @@ def __init__(
inner_dim = dim_head * heads
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim)
self.dropout = nn.Dropout(dropout)

# talking heads
Expand All @@ -180,14 +180,18 @@ def __init__(
self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))

# attention on attention
self.attn_on_attn = on_attn
self.to_out = GEGLU(inner_dim * 2, dim) if on_attn else nn.Linear(inner_dim, dim)

def forward(self, x, context = None, mask = None, context_mask = None, rel_pos = None):
b, n, _, h, talking_heads, device = *x.shape, self.heads, self.talking_heads, x.device
kv_input = default(context, x)

q = self.to_q(x)
q_ = self.to_q(x)
kv = self.to_kv(kv_input).chunk(2, dim = -1)

q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, *kv))
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q_, *kv))

if self.num_mem_kv > 0:
mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b = b), (self.mem_k, self.mem_v))
Expand All @@ -197,7 +201,7 @@ def forward(self, x, context = None, mask = None, context_mask = None, rel_pos =
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

if talking_heads:
dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj)
dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous()

if exists(rel_pos):
dots = rel_pos(dots)
Expand All @@ -218,8 +222,8 @@ def forward(self, x, context = None, mask = None, context_mask = None, rel_pos =
del mask

if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]:
v, _ = dots.topk(dim = -1)
vk = v[..., -1].unsqueeze(-1).expand_as(dots)
top, _ = dots.topk(self.sparse_topk, dim = -1)
vk = top[..., -1].unsqueeze(-1).expand_as(dots)
mask = dots < vk
dots.masked_fill_(mask, float('-inf'))
del mask
Expand All @@ -228,10 +232,14 @@ def forward(self, x, context = None, mask = None, context_mask = None, rel_pos =
attn = self.dropout(attn)

if talking_heads:
dots = einsum('b h i j, h k -> b k i j', dots, self.post_softmax_proj)
dots = einsum('b h i j, h k -> b k i j', dots, self.post_softmax_proj).contiguous()

out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')

if self.attn_on_attn:
out = torch.cat((q_, out), dim = -1)

return self.to_out(out)

class Encoder(nn.Module):
Expand Down Expand Up @@ -259,7 +267,7 @@ def forward(self, x, context = None, mask = None):
return x

class Decoder(nn.Module):
def __init__(self, dim, depth, dim_head = 64, heads = 8, cross_attend = False, use_scalenorm = False, rel_pos_bias = False, **kwargs):
def __init__(self, dim, depth, heads = 8, cross_attend = False, use_scalenorm = False, rel_pos_bias = False, **kwargs):
super().__init__()
self.dim = dim
self.layers = nn.ModuleList([])
Expand All @@ -273,8 +281,8 @@ def __init__(self, dim, depth, dim_head = 64, heads = 8, cross_attend = False, u

for _ in range(depth):
self.layers.append(nn.ModuleList([
prenorm_fn(Attention(dim, dim_head = dim_head, heads = heads, causal = True)),
prenorm_fn(Attention(dim, dim_head = dim_head, heads = heads)) if cross_attend else None,
prenorm_fn(Attention(dim, heads = heads, causal = True, **attn_kwargs)),
prenorm_fn(Attention(dim, heads = heads, **attn_kwargs)) if cross_attend else None,
prenorm_fn(FeedForward(dim, **ff_kwargs)),
]))
def forward(self, x, context = None, mask = None, context_mask = None):
Expand Down

0 comments on commit dc8ec28

Please sign in to comment.