-
-
Notifications
You must be signed in to change notification settings - Fork 417
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
just give multi input its own wrapper
- Loading branch information
1 parent
cc436b3
commit 03f35a8
Showing
5 changed files
with
311 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,294 @@ | ||
from __future__ import annotations | ||
|
||
import torch | ||
from torch import nn, Tensor | ||
from torch.nn import Module, ModuleDict | ||
import torch.nn.functional as F | ||
|
||
from typing import Dict | ||
|
||
from einops import pack, repeat, unpack | ||
|
||
from x_transformers.x_transformers import ( | ||
AttentionLayers, | ||
ScaledSinusoidalEmbedding, | ||
AbsolutePositionalEmbedding, | ||
LayerIntermediates, | ||
LayerNorm, | ||
always, | ||
pad_at_dim, | ||
is_empty | ||
) | ||
|
||
# helper functions | ||
|
||
def exists(val): | ||
return val is not None | ||
|
||
def default(val, d): | ||
if exists(val): | ||
return val | ||
return d() if callable(d) else d | ||
|
||
|
||
class MultiInputTransformerWrapper(Module): | ||
def __init__( | ||
self, | ||
*, | ||
num_tokens: Dict[str, int] = dict(), | ||
max_seq_len, | ||
attn_layers: AttentionLayers, | ||
emb_dim = None, | ||
max_mem_len = 0, | ||
shift_mem_down = 0, | ||
emb_dropout = 0., | ||
post_emb_norm = False, | ||
num_memory_tokens = None, | ||
memory_tokens_interspersed_every = None, | ||
tie_embedding = False, | ||
logits_dim = None, | ||
return_only_embed = False, | ||
num_output_heads = 1, | ||
use_abs_pos_emb = True, | ||
scaled_sinu_pos_emb = False, | ||
emb_frac_gradient = 1., # GLM-130B and Cogview successfully used this, set at 0.1 | ||
attn_z_loss_weight = 1e-4, | ||
): | ||
super().__init__() | ||
|
||
dim = attn_layers.dim | ||
emb_dim = default(emb_dim, dim) | ||
self.emb_dim = emb_dim | ||
|
||
self.max_seq_len = max_seq_len | ||
self.max_mem_len = max_mem_len | ||
self.shift_mem_down = shift_mem_down | ||
|
||
no_abs_pos_emb = max_seq_len == 0 or not (use_abs_pos_emb and not attn_layers.disable_abs_pos_emb) | ||
|
||
if no_abs_pos_emb: | ||
self.pos_emb = always(0) | ||
elif scaled_sinu_pos_emb: | ||
self.pos_emb = ScaledSinusoidalEmbedding(emb_dim) | ||
else: | ||
self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) | ||
|
||
# additional embeddings - say type embedding from BERT | ||
|
||
self.embeds = ModuleDict({f'{name}_embed': nn.Embedding(one_num_tokens, emb_dim) for name, one_num_tokens in num_tokens.items()}) | ||
|
||
# fraction of the gradient that should go to the embedding, https://arxiv.org/abs/2105.13290 | ||
|
||
self.emb_frac_gradient = emb_frac_gradient | ||
|
||
self.post_emb_norm = LayerNorm(emb_dim) if post_emb_norm else nn.Identity() | ||
self.emb_dropout = nn.Dropout(emb_dropout) | ||
|
||
self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity() | ||
self.attn_layers = attn_layers | ||
|
||
assert num_output_heads > 0 | ||
|
||
# output head, usually to logits of num_tokens | ||
|
||
logits_dim = default(logits_dim, num_tokens) | ||
|
||
self.has_multiple_heads = False | ||
|
||
if return_only_embed: | ||
self.to_logits = None | ||
elif tie_embedding: | ||
self.to_logits = lambda t: t @ self.token_emb.emb.weight.t() | ||
elif num_output_heads > 1: | ||
self.has_multiple_heads = True | ||
self.to_logits = ModuleList([nn.Linear(dim, logits_dim, bias = False) for _ in range(num_output_heads)]) | ||
else: | ||
self.to_logits = nn.Linear(dim, logits_dim, bias = False) | ||
|
||
# memory tokens (like [cls]) from Memory Transformers paper | ||
|
||
num_memory_tokens = default(num_memory_tokens, 0) | ||
self.num_memory_tokens = num_memory_tokens | ||
if num_memory_tokens > 0: | ||
self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim)) | ||
|
||
self.memory_tokens_interspersed_every = memory_tokens_interspersed_every | ||
|
||
# whether can do cached kv decoding | ||
|
||
self.can_cache_kv = self.num_memory_tokens == 0 | ||
self.can_cache_kv_outside_max_seq_len = no_abs_pos_emb | ||
|
||
def forward( | ||
self, | ||
x: Dict[str, Tensor], | ||
return_embeddings = False, | ||
return_logits_and_embeddings = False, | ||
return_intermediates = False, | ||
mask = None, | ||
return_mems = False, | ||
return_attn = False, | ||
mems = None, | ||
mem_masks = None, | ||
pos = None, | ||
prepend_embeds = None, | ||
prepend_mask = None, | ||
sum_embeds = None, | ||
return_attn_z_loss = False, | ||
attn_z_loss_weight = 1e-4, | ||
seq_start_pos = None, | ||
cache: LayerIntermediates | None = None, | ||
**kwargs | ||
): | ||
assert not is_empty(x) | ||
first_input = list(x.values())[0] | ||
|
||
b, n, device, num_mems, has_memory_tokens, emb_frac_gradient = *first_input.shape, first_input.device, self.num_memory_tokens, self.num_memory_tokens > 0, self.emb_frac_gradient | ||
|
||
return_hiddens = return_mems | return_attn | return_intermediates | return_attn_z_loss | ||
return_embeddings = return_embeddings | (not exists(self.to_logits)) | ||
|
||
# token embedding | ||
|
||
assert len(x) == len(self.embeds) | ||
|
||
token_emb = 0. | ||
|
||
for name, embed_id in x.items(): | ||
embed_key = f'{name}_embed' | ||
|
||
assert embed_key in self.embeds | ||
embed = self.embeds[embed_key](embed_id) | ||
|
||
token_emb = token_emb + embed | ||
|
||
# absolute positional embedding | ||
|
||
external_pos_emb = exists(pos) and pos.dtype != torch.long | ||
pos_emb = self.pos_emb(first_input, pos = pos, seq_start_pos = seq_start_pos) if not external_pos_emb else pos | ||
|
||
token_emb = token_emb + pos_emb | ||
|
||
# for summing embeddings passed externally - needs this for self-conditioning in non-autoregressive training | ||
|
||
if exists(sum_embeds): | ||
token_emb = token_emb + sum_embeds | ||
|
||
# set back to `x` | ||
|
||
x = token_emb | ||
|
||
# post embedding norm, purportedly leads to greater stabilization | ||
|
||
x = self.post_emb_norm(x) | ||
|
||
# whether to append embeds, as in PaLI, for image embeddings | ||
|
||
if exists(prepend_embeds): | ||
prepend_seq, prepend_dim = prepend_embeds.shape[1:] | ||
assert prepend_dim == x.shape[-1], 'prepended embeddings need to have same dimensions as text model dimensions' | ||
|
||
x = torch.cat((prepend_embeds, x), dim = -2) | ||
|
||
if exists(prepend_mask) or exists(mask): | ||
mask = default(mask, lambda: torch.ones((b, n), device = device, dtype = torch.bool)) | ||
prepend_mask = default(prepend_mask, lambda: torch.ones((b, prepend_seq), device = device, dtype = torch.bool)) | ||
|
||
mask = torch.cat((prepend_mask, mask), dim = -1) | ||
|
||
# whether to reduce the gradient going to the embedding, from cogview paper, corroborated by GLM-130B model | ||
|
||
if emb_frac_gradient < 1: | ||
assert emb_frac_gradient > 0 | ||
x = x * emb_frac_gradient + x.detach() * (1 - emb_frac_gradient) | ||
|
||
# embedding dropout | ||
|
||
x = self.emb_dropout(x) | ||
|
||
x = self.project_emb(x) | ||
|
||
if has_memory_tokens: | ||
mem_every = self.memory_tokens_interspersed_every | ||
|
||
if exists(mem_every): | ||
assert mem_every > 0 | ||
assert isinstance(self.attn_layers, Decoder), 'only for decoder' | ||
next_seq_len = math.ceil(n / mem_every) * mem_every | ||
|
||
x = pad_at_dim(x, (0, next_seq_len - n), dim = -2, value = 0.) | ||
x = rearrange(x, 'b (n m) d -> (b n) m d', m = mem_every) | ||
|
||
mem = repeat(self.memory_tokens, 'n d -> b n d', b = x.shape[0]) | ||
x, mem_packed_shape = pack((mem, x), 'b * d') | ||
|
||
# auto-handle masking after appending memory tokens | ||
if not exists(mem_every) and exists(mask): | ||
mask = pad_at_dim(mask, (num_mems, 0), dim = -1, value = True) | ||
|
||
if exists(mem_every): | ||
x = rearrange(x, '(b n) m d -> b (n m) d', b = b) | ||
|
||
if self.shift_mem_down and exists(mems): | ||
mems_l, mems_r = mems[:self.shift_mem_down], mems[self.shift_mem_down:] | ||
mems = [*mems_r, *mems_l] | ||
|
||
x, intermediates = self.attn_layers(x, mask = mask, mems = mems, mem_masks = mem_masks, cache = cache, return_hiddens = True, seq_start_pos = seq_start_pos, **kwargs) | ||
|
||
# handle memories post-attention | ||
|
||
if has_memory_tokens: | ||
if exists(mem_every): | ||
x = rearrange(x, 'b (n m) d -> (b n) m d', m = (mem_every + num_mems)) | ||
|
||
mem, x = unpack(x, mem_packed_shape, 'b * d') | ||
|
||
intermediates.memory_tokens = mem | ||
|
||
if exists(mem_every): | ||
x = rearrange(x, '(b n) m d -> b (n m) d', b = b) | ||
|
||
x = x[:, :n] | ||
|
||
# projecting to logits | ||
|
||
if not return_embeddings: | ||
if self.has_multiple_heads: | ||
logits = tuple(fn(x) for fn in self.to_logits) | ||
else: | ||
logits = self.to_logits(x) | ||
|
||
# different returns | ||
|
||
if return_logits_and_embeddings: | ||
out = (logits, x) | ||
elif return_embeddings: | ||
out = x | ||
else: | ||
out = logits | ||
|
||
# aux loss | ||
|
||
if return_attn_z_loss: | ||
pre_softmax_attns = [t.pre_softmax_attn for t in intermediates.attn_intermediates] | ||
intermediates.attn_z_loss = calc_z_loss(pre_softmax_attns, weight = attn_z_loss_weight) | ||
return_intermediates = True | ||
|
||
if return_mems: | ||
hiddens = intermediates.hiddens | ||
new_mems = [torch.cat(pair, dim = -2) for pair in zip(mems, hiddens)] if exists(mems) else hiddens | ||
new_mems = [t[..., -self.max_mem_len:, :].detach() for t in new_mems] | ||
|
||
if not return_intermediates: | ||
return out, new_mems | ||
|
||
intermediates.mems = new_mems | ||
|
||
if return_intermediates: | ||
return out, intermediates | ||
|
||
if return_attn: | ||
attn_maps = [t.post_softmax_attn for t in intermediates.attn_intermediates] | ||
return out, attn_maps | ||
|
||
return out |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters