From 96c551501bdcbbaf41ec47f7904fffc55ccf7cc4 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Fri, 3 May 2024 06:24:53 -0700 Subject: [PATCH] support returning multiple heads in transformer wrapper, for eventual prophet net feature --- setup.py | 2 +- x_transformers/x_transformers.py | 91 +++++++++++++++++++++----------- 2 files changed, 60 insertions(+), 33 deletions(-) diff --git a/setup.py b/setup.py index f62ec5d4..7e4eaeee 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'x-transformers', packages = find_packages(exclude=['examples']), - version = '1.28.4', + version = '1.28.5', license='MIT', description = 'X-Transformers - Pytorch', author = 'Phil Wang', diff --git a/x_transformers/x_transformers.py b/x_transformers/x_transformers.py index 06b83e9d..c48ed0bc 100644 --- a/x_transformers/x_transformers.py +++ b/x_transformers/x_transformers.py @@ -3,8 +3,9 @@ from packaging import version import torch -from torch import nn, einsum, Tensor import torch.nn.functional as F +from torch import nn, einsum, Tensor +from torch.nn import Module, ModuleList, ModuleDict from torch.cuda.amp import autocast from functools import partial, wraps @@ -191,13 +192,13 @@ def dropout_seq(seq, mask, dropout): # activations -class ReluSquared(nn.Module): +class ReluSquared(Module): def forward(self, x): return F.relu(x) ** 2 # embedding -class TokenEmbedding(nn.Module): +class TokenEmbedding(Module): def __init__(self, dim, num_tokens, l2norm_embed = False): super().__init__() self.l2norm_embed = l2norm_embed @@ -209,7 +210,7 @@ def forward(self, x): # positional embeddings -class AbsolutePositionalEmbedding(nn.Module): +class AbsolutePositionalEmbedding(Module): def __init__(self, dim, max_seq_len, l2norm_embed = False): super().__init__() self.scale = dim ** -0.5 if not l2norm_embed else 1. @@ -231,7 +232,7 @@ def forward(self, x, pos = None, seq_start_pos = None): pos_emb = pos_emb * self.scale return l2norm(pos_emb) if self.l2norm_embed else pos_emb -class ScaledSinusoidalEmbedding(nn.Module): +class ScaledSinusoidalEmbedding(Module): def __init__(self, dim, theta = 10000): super().__init__() assert divisible_by(dim, 2) @@ -255,7 +256,7 @@ def forward(self, x, pos = None, seq_start_pos = None): emb = torch.cat((emb.sin(), emb.cos()), dim = -1) return emb * self.scale -class RelativePositionBias(nn.Module): +class RelativePositionBias(Module): def __init__(self, scale, causal = False, num_buckets = 32, max_distance = 128, heads = 8): super().__init__() self.scale = scale @@ -300,13 +301,13 @@ def forward(self, i, j): bias = rearrange(values, 'i j h -> h i j') return bias * self.scale -class DynamicPositionBias(nn.Module): +class DynamicPositionBias(Module): def __init__(self, dim, *, heads, depth, log_distance = False, norm = False): super().__init__() assert depth >= 1, 'depth for dynamic position bias MLP must be greater or equal to 1' self.log_distance = log_distance - self.mlp = nn.ModuleList([]) + self.mlp = ModuleList([]) self.mlp.append(Sequential( nn.Linear(1, dim), @@ -352,7 +353,7 @@ def forward(self, i, j): bias = rearrange(bias, 'i j h -> h i j') return bias -class AlibiPositionalBias(nn.Module): +class AlibiPositionalBias(Module): def __init__(self, heads, total_heads, **kwargs): super().__init__() self.heads = heads @@ -401,7 +402,7 @@ def forward(self, i, j): return self.bias -class RotaryEmbedding(nn.Module): +class RotaryEmbedding(Module): def __init__( self, dim, @@ -476,7 +477,7 @@ def apply_rotary_pos_emb(t, freqs, scale = 1): # norms -class Scale(nn.Module): +class Scale(Module): def __init__(self, value, fn): super().__init__() self.value = value @@ -491,7 +492,7 @@ def forward(self, x, **kwargs): return (scale_fn(out[0]), *out[1:]) -class LayerNorm(nn.Module): +class LayerNorm(Module): def __init__(self, dim): """ bias-less layernorm has been shown to be more stable. most newer models have moved towards rmsnorm, also bias-less @@ -506,7 +507,7 @@ def forward(self, x): if version.parse(torch.__version__) >= version.parse('2.1.0'): LayerNorm = partial(nn.LayerNorm, bias = False) -class ScaleNorm(nn.Module): +class ScaleNorm(Module): def __init__(self, dim): super().__init__() self.scale = dim ** 0.5 @@ -515,7 +516,7 @@ def __init__(self, dim): def forward(self, x): return F.normalize(x, dim = -1) * self.scale * self.g -class RMSNorm(nn.Module): +class RMSNorm(Module): def __init__(self, dim): super().__init__() self.scale = dim ** 0.5 @@ -524,7 +525,7 @@ def __init__(self, dim): def forward(self, x): return F.normalize(x, dim = -1) * self.scale * self.g -class SimpleRMSNorm(nn.Module): +class SimpleRMSNorm(Module): def __init__(self, dim): super().__init__() self.scale = dim ** 0.5 @@ -534,7 +535,7 @@ def forward(self, x): # residual and residual gates -class Residual(nn.Module): +class Residual(Module): def __init__(self, dim, scale_residual = False, scale_residual_constant = 1.): super().__init__() self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None @@ -549,7 +550,7 @@ def forward(self, x, residual): return x + residual -class GRUGating(nn.Module): +class GRUGating(Module): def __init__(self, dim, scale_residual = False, **kwargs): super().__init__() self.gru = nn.GRUCell(dim, dim) @@ -579,7 +580,7 @@ def shift(t, amount, mask = None): return pad_at_dim(t, (amount, -amount), dim = - 2, value = 0.) -class ShiftTokens(nn.Module): +class ShiftTokens(Module): def __init__(self, shifts, fn): super().__init__() self.fn = fn @@ -598,7 +599,7 @@ def forward(self, x, **kwargs): # feedforward -class GLU(nn.Module): +class GLU(Module): def __init__( self, dim_in, @@ -615,7 +616,7 @@ def forward(self, x): x, gate = self.proj(x).chunk(2, dim = -1) return x * self.act(gate) * self.mult_bias -class FeedForward(nn.Module): +class FeedForward(Module): def __init__( self, dim, @@ -665,7 +666,7 @@ def forward(self, x): # attention. it is all we need -class Attention(nn.Module): +class Attention(Module): def __init__( self, dim, @@ -996,7 +997,7 @@ def forward( return out, intermediates -class AttentionLayers(nn.Module): +class AttentionLayers(Module): def __init__( self, dim, @@ -1058,7 +1059,7 @@ def __init__( self.dim = dim self.depth = depth self.causal = causal - self.layers = nn.ModuleList([]) + self.layers = ModuleList([]) self.disable_abs_pos_emb = default(disable_abs_pos_emb, (rel_pos_bias or rotary_pos_emb)) @@ -1215,13 +1216,13 @@ def __init__( post_branch_norm = norm_fn() if sandwich_norm else None post_main_norm = norm_fn() if not pre_norm else None - norms = nn.ModuleList([ + norms = ModuleList([ pre_branch_norm, post_branch_norm, post_main_norm ]) - self.layers.append(nn.ModuleList([ + self.layers.append(ModuleList([ norms, layer, residual @@ -1427,7 +1428,7 @@ class CrossAttender(AttentionLayers): def __init__(self, **kwargs): super().__init__(cross_attend = True, only_cross = True, **kwargs) -class ViTransformerWrapper(nn.Module): +class ViTransformerWrapper(Module): def __init__( self, *, @@ -1508,7 +1509,7 @@ def forward( return logits, embed -class TransformerWrapper(nn.Module): +class TransformerWrapper(Module): def __init__( self, *, @@ -1525,6 +1526,7 @@ def __init__( memory_tokens_interspersed_every = None, tie_embedding = False, logits_dim = None, + num_output_heads = 1, use_abs_pos_emb = True, scaled_sinu_pos_emb = False, l2norm_embed = False, @@ -1559,7 +1561,7 @@ def __init__( self.embeds = None if len(embed_num_tokens) > 0: - self.embeds = nn.ModuleDict({f'{name}_embed': nn.Embedding(num_tokens, emb_dim) for name, num_tokens in embed_num_tokens.items()}) + self.embeds = ModuleDict({f'{name}_embed': nn.Embedding(num_tokens, emb_dim) for name, num_tokens in embed_num_tokens.items()}) # fraction of the gradient that should go to the embedding, https://arxiv.org/abs/2105.13290 @@ -1573,8 +1575,19 @@ def __init__( self.init_() + # output head, usually to logits of num_tokens + logits_dim = default(logits_dim, num_tokens) - self.to_logits = nn.Linear(dim, logits_dim, bias = False) if not tie_embedding else lambda t: t @ self.token_emb.emb.weight.t() + + self.has_multiple_heads = False + + if tie_embedding: + self.to_logits = lambda t: t @ self.token_emb.emb.weight.t() + elif num_output_heads > 1: + self.has_multiple_heads = True + self.to_logits = ModuleList([nn.Linear(dim, logits_dim, bias = False) for _ in range(num_output_heads)]) + else: + self.to_logits = nn.Linear(dim, logits_dim, bias = False) # memory tokens (like [cls]) from Memory Transformers paper @@ -1705,6 +1718,8 @@ def forward( x, intermediates = self.attn_layers(x, mask = mask, mems = mems, mem_masks = mem_masks, cache = cache, return_hiddens = True, seq_start_pos = seq_start_pos, **kwargs) + # handle memories post-attention + if has_memory_tokens: if exists(mem_every): x = rearrange(x, 'b (n m) d -> (b n) m d', m = (mem_every + num_mems)) @@ -1718,12 +1733,24 @@ def forward( x = x[:, :n] + # projecting to logits + + if not return_embeddings: + if self.has_multiple_heads: + logits = tuple(fn(x) for fn in self.to_logits) + else: + logits = self.to_logits(x) + + # different returns + if return_logits_and_embeddings: - out = (self.to_logits(x), x) + out = (logits, x) elif return_embeddings: out = x else: - out = self.to_logits(x) + out = logits + + # aux loss if return_attn_z_loss: pre_softmax_attns = list(map(lambda t: t.pre_softmax_attn, intermediates.attn_intermediates)) @@ -1749,7 +1776,7 @@ def forward( return out -class XTransformer(nn.Module): +class XTransformer(Module): def __init__( self, *,