Skip to content

Commit

Permalink
dog food hyper connections
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 27, 2024
1 parent a85dfa9 commit b4aa3f2
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 32 deletions.
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'tab-transformer-pytorch',
packages = find_packages(),
version = '0.3.0',
version = '0.4.0',
license='MIT',
description = 'Tab Transformer - Pytorch',
long_description_content_type = 'text/markdown',
Expand All @@ -17,8 +17,8 @@
'tabular data'
],
install_requires=[
'einops>=0.3',
'torch>=1.6'
'einops>=0.8',
'torch>=2.3'
],
classifiers=[
'Development Status :: 4 - Beta',
Expand Down
39 changes: 26 additions & 13 deletions tab_transformer_pytorch/ft_transformer.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import torch
import torch.nn.functional as F
from torch import nn, einsum
from torch.nn import Module, ModuleList
import torch.nn.functional as F

from einops import rearrange, repeat

from hyper_connections import HyperConnections

# feedforward and attention

class GEGLU(nn.Module):
class GEGLU(Module):
def forward(self, x):
x, gates = x.chunk(2, dim = -1)
return x * F.gelu(gates)
Expand All @@ -20,7 +23,7 @@ def FeedForward(dim, mult = 4, dropout = 0.):
nn.Linear(dim * mult, dim)
)

class Attention(nn.Module):
class Attention(Module):
def __init__(
self,
dim,
Expand Down Expand Up @@ -62,43 +65,51 @@ def forward(self, x):

# transformer

class Transformer(nn.Module):
class Transformer(Module):
def __init__(
self,
dim,
depth,
heads,
dim_head,
attn_dropout,
ff_dropout
ff_dropout,
num_residual_streams = 4
):
super().__init__()
self.layers = nn.ModuleList([])

init_hyper_conn, self.expand_streams, self.reduce_streams = HyperConnections.get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)

self.layers = ModuleList([])

for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = attn_dropout),
FeedForward(dim, dropout = ff_dropout),
self.layers.append(ModuleList([
init_hyper_conn(dim = dim, branch = Attention(dim, heads = heads, dim_head = dim_head, dropout = attn_dropout)),
init_hyper_conn(dim = dim, branch = FeedForward(dim, dropout = ff_dropout)),
]))

def forward(self, x, return_attn = False):
post_softmax_attns = []

x = self.expand_streams(x)

for attn, ff in self.layers:
attn_out, post_softmax_attn = attn(x)
post_softmax_attns.append(post_softmax_attn)

x = attn_out + x
x = ff(x) + x

x = self.reduce_streams(x)

if not return_attn:
return x

return x, torch.stack(post_softmax_attns)

# numerical embedder

class NumericalEmbedder(nn.Module):
class NumericalEmbedder(Module):
def __init__(self, dim, num_numerical_types):
super().__init__()
self.weights = nn.Parameter(torch.randn(num_numerical_types, dim))
Expand All @@ -110,7 +121,7 @@ def forward(self, x):

# main class

class FTTransformer(nn.Module):
class FTTransformer(Module):
def __init__(
self,
*,
Expand All @@ -123,7 +134,8 @@ def __init__(
dim_out = 1,
num_special_tokens = 2,
attn_dropout = 0.,
ff_dropout = 0.
ff_dropout = 0.,
num_residual_streams = 4
):
super().__init__()
assert all(map(lambda n: n > 0, categories)), 'number of each category must be positive'
Expand Down Expand Up @@ -169,7 +181,8 @@ def __init__(
heads = heads,
dim_head = dim_head,
attn_dropout = attn_dropout,
ff_dropout = ff_dropout
ff_dropout = ff_dropout,
num_residual_streams = num_residual_streams
)

# to logits
Expand Down
44 changes: 28 additions & 16 deletions tab_transformer_pytorch/tab_transformer_pytorch.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import torch
import torch.nn.functional as F
from torch import nn, einsum
from torch.nn import Module, ModuleList
import torch.nn.functional as F

from einops import rearrange, repeat

from hyper_connections import HyperConnections

# helpers

def exists(val):
Expand All @@ -14,15 +17,15 @@ def default(val, d):

# classes

class Residual(nn.Module):
class Residual(Module):
def __init__(self, fn):
super().__init__()
self.fn = fn

def forward(self, x, **kwargs):
return self.fn(x, **kwargs) + x

class PreNorm(nn.Module):
class PreNorm(Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
Expand All @@ -33,12 +36,12 @@ def forward(self, x, **kwargs):

# attention

class GEGLU(nn.Module):
class GEGLU(Module):
def forward(self, x):
x, gates = x.chunk(2, dim = -1)
return x * F.gelu(gates)

class FeedForward(nn.Module):
class FeedForward(Module):
def __init__(self, dim, mult = 4, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
Expand All @@ -51,7 +54,7 @@ def __init__(self, dim, mult = 4, dropout = 0.):
def forward(self, x, **kwargs):
return self.net(x)

class Attention(nn.Module):
class Attention(Module):
def __init__(
self,
dim,
Expand Down Expand Up @@ -84,42 +87,49 @@ def forward(self, x):

# transformer

class Transformer(nn.Module):
class Transformer(Module):
def __init__(
self,
dim,
depth,
heads,
dim_head,
attn_dropout,
ff_dropout
ff_dropout,
num_residual_streams = 4
):
super().__init__()
self.layers = nn.ModuleList([])
self.layers = ModuleList([])

init_hyper_conn, self.expand_streams, self.reduce_streams = HyperConnections.get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)

for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = attn_dropout)),
PreNorm(dim, FeedForward(dim, dropout = ff_dropout)),
self.layers.append(ModuleList([
init_hyper_conn(dim = dim, branch = PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = attn_dropout))),
init_hyper_conn(dim = dim, branch = PreNorm(dim, FeedForward(dim, dropout = ff_dropout))),
]))

def forward(self, x, return_attn = False):
post_softmax_attns = []

x = self.expand_streams(x)

for attn, ff in self.layers:
attn_out, post_softmax_attn = attn(x)
post_softmax_attns.append(post_softmax_attn)

x = x + attn_out
x = ff(x) + x

x = self.reduce_streams(x)

if not return_attn:
return x

return x, torch.stack(post_softmax_attns)
# mlp

class MLP(nn.Module):
class MLP(Module):
def __init__(self, dims, act = None):
super().__init__()
dims_pairs = list(zip(dims[:-1], dims[1:]))
Expand All @@ -142,7 +152,7 @@ def forward(self, x):

# main class

class TabTransformer(nn.Module):
class TabTransformer(Module):
def __init__(
self,
*,
Expand All @@ -160,7 +170,8 @@ def __init__(
attn_dropout = 0.,
ff_dropout = 0.,
use_shared_categ_embed = True,
shared_categ_dim_divisor = 8. # in paper, they reserve dimension / 8 for category shared embedding
shared_categ_dim_divisor = 8., # in paper, they reserve dimension / 8 for category shared embedding
num_residual_streams = 4
):
super().__init__()
assert all(map(lambda n: n > 0, categories)), 'number of each category must be positive'
Expand Down Expand Up @@ -214,7 +225,8 @@ def __init__(
heads = heads,
dim_head = dim_head,
attn_dropout = attn_dropout,
ff_dropout = ff_dropout
ff_dropout = ff_dropout,
num_residual_streams = num_residual_streams
)

# mlp to logits
Expand Down

0 comments on commit b4aa3f2

Please sign in to comment.