Skip to content

Commit

Permalink
make room for special tokens in categories embedding table
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 20, 2020
1 parent 20e3252 commit f831f2b
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'tab-transformer-pytorch',
packages = find_packages(),
version = '0.0.8',
version = '0.0.9',
license='MIT',
description = 'Tab Transformer - Pytorch',
author = 'Phil Wang',
Expand Down
14 changes: 11 additions & 3 deletions tab_transformer_pytorch/tab_transformer_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def __init__(
dim_out = 1,
mlp_hidden_mults = (4, 2),
mlp_act = None,
num_special_tokens = 2,
continuous_mean_std = None
):
super().__init__()
Expand All @@ -125,14 +126,21 @@ def __init__(
self.num_categories = len(categories)
self.num_unique_categories = sum(categories)

# create category embeddings table

self.num_special_tokens = num_special_tokens
total_tokens = self.num_unique_categories + num_special_tokens

self.categorical_embeds = nn.Embedding(total_tokens, dim)

# for automatically offsetting unique category ids to the correct position in the categories embedding table

categories_offset = F.pad(torch.tensor(list(categories)), (1, 0), value = 0).cumsum(dim = -1)[:-1]
categories_offset = F.pad(torch.tensor(list(categories)), (1, 0), value = num_special_tokens)
categories_offset = categories_offset.cumsum(dim = -1)[:-1]
self.register_buffer('categories_offset', categories_offset)

self.categorical_embeds = nn.Embedding(self.num_unique_categories, dim)

# continuous

if exists(continuous_mean_std):
assert continuous_mean_std.shape == (num_continuous, 2), f'continuous_mean_std must have a shape of ({num_continuous}, 2) where the last dimension contains the mean and variance respectively'
self.register_buffer('continuous_mean_std', continuous_mean_std)
Expand Down

0 comments on commit f831f2b

Please sign in to comment.