Skip to content

Commit

Permalink
support returning multiple heads in transformer wrapper, for eventual…
Browse files Browse the repository at this point in the history
… prophet net feature
  • Loading branch information
lucidrains committed May 3, 2024
1 parent 3673f74 commit 96c5515
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 33 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '1.28.4',
version = '1.28.5',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand Down
91 changes: 59 additions & 32 deletions x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
from packaging import version

import torch
from torch import nn, einsum, Tensor
import torch.nn.functional as F
from torch import nn, einsum, Tensor
from torch.nn import Module, ModuleList, ModuleDict
from torch.cuda.amp import autocast

from functools import partial, wraps
Expand Down Expand Up @@ -191,13 +192,13 @@ def dropout_seq(seq, mask, dropout):

# activations

class ReluSquared(nn.Module):
class ReluSquared(Module):
def forward(self, x):
return F.relu(x) ** 2

# embedding

class TokenEmbedding(nn.Module):
class TokenEmbedding(Module):
def __init__(self, dim, num_tokens, l2norm_embed = False):
super().__init__()
self.l2norm_embed = l2norm_embed
Expand All @@ -209,7 +210,7 @@ def forward(self, x):

# positional embeddings

class AbsolutePositionalEmbedding(nn.Module):
class AbsolutePositionalEmbedding(Module):
def __init__(self, dim, max_seq_len, l2norm_embed = False):
super().__init__()
self.scale = dim ** -0.5 if not l2norm_embed else 1.
Expand All @@ -231,7 +232,7 @@ def forward(self, x, pos = None, seq_start_pos = None):
pos_emb = pos_emb * self.scale
return l2norm(pos_emb) if self.l2norm_embed else pos_emb

class ScaledSinusoidalEmbedding(nn.Module):
class ScaledSinusoidalEmbedding(Module):
def __init__(self, dim, theta = 10000):
super().__init__()
assert divisible_by(dim, 2)
Expand All @@ -255,7 +256,7 @@ def forward(self, x, pos = None, seq_start_pos = None):
emb = torch.cat((emb.sin(), emb.cos()), dim = -1)
return emb * self.scale

class RelativePositionBias(nn.Module):
class RelativePositionBias(Module):
def __init__(self, scale, causal = False, num_buckets = 32, max_distance = 128, heads = 8):
super().__init__()
self.scale = scale
Expand Down Expand Up @@ -300,13 +301,13 @@ def forward(self, i, j):
bias = rearrange(values, 'i j h -> h i j')
return bias * self.scale

class DynamicPositionBias(nn.Module):
class DynamicPositionBias(Module):
def __init__(self, dim, *, heads, depth, log_distance = False, norm = False):
super().__init__()
assert depth >= 1, 'depth for dynamic position bias MLP must be greater or equal to 1'
self.log_distance = log_distance

self.mlp = nn.ModuleList([])
self.mlp = ModuleList([])

self.mlp.append(Sequential(
nn.Linear(1, dim),
Expand Down Expand Up @@ -352,7 +353,7 @@ def forward(self, i, j):
bias = rearrange(bias, 'i j h -> h i j')
return bias

class AlibiPositionalBias(nn.Module):
class AlibiPositionalBias(Module):
def __init__(self, heads, total_heads, **kwargs):
super().__init__()
self.heads = heads
Expand Down Expand Up @@ -401,7 +402,7 @@ def forward(self, i, j):

return self.bias

class RotaryEmbedding(nn.Module):
class RotaryEmbedding(Module):
def __init__(
self,
dim,
Expand Down Expand Up @@ -476,7 +477,7 @@ def apply_rotary_pos_emb(t, freqs, scale = 1):

# norms

class Scale(nn.Module):
class Scale(Module):
def __init__(self, value, fn):
super().__init__()
self.value = value
Expand All @@ -491,7 +492,7 @@ def forward(self, x, **kwargs):

return (scale_fn(out[0]), *out[1:])

class LayerNorm(nn.Module):
class LayerNorm(Module):
def __init__(self, dim):
"""
bias-less layernorm has been shown to be more stable. most newer models have moved towards rmsnorm, also bias-less
Expand All @@ -506,7 +507,7 @@ def forward(self, x):
if version.parse(torch.__version__) >= version.parse('2.1.0'):
LayerNorm = partial(nn.LayerNorm, bias = False)

class ScaleNorm(nn.Module):
class ScaleNorm(Module):
def __init__(self, dim):
super().__init__()
self.scale = dim ** 0.5
Expand All @@ -515,7 +516,7 @@ def __init__(self, dim):
def forward(self, x):
return F.normalize(x, dim = -1) * self.scale * self.g

class RMSNorm(nn.Module):
class RMSNorm(Module):
def __init__(self, dim):
super().__init__()
self.scale = dim ** 0.5
Expand All @@ -524,7 +525,7 @@ def __init__(self, dim):
def forward(self, x):
return F.normalize(x, dim = -1) * self.scale * self.g

class SimpleRMSNorm(nn.Module):
class SimpleRMSNorm(Module):
def __init__(self, dim):
super().__init__()
self.scale = dim ** 0.5
Expand All @@ -534,7 +535,7 @@ def forward(self, x):

# residual and residual gates

class Residual(nn.Module):
class Residual(Module):
def __init__(self, dim, scale_residual = False, scale_residual_constant = 1.):
super().__init__()
self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
Expand All @@ -549,7 +550,7 @@ def forward(self, x, residual):

return x + residual

class GRUGating(nn.Module):
class GRUGating(Module):
def __init__(self, dim, scale_residual = False, **kwargs):
super().__init__()
self.gru = nn.GRUCell(dim, dim)
Expand Down Expand Up @@ -579,7 +580,7 @@ def shift(t, amount, mask = None):

return pad_at_dim(t, (amount, -amount), dim = - 2, value = 0.)

class ShiftTokens(nn.Module):
class ShiftTokens(Module):
def __init__(self, shifts, fn):
super().__init__()
self.fn = fn
Expand All @@ -598,7 +599,7 @@ def forward(self, x, **kwargs):

# feedforward

class GLU(nn.Module):
class GLU(Module):
def __init__(
self,
dim_in,
Expand All @@ -615,7 +616,7 @@ def forward(self, x):
x, gate = self.proj(x).chunk(2, dim = -1)
return x * self.act(gate) * self.mult_bias

class FeedForward(nn.Module):
class FeedForward(Module):
def __init__(
self,
dim,
Expand Down Expand Up @@ -665,7 +666,7 @@ def forward(self, x):

# attention. it is all we need

class Attention(nn.Module):
class Attention(Module):
def __init__(
self,
dim,
Expand Down Expand Up @@ -996,7 +997,7 @@ def forward(

return out, intermediates

class AttentionLayers(nn.Module):
class AttentionLayers(Module):
def __init__(
self,
dim,
Expand Down Expand Up @@ -1058,7 +1059,7 @@ def __init__(
self.dim = dim
self.depth = depth
self.causal = causal
self.layers = nn.ModuleList([])
self.layers = ModuleList([])

self.disable_abs_pos_emb = default(disable_abs_pos_emb, (rel_pos_bias or rotary_pos_emb))

Expand Down Expand Up @@ -1215,13 +1216,13 @@ def __init__(
post_branch_norm = norm_fn() if sandwich_norm else None
post_main_norm = norm_fn() if not pre_norm else None

norms = nn.ModuleList([
norms = ModuleList([
pre_branch_norm,
post_branch_norm,
post_main_norm
])

self.layers.append(nn.ModuleList([
self.layers.append(ModuleList([
norms,
layer,
residual
Expand Down Expand Up @@ -1427,7 +1428,7 @@ class CrossAttender(AttentionLayers):
def __init__(self, **kwargs):
super().__init__(cross_attend = True, only_cross = True, **kwargs)

class ViTransformerWrapper(nn.Module):
class ViTransformerWrapper(Module):
def __init__(
self,
*,
Expand Down Expand Up @@ -1508,7 +1509,7 @@ def forward(

return logits, embed

class TransformerWrapper(nn.Module):
class TransformerWrapper(Module):
def __init__(
self,
*,
Expand All @@ -1525,6 +1526,7 @@ def __init__(
memory_tokens_interspersed_every = None,
tie_embedding = False,
logits_dim = None,
num_output_heads = 1,
use_abs_pos_emb = True,
scaled_sinu_pos_emb = False,
l2norm_embed = False,
Expand Down Expand Up @@ -1559,7 +1561,7 @@ def __init__(
self.embeds = None

if len(embed_num_tokens) > 0:
self.embeds = nn.ModuleDict({f'{name}_embed': nn.Embedding(num_tokens, emb_dim) for name, num_tokens in embed_num_tokens.items()})
self.embeds = ModuleDict({f'{name}_embed': nn.Embedding(num_tokens, emb_dim) for name, num_tokens in embed_num_tokens.items()})

# fraction of the gradient that should go to the embedding, https://arxiv.org/abs/2105.13290

Expand All @@ -1573,8 +1575,19 @@ def __init__(

self.init_()

# output head, usually to logits of num_tokens

logits_dim = default(logits_dim, num_tokens)
self.to_logits = nn.Linear(dim, logits_dim, bias = False) if not tie_embedding else lambda t: t @ self.token_emb.emb.weight.t()

self.has_multiple_heads = False

if tie_embedding:
self.to_logits = lambda t: t @ self.token_emb.emb.weight.t()
elif num_output_heads > 1:
self.has_multiple_heads = True
self.to_logits = ModuleList([nn.Linear(dim, logits_dim, bias = False) for _ in range(num_output_heads)])
else:
self.to_logits = nn.Linear(dim, logits_dim, bias = False)

# memory tokens (like [cls]) from Memory Transformers paper

Expand Down Expand Up @@ -1705,6 +1718,8 @@ def forward(

x, intermediates = self.attn_layers(x, mask = mask, mems = mems, mem_masks = mem_masks, cache = cache, return_hiddens = True, seq_start_pos = seq_start_pos, **kwargs)

# handle memories post-attention

if has_memory_tokens:
if exists(mem_every):
x = rearrange(x, 'b (n m) d -> (b n) m d', m = (mem_every + num_mems))
Expand All @@ -1718,12 +1733,24 @@ def forward(

x = x[:, :n]

# projecting to logits

if not return_embeddings:
if self.has_multiple_heads:
logits = tuple(fn(x) for fn in self.to_logits)
else:
logits = self.to_logits(x)

# different returns

if return_logits_and_embeddings:
out = (self.to_logits(x), x)
out = (logits, x)
elif return_embeddings:
out = x
else:
out = self.to_logits(x)
out = logits

# aux loss

if return_attn_z_loss:
pre_softmax_attns = list(map(lambda t: t.pre_softmax_attn, intermediates.attn_intermediates))
Expand All @@ -1749,7 +1776,7 @@ def forward(

return out

class XTransformer(nn.Module):
class XTransformer(Module):
def __init__(
self,
*,
Expand Down

0 comments on commit 96c5515

Please sign in to comment.