Skip to content

Commit

Permalink
release custom embedder and unembedder, contributed by @pradeep-pyro
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 7, 2024
1 parent 409ba0f commit 55148ba
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 10 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '1.42.7',
version = '1.42.8',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand Down
27 changes: 20 additions & 7 deletions tests/test_x_transformers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import pytest

import torch
from torch import nn
from torch.nn import Module

from x_transformers.x_transformers import (
XTransformer,
Expand Down Expand Up @@ -401,26 +404,31 @@ def test_embedder(embedder_type):
num_tokens = 20000
dim = 128
token_emb_kwargs = {}

if embedder_type == 'embedding':
embedder = torch.nn.Embedding(num_tokens, dim)
embedder = nn.Embedding(num_tokens, dim)
elif embedder_type == 'none':
embedder = None
else:
class CustomEmbedder(torch.nn.Module):
class CustomEmbedder(Module):
"""
Made up embedder that sums two embeddings. Just to check if we can pass additional input to the embedder's
forward pass without breaking the model.
"""
def __init__(self, num_tokens, dim):
super().__init__()
self.embed_x = torch.nn.Embedding(num_tokens, dim)
self.embed_y = torch.nn.Embedding(num_tokens, dim)
self.embed_x = nn.Embedding(num_tokens, dim)
self.embed_y = nn.Embedding(num_tokens, dim)

def forward(self, x, y):
return self.embed_x(x) + self.embed_y(y)

def init_(self):
pass

embedder = CustomEmbedder(num_tokens, dim)
token_emb_kwargs['y'] = torch.randint(0, num_tokens, (2, 1024))

model = TransformerWrapper(
num_tokens = num_tokens,
max_seq_len = 1024,
Expand All @@ -442,16 +450,19 @@ def init_(self):
def test_to_logits(to_logits):
num_tokens = 20000
dim = 128

to_logits_kwargs = {}

if to_logits == 'linear':
logit_mapper = LinearNoBias(dim, num_tokens)
elif to_logits == 'none':
logit_mapper = None
else:
class PointerNetworkLogits(torch.nn.Module):
class PointerNetworkLogits(Module):
def __init__(self, dim):
super().__init__()
self.proj_to_pointers = torch.nn.Linear(dim, dim)
self.proj_to_pointers = nn.Linear(dim, dim)

def forward(self, model_embeddings, input_embeddings):
pointers = self.proj_to_pointers(model_embeddings)
logits = torch.matmul(pointers, input_embeddings.permute(0, 2, 1))
Expand All @@ -472,5 +483,7 @@ def forward(self, model_embeddings, input_embeddings):
)

x = torch.randint(0, num_tokens, (2, 1024))

output = model(x, to_logits_kwargs=to_logits_kwargs)
assert output.shape == (2, 1024, 20000)

assert output.shape == (2, 1024, 20000)
5 changes: 3 additions & 2 deletions x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2371,12 +2371,12 @@ def __init__(
if return_only_embed:
self.to_logits = None
elif tie_embedding:
assert isinstance(self.token_emb, TokenEmbedding), 'can only tie embedding if using `TokenEmbedding`'
assert isinstance(token_emb, TokenEmbedding), 'can only tie embedding if using `TokenEmbedding`'
self.to_logits = lambda t: t @ self.token_emb.emb.weight.t()
elif num_output_heads > 1:
self.to_logits = ModuleList([LinearNoBias(dim, logits_dim) for _ in range(num_output_heads)])
else:
self.to_logits = LinearNoBias(dim, logits_dim) if to_logits is None else to_logits
self.to_logits = LinearNoBias(dim, logits_dim) if not exists(to_logits) else to_logits

# memory tokens (like [cls]) from Memory Transformers paper

Expand All @@ -2399,6 +2399,7 @@ def __init__(
def init_(self):
if hasattr(self.token_emb, 'init_'):
self.token_emb.init_()

if self.l2norm_embed:
if not isinstance(self.pos_emb, always):
nn.init.normal_(self.pos_emb.emb.weight, std = 1e-5)
Expand Down

0 comments on commit 55148ba

Please sign in to comment.