Skip to content

Commit

Permalink
offer l2 distance attention for starters and cite
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Sep 30, 2024
1 parent f1a8e66 commit 2d26af6
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 4 deletions.
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2239,6 +2239,16 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
booktitle = {Neural Information Processing Systems},
year = {2018},
url = {https://api.semanticscholar.org/CorpusID:44064935}
```

```bibtex
@article{Kim2020TheLC,
title = {The Lipschitz Constant of Self-Attention},
author = {Hyunjik Kim and George Papamakarios and Andriy Mnih},
journal = {ArXiv},
year = {2020},
volume = {abs/2006.04710},
url = {https://api.semanticscholar.org/CorpusID:219530837}
}
```

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '1.37.4',
version = '1.37.6',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand Down
19 changes: 19 additions & 0 deletions tests/test_x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,3 +301,22 @@ def test_sigsoftmax():
model.eval()

eval_logits = model(x)

@pytest.mark.parametrize('attn_one_kv_head', (True, False))
def test_l2_distance(attn_one_kv_head):

model = TransformerWrapper(
num_tokens = 20000,
max_seq_len = 1024,
attn_layers = Decoder(
dim = 512,
depth = 12,
heads = 8,
attn_l2_distance = True,
attn_one_kv_head = attn_one_kv_head,
)
)

x = torch.randint(0, 256, (1, 1024))

model(x)
36 changes: 33 additions & 3 deletions x_transformers/attend.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from packaging import version
from dataclasses import dataclass

from einops import rearrange, repeat
from einops import rearrange, repeat, pack, unpack

# constants

Expand All @@ -39,9 +39,16 @@ def default(val, d):
def compact(arr):
return [*filter(exists, arr)]

def softclamp(t, value):
@torch.jit.script
def softclamp(t: Tensor, value: float):
return (t / value).tanh() * value

def pack_one(t, pattern):
return pack([t], pattern)

def unpack_one(t, ps, pattern):
return unpack(t, ps, pattern)[0]

def once(fn):
called = False
@wraps(fn)
Expand All @@ -55,6 +62,18 @@ def inner(x):

print_once = once(print)

# alternative distance functions

def qk_l2_dist_squared(q, k):
if k.ndim == 3:
k = repeat(k, 'b j d -> b h j d', h = q.shape[1])

q, packed_shape = pack_one(q, '* i d')
k, _ = pack_one(k, '* j d')

l2_dist_squared = torch.cdist(q, k) ** 2
return unpack_one(l2_dist_squared, packed_shape, '* i j')

# functions for creating causal mask
# need a special one for onnx cpu (no support for .triu)

Expand All @@ -80,6 +99,7 @@ def __init__(
sparse_topk = None,
scale = None,
qk_norm = False,
l2_distance = False,
flash = False,
softclamp_logits = False,
logit_softclamp_value = 50.,
Expand Down Expand Up @@ -123,6 +143,11 @@ def __init__(
assert not (flash and sigsoftmax), 'sigsoftmax not available for flash attention'
self.sigsoftmax = sigsoftmax

# l2 distance attention

assert not (flash and l2_distance), 'l2 distance attention does not work with flash attention just yet'
self.l2_distance = l2_distance

# add a key / value token composed of zeros
# in case this helps controlling outliers, proposed by https://www.evanmiller.org/attention-is-off-by-one.html

Expand Down Expand Up @@ -325,7 +350,12 @@ def forward(

kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d'

sim = einsum(f'b h i d, {kv_einsum_eq} -> b h i j', q, k) * scale
if not self.l2_distance:
sim = einsum(f'b h i d, {kv_einsum_eq} -> b h i j', q, k)
else:
sim = -qk_l2_dist_squared(q, k)

sim = sim * scale

if exists(prev_attn):
sim = sim + prev_attn
Expand Down
2 changes: 2 additions & 0 deletions x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -923,6 +923,7 @@ def __init__(
qk_norm_groups = 1,
qk_norm_scale = 10,
qk_norm_dim_scale = False,
l2_distance = False,
one_kv_head = False,
kv_heads = None,
shared_kv = False,
Expand Down Expand Up @@ -1037,6 +1038,7 @@ def __init__(
sparse_topk = sparse_topk,
qk_norm = qk_norm,
scale = qk_norm_scale if qk_norm else self.scale,
l2_distance = l2_distance,
add_zero_kv = add_zero_kv,
flash = flash,
softclamp_logits = softclamp_logits,
Expand Down

0 comments on commit 2d26af6

Please sign in to comment.