Skip to content

Commit

Permalink
add sandwich coefficient, from sandwich transformers. also allow user…
Browse files Browse the repository at this point in the history
… to define any arbitrary layer configuration, using `a` for self attention, `f` for feedforward, and `c` for cross attention
  • Loading branch information
lucidrains committed Nov 9, 2020
1 parent 0fe3be7 commit ff7e307
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 43 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '0.0.28',
version = '0.0.29',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand Down
98 changes: 56 additions & 42 deletions x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,12 @@
import torch.nn.functional as F
from functools import partial
from inspect import isfunction
from enum import Enum

from einops import rearrange, repeat, reduce
from entmax import entmax15

from x_transformers.autoregressive_wrapper import AutoregressiveWrapper

# constants

LAYER_TYPE = Enum('LAYER_TYPE', ['ff', 'attn', 'cross_attn'])

# helpers

def exists(val):
Expand Down Expand Up @@ -109,7 +104,7 @@ def __init__(self, dim, eps = 1e-5):
self.eps = eps
self.g = nn.Parameter(torch.ones(1))

def forward(self, x, **kwargs):
def forward(self, x):
n = torch.norm(x, dim = -1, keepdim = True).clamp(min = self.eps)
return x / n * self.g

Expand Down Expand Up @@ -232,7 +227,7 @@ def forward(self, x, context = None, mask = None, context_mask = None, rel_pos =
dots = rel_pos(dots)

if exists(input_mask):
dots.masked_fill_(input_mask, float('-inf'))
dots.masked_fill_(~input_mask, float('-inf'))
del input_mask

if self.causal:
Expand Down Expand Up @@ -262,12 +257,25 @@ def forward(self, x, context = None, mask = None, context_mask = None, rel_pos =

return self.to_out(out)

class Encoder(nn.Module):
def __init__(self, dim, depth, heads = 8, use_scalenorm = False, use_rezero = False, rel_pos_bias = False, **kwargs):
class AttentionLayers(nn.Module):
def __init__(
self,
dim,
depth,
heads = 8,
causal = False,
cross_attend = False,
use_scalenorm = False,
use_rezero = False,
rel_pos_bias = False,
custom_layers = None,
sandwich_coef = None,
**kwargs
):
super().__init__()
self.dim = dim
self.layers = nn.ModuleList([])
self.rel_pos = RelativePositionBias() if rel_pos_bias else None
self.rel_pos = RelativePositionBias(causal = causal) if rel_pos_bias else None

norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm
prenorm_fn = partial(PreNorm, dim, norm_class = norm_class)
Expand All @@ -276,45 +284,51 @@ def __init__(self, dim, depth, heads = 8, use_scalenorm = False, use_rezero = Fa
ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs)
attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs)

for _ in range(depth):
self.layers.append(nn.ModuleList([
prenorm_fn(Attention(dim, heads = heads, **attn_kwargs)),
prenorm_fn(FeedForward(dim, **ff_kwargs))
]))
def forward(self, x, context = None, mask = None):
for (self_attn, ff) in self.layers:
x = self_attn(x, mask = mask, rel_pos = self.rel_pos) + x
x = ff(x) + x
return x
default_block = ('a', 'f') if not cross_attend else ('a', 'c', 'f')

class Decoder(nn.Module):
def __init__(self, dim, depth, heads = 8, cross_attend = False, use_scalenorm = False, use_rezero = False, rel_pos_bias = False, **kwargs):
super().__init__()
self.dim = dim
self.layers = nn.ModuleList([])
self.rel_pos = RelativePositionBias(causal = True) if rel_pos_bias else None
if exists(custom_layers):
layer_types = custom_layers
elif exists(sandwich_coef):
assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth'
layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef
else:
layer_types = default_block * depth

norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm
prenorm_fn = partial(PreNorm, dim, norm_class = norm_class)
prenorm_fn = Rezero if use_rezero else prenorm_fn
self.layer_types = layer_types

ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs)
attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs)
for layer_type in self.layer_types:
if layer_type == 'a':
layer = Attention(dim, heads = heads, causal = causal, **attn_kwargs)
elif layer_type == 'c':
layer = Attention(dim, heads = heads, **attn_kwargs)
elif layer_type == 'f':
layer = FeedForward(dim, **ff_kwargs)
else:
raise Exception(f'invalid layer type {layer_type}')

self.layers.append(prenorm_fn(layer))

for _ in range(depth):
self.layers.append(nn.ModuleList([
prenorm_fn(Attention(dim, heads = heads, causal = True, **attn_kwargs)),
prenorm_fn(Attention(dim, heads = heads, **attn_kwargs)) if cross_attend else None,
prenorm_fn(FeedForward(dim, **ff_kwargs)),
]))
def forward(self, x, context = None, mask = None, context_mask = None):
for (self_attn, cross_attn, ff) in self.layers:
x = self_attn(x, rel_pos = self.rel_pos) + x
if exists(cross_attn):
x = cross_attn(x, context = context, mask = mask, context_mask = context_mask) + x
x = ff(x) + x
for (layer_type, block) in zip(self.layer_types, self.layers):
if layer_type == 'a':
x = block(x, mask = mask, rel_pos = self.rel_pos) + x
elif layer_type == 'c':
x = block(x, context = context, mask = mask, context_mask = context_mask) + x
elif layer_type == 'f':
x = block(x) + x
return x

class Encoder(AttentionLayers):
def __init__(self, **kwargs):
assert 'causal' not in kwargs, 'cannot set causality on encoder'
assert 'cross_attend' not in kwargs, 'encoder cannot cross attend'
super().__init__(causal = False, **kwargs)

class Decoder(AttentionLayers):
def __init__(self, **kwargs):
assert 'causal' not in kwargs, 'cannot set causality on decoder'
super().__init__(causal = True, **kwargs)

class ViTransformerWrapper(nn.Module):
def __init__(
self,
Expand Down

0 comments on commit ff7e307

Please sign in to comment.