From 7c56d238eb8a9495cd1ab326aea0ed585fc54501 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Fri, 18 Oct 2024 08:33:59 -0700 Subject: [PATCH] redo topk attention to allow for straight through --- README.md | 18 +++++++++- setup.py | 2 +- tests/test_x_transformers.py | 22 ------------ x_transformers/attend.py | 61 +++++++++++++++++++------------- x_transformers/x_transformers.py | 4 +-- 5 files changed, 56 insertions(+), 51 deletions(-) diff --git a/README.md b/README.md index 1dc8ffb6..1b61da5e 100644 --- a/README.md +++ b/README.md @@ -525,7 +525,23 @@ model = TransformerWrapper( dim = 512, depth = 6, heads = 8, - attn_sparse_topk = 8 # keep only the top 8 values before attention (softmax) + attn_sparse_topk = 8, # keep only the top 8 values before attention (softmax) + sparse_topk_straight_through = True # straight through the original gradients + ) +) +``` + +An extreme case of `topk` value of `1`, you can use the following + +```python +model = TransformerWrapper( + num_tokens = 20000, + max_seq_len = 1024, + attn_layers = Decoder( + dim = 512, + depth = 6, + heads = 8, + attn_hard = True # will only propagate the single value of the argmax of qk logit. offered in the case it addresses https://arxiv.org/abs/2410.01104 ) ) ``` diff --git a/setup.py b/setup.py index 2e3f2bd7..21cf0429 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'x-transformers', packages = find_packages(exclude=['examples']), - version = '1.39.4', + version = '1.40.1', license='MIT', description = 'X-Transformers - Pytorch', author = 'Phil Wang', diff --git a/tests/test_x_transformers.py b/tests/test_x_transformers.py index bc541685..7e9bb43d 100644 --- a/tests/test_x_transformers.py +++ b/tests/test_x_transformers.py @@ -280,28 +280,6 @@ def test_mos(): eval_logits = model(x) -def test_sigsoftmax(): - model = TransformerWrapper( - num_tokens = 20000, - max_seq_len = 1024, - mixture_of_softmax = True, - sigsoftmax_logits = True, - attn_layers = Decoder( - attn_sigsoftmax = True, - dim = 128, - depth = 6, - heads = 8 - ) - ) - - x = torch.randint(0, 20000, (2, 1024)) - - logits = model(x) - - model.eval() - - eval_logits = model(x) - @pytest.mark.parametrize('attn_one_kv_head', (True, False)) def test_l2_distance(attn_one_kv_head): diff --git a/x_transformers/attend.py b/x_transformers/attend.py index ff76794c..c36bc09c 100644 --- a/x_transformers/attend.py +++ b/x_transformers/attend.py @@ -109,12 +109,35 @@ def qk_l2_dist_squared(q, k): # one-hot straight through softmax -def one_hot_straight_through(t, temperature = 1.): - one_hot_indices = t.argmax(dim = -1, keepdim = True) - one_hot = torch.zeros_like(t).scatter(-1, one_hot_indices, 1.) +def one_hot_straight_through(logits, temperature = 1.): + one_hot_indices = logits.argmax(dim = -1, keepdim = True) + one_hot = torch.zeros_like(logits).scatter(-1, one_hot_indices, 1.) - t = (t / temperature).softmax(dim = -1) - return one_hot + t - t.detach() + soft_attn = (logits / temperature).softmax(dim = -1) + return one_hot + soft_attn - soft_attn.detach() + +# sparse topk attention - only keep topk attn logits for softmax +# optional straight through with masked out logits by setting `attn_sparse_topk_straight_through = True` + +def sparse_topk_attn( + logits, + sparse_topk, + temperature = 1., + straight_through = False +): + orig_logits = logits + + mask_value = -torch.finfo(logits.dtype).max + top_values, _ = logits.topk(sparse_topk, dim = -1) + sparse_topk_mask = (logits >= top_values[..., -1:]) & (logits > mask_value) + logits = logits.masked_fill(~sparse_topk_mask, mask_value) + topk_attn = logits.softmax(dim = -1) + + if not straight_through: + return topk_attn + + soft_attn = (orig_logits / temperature).softmax(dim = -1) + return topk_attn.detach() + soft_attn - soft_attn.detach() # functions for creating causal mask # need a special one for onnx cpu (no support for .triu) @@ -141,6 +164,7 @@ def __init__( post_talking_heads = False, pre_scale_post_talking_heads = False, sparse_topk = None, + sparse_topk_straight_through = False, scale = None, qk_norm = False, l2_distance = False, @@ -152,7 +176,6 @@ def __init__( add_zero_kv = False, selective = False, hard = False, - sigsoftmax = False, cope = None, onnxable = False, sdp_kwargs: dict = dict( @@ -171,9 +194,13 @@ def __init__( # attention type + is_sparse_topk_attn = exists(sparse_topk) + assert not (flash and sigmoid), 'sigmoid attention not available for flash' assert not (flash and hard), 'hard attention not available for flash' - assert at_most_one_of(sigmoid, hard, l2_distance) + assert not (flash and is_sparse_topk_attn), 'topk attention not available for flash' + + assert at_most_one_of(sigmoid, hard, l2_distance, is_sparse_topk_attn) if exists(custom_attn_fn): self.attn_fn = custom_attn_fn @@ -181,6 +208,8 @@ def __init__( self.attn_fn = F.sigmoid elif hard: self.attn_fn = one_hot_straight_through + elif is_sparse_topk_attn: + self.attn_fn = partial(sparse_topk_attn, sparse_topk = sparse_topk, straight_through = sparse_topk_straight_through) else: softmax_fn = partial(F.softmax, dim = -1) self.attn_fn = partial(softmax_fn, dtype = torch.float32) if not qk_norm else softmax_fn @@ -214,16 +243,6 @@ def __init__( assert not (selective and not causal), 'selective attention is designed for autoregressive' self.selective = selective - # sparse topk - - assert not (flash and sparse_topk), 'sparse topk not compatible with flash attention' - self.sparse_topk = sparse_topk - - # sig softmax - - assert not (flash and sigsoftmax), 'sigsoftmax not available for flash attention' - self.sigsoftmax = sigsoftmax - # l2 distance attention self.l2_distance = l2_distance @@ -476,11 +495,6 @@ def forward( causal_mask = self.create_causal_mask(i, j, device = device) sim = sim.masked_fill(causal_mask, mask_value) - if exists(self.sparse_topk): - top_values, _ = sim.topk(self.sparse_topk, dim = -1) - sparse_topk_mask = (sim >= top_values[..., -1:]) & (sim > mask_value) - sim = sim.masked_fill(~sparse_topk_mask, mask_value) - row_is_entirely_masked = None if exists(mask): @@ -494,9 +508,6 @@ def forward( pre_softmax_attn = sim - if self.sigsoftmax: - sim = sim + sim.sigmoid().log() - attn = self.attn_fn(sim) attn = attn.type(dtype) diff --git a/x_transformers/x_transformers.py b/x_transformers/x_transformers.py index a04d8d6f..d2920329 100644 --- a/x_transformers/x_transformers.py +++ b/x_transformers/x_transformers.py @@ -912,6 +912,7 @@ def __init__( pre_scale_post_talking_heads = False, head_scale = False, sparse_topk = None, + sparse_topk_straight_through = False, num_mem_kv = 0, dropout = 0., on_attn = False, @@ -920,7 +921,6 @@ def __init__( gate_values = False, zero_init_output = False, hard = False, - sigsoftmax = False, max_attend_past = None, qk_norm = False, qk_norm_groups = 1, @@ -1044,6 +1044,7 @@ def __init__( pre_scale_post_talking_heads = pre_scale_post_talking_heads, dropout = dropout, sparse_topk = sparse_topk, + sparse_topk_straight_through = sparse_topk_straight_through, hard = hard, qk_norm = qk_norm, scale = qk_norm_scale if qk_norm else self.scale, @@ -1054,7 +1055,6 @@ def __init__( add_zero_kv = add_zero_kv, flash = flash, softclamp_logits = softclamp_logits, - sigsoftmax = sigsoftmax, logit_softclamp_value = logit_softclamp_value, cope = cope, onnxable = onnxable