From a022f6d32ca24885ca0019216ecc84c11210eae9 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sun, 3 Apr 2022 16:52:58 -0700 Subject: [PATCH] upgrade attention --- lightweight_gan/lightweight_gan.py | 43 +++++++++++++++++++++++------- lightweight_gan/version.py | 2 +- 2 files changed, 35 insertions(+), 10 deletions(-) diff --git a/lightweight_gan/lightweight_gan.py b/lightweight_gan/lightweight_gan.py index 638a9d1..b4125b8 100644 --- a/lightweight_gan/lightweight_gan.py +++ b/lightweight_gan/lightweight_gan.py @@ -228,12 +228,14 @@ def forward(self, x): return self.net(x) class LinearAttention(nn.Module): - def __init__(self, dim, dim_head = 64, heads = 8): + def __init__(self, dim, dim_head = 64, heads = 8, kernel_size = 3): super().__init__() self.scale = dim_head ** -0.5 self.heads = heads + self.dim_head = dim_head inner_dim = dim_head * heads + self.kernel_size = kernel_size self.nonlin = nn.GELU() self.to_q = nn.Conv2d(dim, inner_dim, 1, bias = False) self.to_kv = DepthWiseConv2d(dim, inner_dim * 2, 3, padding = 1, bias = False) @@ -242,18 +244,41 @@ def __init__(self, dim, dim_head = 64, heads = 8): def forward(self, fmap): h, x, y = self.heads, *fmap.shape[-2:] q, k, v = (self.to_q(fmap), *self.to_kv(fmap).chunk(2, dim = 1)) - q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> (b h) (x y) c', h = h), (q, k, v)) + q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> (b h) c x y', h = h), (q, k, v)) - q = q.softmax(dim = -1) - k = k.softmax(dim = -2) + # linear attention - q = q * self.scale + lin_q, lin_k, lin_v = map(lambda t: rearrange(t, 'b d ... -> b (...) d'), (q, k, v)) - context = einsum('b n d, b n e -> b d e', k, v) - out = einsum('b n d, b d e -> b n e', q, context) - out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, x = x, y = y) + lin_q = lin_q.softmax(dim = -1) + lin_k = lin_k.softmax(dim = -2) - out = self.nonlin(out) + lin_q = lin_q * self.scale + + context = einsum('b n d, b n e -> b d e', lin_k, lin_v) + lin_out = einsum('b n d, b d e -> b n e', lin_q, context) + lin_out = rearrange(lin_out, '(b h) (x y) d -> b (h d) x y', h = h, x = x, y = y) + + # conv-like full attention + + k = F.unfold(k, kernel_size = self.kernel_size, padding = self.kernel_size // 2) + v = F.unfold(v, kernel_size = self.kernel_size, padding = self.kernel_size // 2) + + k, v = map(lambda t: rearrange(t, 'b (d j) n -> b n j d', d = self.dim_head), (k, v)) + + q = rearrange(q, 'b c ... -> b (...) c') * self.scale + + sim = einsum('b i d, b i j d -> b i j', q, k) + sim = sim - sim.amax(dim = -1, keepdim = True).detach() + + attn = sim.softmax(dim = -1) + + full_out = einsum('b i j, b i j d -> b i d', attn, v) + full_out = rearrange(full_out, '(b h) (x y) d -> b (h d) x y', h = h, x = x, y = y) + + # add outputs of linear attention + conv like full attention + + out = self.nonlin(lin_out + full_out) return self.to_out(out) # dataset diff --git a/lightweight_gan/version.py b/lightweight_gan/version.py index e453371..8c306aa 100644 --- a/lightweight_gan/version.py +++ b/lightweight_gan/version.py @@ -1 +1 @@ -__version__ = '0.21.0' +__version__ = '0.21.1'