Skip to content

Commit

Permalink
make sure masks work with funnel transformer
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 6, 2020
1 parent 72541d6 commit 91b310f
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 10 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,8 @@ model = TransformerWrapper(
)

x = torch.randint(1, 20000, (1, 1024))
model(x) # (1, 1024, 20000)
mask = torch.ones_like(x).bool()
model(x, mask = mask) # (1, 1024, 20000)
```

## Citations
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '0.0.23',
version = '0.0.24',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand Down
29 changes: 21 additions & 8 deletions x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ def default(val, d):

def residualize(f):
def fn(x, *args, **kwargs):
return f(x, *args, **kwargs) + x
out, *rest = f(x, *args, **kwargs)
return (out + x, *rest)
return fn

# keyword argument helpers
Expand Down Expand Up @@ -241,14 +242,26 @@ def forward(self, x, context = None, mask = None, context_mask = None, rel_pos =

class AttentionWithDownsample(Attention):
def forward(self, x, num_memory_tokens = 0, downsample = False, **kwargs):
mask = None

if downsample:
b, n, *_ = x.shape
is_odd = (n % 2) == 1

mem, x = x[:, :num_memory_tokens], x[:, num_memory_tokens:]
x, remainder = (x[:, :-1], x[:, -1:]) if (n % 2) == 1 else (x[:, :], x[:, 0:0])
x, remainder = (x[:, :-1], x[:, -1:]) if is_odd else (x[:, :], x[:, 0:0])
x = reduce(x, 'b (n c) d -> b n d', 'mean', c = 2)
x = torch.cat((mem, x, remainder), dim = 1)

return super().forward(x, **kwargs)
mask = kwargs.pop('mask', None)
if exists(mask):
mask = mask[:, num_memory_tokens:]
mask = F.pad(mask, (0, 1), value = False) if is_odd else mask
mask = mask.reshape(b, -1, 2).any(dim = -1)
mask = F.pad(mask, (num_memory_tokens, 0), value = True)
kwargs.update(mask = mask)

return super().forward(x, **kwargs), mask

class Encoder(nn.Module):
def __init__(self, dim, depth, heads = 8, use_scalenorm = False, rel_pos_bias = False, **kwargs):
Expand Down Expand Up @@ -300,10 +313,7 @@ def __init__(self, dim, depths, heads = 8, use_scalenorm = False, rel_pos_bias =
prenorm_fn(FeedForward(dim, **ff_kwargs))
]))

self.bottlenecks.append(nn.ModuleList([
rel_pos,
layers
]))
self.bottlenecks.append(nn.ModuleList([rel_pos, layers]))

def forward(self, x, context = None, mask = None):
n = x.shape[1]
Expand All @@ -318,9 +328,12 @@ def forward(self, x, context = None, mask = None):
downsample = layer_ind != 0 and ind == 0
self_attn = residualize(self_attn) if not downsample else self_attn

x = self_attn(x, mask = mask, rel_pos = rel_pos, downsample = downsample, num_memory_tokens = num_mem)
x, new_mask = self_attn(x, mask = mask, rel_pos = rel_pos, downsample = downsample, num_memory_tokens = num_mem)
x = ff(x) + x

if exists(new_mask):
mask = new_mask

mem, x = x[:, :num_mem], x[:, num_mem:]
# upsample by repeating tokens as specified in paper
x = repeat(x, 'b n d -> b (n m) d', m = 2 ** (num_downsamples - 1))
Expand Down

0 comments on commit 91b310f

Please sign in to comment.