From 91b310f8e6ddf909e663845313609d43f829dd77 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Fri, 6 Nov 2020 14:26:08 -0800 Subject: [PATCH] make sure masks work with funnel transformer --- README.md | 3 ++- setup.py | 2 +- x_transformers/x_transformers.py | 29 +++++++++++++++++++++-------- 3 files changed, 24 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index c1ab99ad..22cdde43 100644 --- a/README.md +++ b/README.md @@ -174,7 +174,8 @@ model = TransformerWrapper( ) x = torch.randint(1, 20000, (1, 1024)) -model(x) # (1, 1024, 20000) +mask = torch.ones_like(x).bool() +model(x, mask = mask) # (1, 1024, 20000) ``` ## Citations diff --git a/setup.py b/setup.py index c464d7b8..3bc5311a 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'x-transformers', packages = find_packages(exclude=['examples']), - version = '0.0.23', + version = '0.0.24', license='MIT', description = 'X-Transformers - Pytorch', author = 'Phil Wang', diff --git a/x_transformers/x_transformers.py b/x_transformers/x_transformers.py index e9319641..f0992b98 100644 --- a/x_transformers/x_transformers.py +++ b/x_transformers/x_transformers.py @@ -22,7 +22,8 @@ def default(val, d): def residualize(f): def fn(x, *args, **kwargs): - return f(x, *args, **kwargs) + x + out, *rest = f(x, *args, **kwargs) + return (out + x, *rest) return fn # keyword argument helpers @@ -241,14 +242,26 @@ def forward(self, x, context = None, mask = None, context_mask = None, rel_pos = class AttentionWithDownsample(Attention): def forward(self, x, num_memory_tokens = 0, downsample = False, **kwargs): + mask = None + if downsample: b, n, *_ = x.shape + is_odd = (n % 2) == 1 + mem, x = x[:, :num_memory_tokens], x[:, num_memory_tokens:] - x, remainder = (x[:, :-1], x[:, -1:]) if (n % 2) == 1 else (x[:, :], x[:, 0:0]) + x, remainder = (x[:, :-1], x[:, -1:]) if is_odd else (x[:, :], x[:, 0:0]) x = reduce(x, 'b (n c) d -> b n d', 'mean', c = 2) x = torch.cat((mem, x, remainder), dim = 1) - return super().forward(x, **kwargs) + mask = kwargs.pop('mask', None) + if exists(mask): + mask = mask[:, num_memory_tokens:] + mask = F.pad(mask, (0, 1), value = False) if is_odd else mask + mask = mask.reshape(b, -1, 2).any(dim = -1) + mask = F.pad(mask, (num_memory_tokens, 0), value = True) + kwargs.update(mask = mask) + + return super().forward(x, **kwargs), mask class Encoder(nn.Module): def __init__(self, dim, depth, heads = 8, use_scalenorm = False, rel_pos_bias = False, **kwargs): @@ -300,10 +313,7 @@ def __init__(self, dim, depths, heads = 8, use_scalenorm = False, rel_pos_bias = prenorm_fn(FeedForward(dim, **ff_kwargs)) ])) - self.bottlenecks.append(nn.ModuleList([ - rel_pos, - layers - ])) + self.bottlenecks.append(nn.ModuleList([rel_pos, layers])) def forward(self, x, context = None, mask = None): n = x.shape[1] @@ -318,9 +328,12 @@ def forward(self, x, context = None, mask = None): downsample = layer_ind != 0 and ind == 0 self_attn = residualize(self_attn) if not downsample else self_attn - x = self_attn(x, mask = mask, rel_pos = rel_pos, downsample = downsample, num_memory_tokens = num_mem) + x, new_mask = self_attn(x, mask = mask, rel_pos = rel_pos, downsample = downsample, num_memory_tokens = num_mem) x = ff(x) + x + if exists(new_mask): + mask = new_mask + mem, x = x[:, :num_mem], x[:, num_mem:] # upsample by repeating tokens as specified in paper x = repeat(x, 'b n d -> b (n m) d', m = 2 ** (num_downsamples - 1))