Skip to content

Commit

Permalink
attention fix
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Apr 4, 2022
1 parent a022f6d commit b65a15f
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 7 deletions.
22 changes: 16 additions & 6 deletions lightweight_gan/lightweight_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,18 +237,24 @@ def __init__(self, dim, dim_head = 64, heads = 8, kernel_size = 3):

self.kernel_size = kernel_size
self.nonlin = nn.GELU()

self.to_lin_q = nn.Conv2d(dim, inner_dim, 1, bias = False)
self.to_lin_kv = DepthWiseConv2d(dim, inner_dim * 2, 3, padding = 1, bias = False)

self.to_q = nn.Conv2d(dim, inner_dim, 1, bias = False)
self.to_kv = DepthWiseConv2d(dim, inner_dim * 2, 3, padding = 1, bias = False)
self.to_out = nn.Conv2d(inner_dim, dim, 1)
self.to_kv = nn.Conv2d(dim, inner_dim * 2, 1, bias = False)

self.to_out = nn.Conv2d(inner_dim * 2, dim, 1)

def forward(self, fmap):
h, x, y = self.heads, *fmap.shape[-2:]
q, k, v = (self.to_q(fmap), *self.to_kv(fmap).chunk(2, dim = 1))
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> (b h) c x y', h = h), (q, k, v))

# linear attention

lin_q, lin_k, lin_v = map(lambda t: rearrange(t, 'b d ... -> b (...) d'), (q, k, v))
lin_q, lin_k, lin_v = (self.to_lin_q(fmap), *self.to_lin_kv(fmap).chunk(2, dim = 1))
lin_q, lin_k, lin_v = map(lambda t: rearrange(t, 'b (h c) x y -> (b h) c x y', h = h), (lin_q, lin_k, lin_v))

lin_q, lin_k, lin_v = map(lambda t: rearrange(t, 'b d ... -> b (...) d'), (lin_q, lin_k, lin_v))

lin_q = lin_q.softmax(dim = -1)
lin_k = lin_k.softmax(dim = -2)
Expand All @@ -261,6 +267,9 @@ def forward(self, fmap):

# conv-like full attention

q, k, v = (self.to_q(fmap), *self.to_kv(fmap).chunk(2, dim = 1))
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> (b h) c x y', h = h), (q, k, v))

k = F.unfold(k, kernel_size = self.kernel_size, padding = self.kernel_size // 2)
v = F.unfold(v, kernel_size = self.kernel_size, padding = self.kernel_size // 2)

Expand All @@ -278,7 +287,8 @@ def forward(self, fmap):

# add outputs of linear attention + conv like full attention

out = self.nonlin(lin_out + full_out)
lin_out = self.nonlin(lin_out)
out = torch.cat((lin_out, full_out), dim = 1)
return self.to_out(out)

# dataset
Expand Down
2 changes: 1 addition & 1 deletion lightweight_gan/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.21.1'
__version__ = '0.21.2'

0 comments on commit b65a15f

Please sign in to comment.