Skip to content

Commit

Permalink
redo topk attention to allow for straight through
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 18, 2024
1 parent 79ef49e commit 7c56d23
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 51 deletions.
18 changes: 17 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,23 @@ model = TransformerWrapper(
dim = 512,
depth = 6,
heads = 8,
attn_sparse_topk = 8 # keep only the top 8 values before attention (softmax)
attn_sparse_topk = 8, # keep only the top 8 values before attention (softmax)
sparse_topk_straight_through = True # straight through the original gradients
)
)
```

An extreme case of `topk` value of `1`, you can use the following

```python
model = TransformerWrapper(
num_tokens = 20000,
max_seq_len = 1024,
attn_layers = Decoder(
dim = 512,
depth = 6,
heads = 8,
attn_hard = True # will only propagate the single value of the argmax of qk logit. offered in the case it addresses https://arxiv.org/abs/2410.01104
)
)
```
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '1.39.4',
version = '1.40.1',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand Down
22 changes: 0 additions & 22 deletions tests/test_x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,28 +280,6 @@ def test_mos():

eval_logits = model(x)

def test_sigsoftmax():
model = TransformerWrapper(
num_tokens = 20000,
max_seq_len = 1024,
mixture_of_softmax = True,
sigsoftmax_logits = True,
attn_layers = Decoder(
attn_sigsoftmax = True,
dim = 128,
depth = 6,
heads = 8
)
)

x = torch.randint(0, 20000, (2, 1024))

logits = model(x)

model.eval()

eval_logits = model(x)

@pytest.mark.parametrize('attn_one_kv_head', (True, False))
def test_l2_distance(attn_one_kv_head):

Expand Down
61 changes: 36 additions & 25 deletions x_transformers/attend.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,35 @@ def qk_l2_dist_squared(q, k):

# one-hot straight through softmax

def one_hot_straight_through(t, temperature = 1.):
one_hot_indices = t.argmax(dim = -1, keepdim = True)
one_hot = torch.zeros_like(t).scatter(-1, one_hot_indices, 1.)
def one_hot_straight_through(logits, temperature = 1.):
one_hot_indices = logits.argmax(dim = -1, keepdim = True)
one_hot = torch.zeros_like(logits).scatter(-1, one_hot_indices, 1.)

t = (t / temperature).softmax(dim = -1)
return one_hot + t - t.detach()
soft_attn = (logits / temperature).softmax(dim = -1)
return one_hot + soft_attn - soft_attn.detach()

# sparse topk attention - only keep topk attn logits for softmax
# optional straight through with masked out logits by setting `attn_sparse_topk_straight_through = True`

def sparse_topk_attn(
logits,
sparse_topk,
temperature = 1.,
straight_through = False
):
orig_logits = logits

mask_value = -torch.finfo(logits.dtype).max
top_values, _ = logits.topk(sparse_topk, dim = -1)
sparse_topk_mask = (logits >= top_values[..., -1:]) & (logits > mask_value)
logits = logits.masked_fill(~sparse_topk_mask, mask_value)
topk_attn = logits.softmax(dim = -1)

if not straight_through:
return topk_attn

soft_attn = (orig_logits / temperature).softmax(dim = -1)
return topk_attn.detach() + soft_attn - soft_attn.detach()

# functions for creating causal mask
# need a special one for onnx cpu (no support for .triu)
Expand All @@ -141,6 +164,7 @@ def __init__(
post_talking_heads = False,
pre_scale_post_talking_heads = False,
sparse_topk = None,
sparse_topk_straight_through = False,
scale = None,
qk_norm = False,
l2_distance = False,
Expand All @@ -152,7 +176,6 @@ def __init__(
add_zero_kv = False,
selective = False,
hard = False,
sigsoftmax = False,
cope = None,
onnxable = False,
sdp_kwargs: dict = dict(
Expand All @@ -171,16 +194,22 @@ def __init__(

# attention type

is_sparse_topk_attn = exists(sparse_topk)

assert not (flash and sigmoid), 'sigmoid attention not available for flash'
assert not (flash and hard), 'hard attention not available for flash'
assert at_most_one_of(sigmoid, hard, l2_distance)
assert not (flash and is_sparse_topk_attn), 'topk attention not available for flash'

assert at_most_one_of(sigmoid, hard, l2_distance, is_sparse_topk_attn)

if exists(custom_attn_fn):
self.attn_fn = custom_attn_fn
elif sigmoid:
self.attn_fn = F.sigmoid
elif hard:
self.attn_fn = one_hot_straight_through
elif is_sparse_topk_attn:
self.attn_fn = partial(sparse_topk_attn, sparse_topk = sparse_topk, straight_through = sparse_topk_straight_through)
else:
softmax_fn = partial(F.softmax, dim = -1)
self.attn_fn = partial(softmax_fn, dtype = torch.float32) if not qk_norm else softmax_fn
Expand Down Expand Up @@ -214,16 +243,6 @@ def __init__(
assert not (selective and not causal), 'selective attention is designed for autoregressive'
self.selective = selective

# sparse topk

assert not (flash and sparse_topk), 'sparse topk not compatible with flash attention'
self.sparse_topk = sparse_topk

# sig softmax

assert not (flash and sigsoftmax), 'sigsoftmax not available for flash attention'
self.sigsoftmax = sigsoftmax

# l2 distance attention

self.l2_distance = l2_distance
Expand Down Expand Up @@ -476,11 +495,6 @@ def forward(
causal_mask = self.create_causal_mask(i, j, device = device)
sim = sim.masked_fill(causal_mask, mask_value)

if exists(self.sparse_topk):
top_values, _ = sim.topk(self.sparse_topk, dim = -1)
sparse_topk_mask = (sim >= top_values[..., -1:]) & (sim > mask_value)
sim = sim.masked_fill(~sparse_topk_mask, mask_value)

row_is_entirely_masked = None

if exists(mask):
Expand All @@ -494,9 +508,6 @@ def forward(

pre_softmax_attn = sim

if self.sigsoftmax:
sim = sim + sim.sigmoid().log()

attn = self.attn_fn(sim)

attn = attn.type(dtype)
Expand Down
4 changes: 2 additions & 2 deletions x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -912,6 +912,7 @@ def __init__(
pre_scale_post_talking_heads = False,
head_scale = False,
sparse_topk = None,
sparse_topk_straight_through = False,
num_mem_kv = 0,
dropout = 0.,
on_attn = False,
Expand All @@ -920,7 +921,6 @@ def __init__(
gate_values = False,
zero_init_output = False,
hard = False,
sigsoftmax = False,
max_attend_past = None,
qk_norm = False,
qk_norm_groups = 1,
Expand Down Expand Up @@ -1044,6 +1044,7 @@ def __init__(
pre_scale_post_talking_heads = pre_scale_post_talking_heads,
dropout = dropout,
sparse_topk = sparse_topk,
sparse_topk_straight_through = sparse_topk_straight_through,
hard = hard,
qk_norm = qk_norm,
scale = qk_norm_scale if qk_norm else self.scale,
Expand All @@ -1054,7 +1055,6 @@ def __init__(
add_zero_kv = add_zero_kv,
flash = flash,
softclamp_logits = softclamp_logits,
sigsoftmax = sigsoftmax,
logit_softclamp_value = logit_softclamp_value,
cope = cope,
onnxable = onnxable
Expand Down

0 comments on commit 7c56d23

Please sign in to comment.