Skip to content

Commit

Permalink
just give multi input its own wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Aug 6, 2024
1 parent cc436b3 commit 03f35a8
Show file tree
Hide file tree
Showing 5 changed files with 311 additions and 10 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '1.32.4',
version = '1.32.6',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand Down
20 changes: 11 additions & 9 deletions tests/test_x_transformers.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import torch

from x_transformers.x_transformers import (
XTransformer,
TransformerWrapper,
Decoder,
AutoregressiveWrapper
)

from x_transformers.multi_input import MultiInputTransformerWrapper

def test_readme():
model = XTransformer(
dim = 512,
Expand Down Expand Up @@ -129,28 +132,27 @@ def test_attn_softclamp_logits():
model(x)

def test_multiple_input_embeds():
model = TransformerWrapper(
num_tokens = 20000,
max_seq_len = 1024,
return_only_embed = True,
embed_num_tokens = dict(
model = MultiInputTransformerWrapper(
num_tokens = dict(
note = 20000,
pitch = 32,
tone = 16
),
max_seq_len = 1024,
return_only_embed = True,
attn_layers = Decoder(
dim = 128,
depth = 6,
heads = 8
)
)

x = torch.randint(0, 20000, (2, 1024))

embed_ids = dict(
x = dict(
note = torch.randint(0, 20000, (2, 1024)),
pitch = torch.randint(0, 32, (2, 1024)),
tone = torch.randint(0, 16, (2, 1024))
)

embed = model(x, embed_ids = embed_ids)
embed = model(x)

assert embed.shape == (2, 1024, 128)
2 changes: 2 additions & 0 deletions x_transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
ContinuousAutoregressiveWrapper
)

from x_transformers.multi_input import MultiInputTransformerWrapper

from x_transformers.xval import (
XValTransformerWrapper,
XValAutoregressiveWrapper
Expand Down
294 changes: 294 additions & 0 deletions x_transformers/multi_input.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,294 @@
from __future__ import annotations

import torch
from torch import nn, Tensor
from torch.nn import Module, ModuleDict
import torch.nn.functional as F

from typing import Dict

from einops import pack, repeat, unpack

from x_transformers.x_transformers import (
AttentionLayers,
ScaledSinusoidalEmbedding,
AbsolutePositionalEmbedding,
LayerIntermediates,
LayerNorm,
always,
pad_at_dim,
is_empty
)

# helper functions

def exists(val):
return val is not None

def default(val, d):
if exists(val):
return val
return d() if callable(d) else d


class MultiInputTransformerWrapper(Module):
def __init__(
self,
*,
num_tokens: Dict[str, int] = dict(),
max_seq_len,
attn_layers: AttentionLayers,
emb_dim = None,
max_mem_len = 0,
shift_mem_down = 0,
emb_dropout = 0.,
post_emb_norm = False,
num_memory_tokens = None,
memory_tokens_interspersed_every = None,
tie_embedding = False,
logits_dim = None,
return_only_embed = False,
num_output_heads = 1,
use_abs_pos_emb = True,
scaled_sinu_pos_emb = False,
emb_frac_gradient = 1., # GLM-130B and Cogview successfully used this, set at 0.1
attn_z_loss_weight = 1e-4,
):
super().__init__()

dim = attn_layers.dim
emb_dim = default(emb_dim, dim)
self.emb_dim = emb_dim

self.max_seq_len = max_seq_len
self.max_mem_len = max_mem_len
self.shift_mem_down = shift_mem_down

no_abs_pos_emb = max_seq_len == 0 or not (use_abs_pos_emb and not attn_layers.disable_abs_pos_emb)

if no_abs_pos_emb:
self.pos_emb = always(0)
elif scaled_sinu_pos_emb:
self.pos_emb = ScaledSinusoidalEmbedding(emb_dim)
else:
self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len)

# additional embeddings - say type embedding from BERT

self.embeds = ModuleDict({f'{name}_embed': nn.Embedding(one_num_tokens, emb_dim) for name, one_num_tokens in num_tokens.items()})

# fraction of the gradient that should go to the embedding, https://arxiv.org/abs/2105.13290

self.emb_frac_gradient = emb_frac_gradient

self.post_emb_norm = LayerNorm(emb_dim) if post_emb_norm else nn.Identity()
self.emb_dropout = nn.Dropout(emb_dropout)

self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
self.attn_layers = attn_layers

assert num_output_heads > 0

# output head, usually to logits of num_tokens

logits_dim = default(logits_dim, num_tokens)

self.has_multiple_heads = False

if return_only_embed:
self.to_logits = None
elif tie_embedding:
self.to_logits = lambda t: t @ self.token_emb.emb.weight.t()
elif num_output_heads > 1:
self.has_multiple_heads = True
self.to_logits = ModuleList([nn.Linear(dim, logits_dim, bias = False) for _ in range(num_output_heads)])
else:
self.to_logits = nn.Linear(dim, logits_dim, bias = False)

# memory tokens (like [cls]) from Memory Transformers paper

num_memory_tokens = default(num_memory_tokens, 0)
self.num_memory_tokens = num_memory_tokens
if num_memory_tokens > 0:
self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))

self.memory_tokens_interspersed_every = memory_tokens_interspersed_every

# whether can do cached kv decoding

self.can_cache_kv = self.num_memory_tokens == 0
self.can_cache_kv_outside_max_seq_len = no_abs_pos_emb

def forward(
self,
x: Dict[str, Tensor],
return_embeddings = False,
return_logits_and_embeddings = False,
return_intermediates = False,
mask = None,
return_mems = False,
return_attn = False,
mems = None,
mem_masks = None,
pos = None,
prepend_embeds = None,
prepend_mask = None,
sum_embeds = None,
return_attn_z_loss = False,
attn_z_loss_weight = 1e-4,
seq_start_pos = None,
cache: LayerIntermediates | None = None,
**kwargs
):
assert not is_empty(x)
first_input = list(x.values())[0]

b, n, device, num_mems, has_memory_tokens, emb_frac_gradient = *first_input.shape, first_input.device, self.num_memory_tokens, self.num_memory_tokens > 0, self.emb_frac_gradient

return_hiddens = return_mems | return_attn | return_intermediates | return_attn_z_loss
return_embeddings = return_embeddings | (not exists(self.to_logits))

# token embedding

assert len(x) == len(self.embeds)

token_emb = 0.

for name, embed_id in x.items():
embed_key = f'{name}_embed'

assert embed_key in self.embeds
embed = self.embeds[embed_key](embed_id)

token_emb = token_emb + embed

# absolute positional embedding

external_pos_emb = exists(pos) and pos.dtype != torch.long
pos_emb = self.pos_emb(first_input, pos = pos, seq_start_pos = seq_start_pos) if not external_pos_emb else pos

token_emb = token_emb + pos_emb

# for summing embeddings passed externally - needs this for self-conditioning in non-autoregressive training

if exists(sum_embeds):
token_emb = token_emb + sum_embeds

# set back to `x`

x = token_emb

# post embedding norm, purportedly leads to greater stabilization

x = self.post_emb_norm(x)

# whether to append embeds, as in PaLI, for image embeddings

if exists(prepend_embeds):
prepend_seq, prepend_dim = prepend_embeds.shape[1:]
assert prepend_dim == x.shape[-1], 'prepended embeddings need to have same dimensions as text model dimensions'

x = torch.cat((prepend_embeds, x), dim = -2)

if exists(prepend_mask) or exists(mask):
mask = default(mask, lambda: torch.ones((b, n), device = device, dtype = torch.bool))
prepend_mask = default(prepend_mask, lambda: torch.ones((b, prepend_seq), device = device, dtype = torch.bool))

mask = torch.cat((prepend_mask, mask), dim = -1)

# whether to reduce the gradient going to the embedding, from cogview paper, corroborated by GLM-130B model

if emb_frac_gradient < 1:
assert emb_frac_gradient > 0
x = x * emb_frac_gradient + x.detach() * (1 - emb_frac_gradient)

# embedding dropout

x = self.emb_dropout(x)

x = self.project_emb(x)

if has_memory_tokens:
mem_every = self.memory_tokens_interspersed_every

if exists(mem_every):
assert mem_every > 0
assert isinstance(self.attn_layers, Decoder), 'only for decoder'
next_seq_len = math.ceil(n / mem_every) * mem_every

x = pad_at_dim(x, (0, next_seq_len - n), dim = -2, value = 0.)
x = rearrange(x, 'b (n m) d -> (b n) m d', m = mem_every)

mem = repeat(self.memory_tokens, 'n d -> b n d', b = x.shape[0])
x, mem_packed_shape = pack((mem, x), 'b * d')

# auto-handle masking after appending memory tokens
if not exists(mem_every) and exists(mask):
mask = pad_at_dim(mask, (num_mems, 0), dim = -1, value = True)

if exists(mem_every):
x = rearrange(x, '(b n) m d -> b (n m) d', b = b)

if self.shift_mem_down and exists(mems):
mems_l, mems_r = mems[:self.shift_mem_down], mems[self.shift_mem_down:]
mems = [*mems_r, *mems_l]

x, intermediates = self.attn_layers(x, mask = mask, mems = mems, mem_masks = mem_masks, cache = cache, return_hiddens = True, seq_start_pos = seq_start_pos, **kwargs)

# handle memories post-attention

if has_memory_tokens:
if exists(mem_every):
x = rearrange(x, 'b (n m) d -> (b n) m d', m = (mem_every + num_mems))

mem, x = unpack(x, mem_packed_shape, 'b * d')

intermediates.memory_tokens = mem

if exists(mem_every):
x = rearrange(x, '(b n) m d -> b (n m) d', b = b)

x = x[:, :n]

# projecting to logits

if not return_embeddings:
if self.has_multiple_heads:
logits = tuple(fn(x) for fn in self.to_logits)
else:
logits = self.to_logits(x)

# different returns

if return_logits_and_embeddings:
out = (logits, x)
elif return_embeddings:
out = x
else:
out = logits

# aux loss

if return_attn_z_loss:
pre_softmax_attns = [t.pre_softmax_attn for t in intermediates.attn_intermediates]
intermediates.attn_z_loss = calc_z_loss(pre_softmax_attns, weight = attn_z_loss_weight)
return_intermediates = True

if return_mems:
hiddens = intermediates.hiddens
new_mems = [torch.cat(pair, dim = -2) for pair in zip(mems, hiddens)] if exists(mems) else hiddens
new_mems = [t[..., -self.max_mem_len:, :].detach() for t in new_mems]

if not return_intermediates:
return out, new_mems

intermediates.mems = new_mems

if return_intermediates:
return out, intermediates

if return_attn:
attn_maps = [t.post_softmax_attn for t in intermediates.attn_intermediates]
return out, attn_maps

return out
3 changes: 3 additions & 0 deletions x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ def cast_tuple(val, depth):
def divisible_by(num, den):
return (num % den) == 0

def is_empty(x):
return len(x) == 0

def maybe(fn):
@wraps(fn)
def inner(x, *args, **kwargs):
Expand Down

0 comments on commit 03f35a8

Please sign in to comment.