Skip to content

Commit

Permalink
add dropouts
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 5, 2020
1 parent f36b4b0 commit e033224
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '0.0.18',
version = '0.0.19',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand Down
9 changes: 8 additions & 1 deletion x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,8 @@ def __init__(
talking_heads = False,
sparse_topk = None,
use_entmax15 = False,
num_mem_kv = 0
num_mem_kv = 0,
dropout = 0.
):
super().__init__()
self.scale = dim_head ** -0.5
Expand All @@ -159,6 +160,7 @@ def __init__(
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim)
self.dropout = nn.Dropout(dropout)

# talking heads
self.talking_heads = talking_heads
Expand Down Expand Up @@ -223,6 +225,7 @@ def forward(self, x, context = None, mask = None, context_mask = None, rel_pos =
del mask

attn = self.attn_fn(dots, dim = -1)
attn = self.dropout(attn)

if talking_heads:
dots = einsum('b h i j, h k -> b k i j', dots, self.post_softmax_proj)
Expand Down Expand Up @@ -337,13 +340,16 @@ def __init__(
num_tokens,
max_seq_len,
attn_layers,
emb_dropout = 0.,
num_memory_tokens = 0
):
super().__init__()
dim = attn_layers.dim
self.max_seq_len = max_seq_len
self.token_emb = nn.Embedding(num_tokens, dim)
self.pos_emb = nn.Embedding(max_seq_len, dim)
self.emb_dropout = nn.Dropout(emb_dropout)

self.attn_layers = attn_layers
self.norm = nn.LayerNorm(dim)

Expand All @@ -363,6 +369,7 @@ def forward(self, x, return_embeddings = False, mask = None, **kwargs):
b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens
x = self.token_emb(x)
x += self.pos_emb(torch.arange(n, device = device))
x = self.emb_dropout(x)

if num_mem > 0:
mem = repeat(self.memory_tokens, 'n d -> b n d', b = b)
Expand Down

0 comments on commit e033224

Please sign in to comment.