Skip to content

Commit

Permalink
upgrade attention
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Apr 3, 2022
1 parent 7086332 commit a022f6d
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 10 deletions.
43 changes: 34 additions & 9 deletions lightweight_gan/lightweight_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,12 +228,14 @@ def forward(self, x):
return self.net(x)

class LinearAttention(nn.Module):
def __init__(self, dim, dim_head = 64, heads = 8):
def __init__(self, dim, dim_head = 64, heads = 8, kernel_size = 3):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
self.dim_head = dim_head
inner_dim = dim_head * heads

self.kernel_size = kernel_size
self.nonlin = nn.GELU()
self.to_q = nn.Conv2d(dim, inner_dim, 1, bias = False)
self.to_kv = DepthWiseConv2d(dim, inner_dim * 2, 3, padding = 1, bias = False)
Expand All @@ -242,18 +244,41 @@ def __init__(self, dim, dim_head = 64, heads = 8):
def forward(self, fmap):
h, x, y = self.heads, *fmap.shape[-2:]
q, k, v = (self.to_q(fmap), *self.to_kv(fmap).chunk(2, dim = 1))
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> (b h) (x y) c', h = h), (q, k, v))
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> (b h) c x y', h = h), (q, k, v))

q = q.softmax(dim = -1)
k = k.softmax(dim = -2)
# linear attention

q = q * self.scale
lin_q, lin_k, lin_v = map(lambda t: rearrange(t, 'b d ... -> b (...) d'), (q, k, v))

context = einsum('b n d, b n e -> b d e', k, v)
out = einsum('b n d, b d e -> b n e', q, context)
out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, x = x, y = y)
lin_q = lin_q.softmax(dim = -1)
lin_k = lin_k.softmax(dim = -2)

out = self.nonlin(out)
lin_q = lin_q * self.scale

context = einsum('b n d, b n e -> b d e', lin_k, lin_v)
lin_out = einsum('b n d, b d e -> b n e', lin_q, context)
lin_out = rearrange(lin_out, '(b h) (x y) d -> b (h d) x y', h = h, x = x, y = y)

# conv-like full attention

k = F.unfold(k, kernel_size = self.kernel_size, padding = self.kernel_size // 2)
v = F.unfold(v, kernel_size = self.kernel_size, padding = self.kernel_size // 2)

k, v = map(lambda t: rearrange(t, 'b (d j) n -> b n j d', d = self.dim_head), (k, v))

q = rearrange(q, 'b c ... -> b (...) c') * self.scale

sim = einsum('b i d, b i j d -> b i j', q, k)
sim = sim - sim.amax(dim = -1, keepdim = True).detach()

attn = sim.softmax(dim = -1)

full_out = einsum('b i j, b i j d -> b i d', attn, v)
full_out = rearrange(full_out, '(b h) (x y) d -> b (h d) x y', h = h, x = x, y = y)

# add outputs of linear attention + conv like full attention

out = self.nonlin(lin_out + full_out)
return self.to_out(out)

# dataset
Expand Down
2 changes: 1 addition & 1 deletion lightweight_gan/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.21.0'
__version__ = '0.21.1'

0 comments on commit a022f6d

Please sign in to comment.