Skip to content

Commit

Permalink
add same functionality as add_zero_attn in pytorch mha, with attn_add…
Browse files Browse the repository at this point in the history
…_zero_kv = True
  • Loading branch information
lucidrains committed Jul 25, 2023
1 parent 3451614 commit db58c04
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 1 deletion.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '1.16.21',
version = '1.16.22',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand Down
15 changes: 15 additions & 0 deletions x_transformers/attend.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def __init__(
scale = None,
qk_norm = False,
flash = False,
add_zero_kv = False,
onnxable = False
):
super().__init__()
Expand Down Expand Up @@ -102,6 +103,11 @@ def __init__(
assert not (flash and sparse_topk), 'sparse topk not compatible with flash attention'
self.sparse_topk = sparse_topk

# add a key / value token composed of zeros
# in case this helps controlling outliers, proposed by https://www.evanmiller.org/attention-is-off-by-one.html

self.add_zero_kv = add_zero_kv

# flash attention

self.flash = flash
Expand Down Expand Up @@ -221,6 +227,15 @@ def forward(

scale = default(self.scale, q.shape[-1] ** -0.5)

if self.add_zero_kv:
k, v = map(lambda t: F.pad(t, (0, 0, 1, 0), value = 0.), (k, v))

if exists(mask):
mask = F.pad(mask, (1, 0), value = True)

if exists(attn_bias):
attn_bias = F.pad(attn_bias, (1, 0), value = 0.)

if self.flash:
assert not exists(prev_attn), 'residual attention not compatible with flash attention'
return self.flash_attn(q, k, v, mask = mask, attn_bias = attn_bias)
Expand Down
2 changes: 2 additions & 0 deletions x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,7 @@ def __init__(
value_dim_head = None,
tensor_product = False, # https://arxiv.org/abs/2208.06061
cascading_heads = False,
add_zero_kv = False, # same as add_zero_attn in pytorch
onnxable = False
):
super().__init__()
Expand Down Expand Up @@ -692,6 +693,7 @@ def __init__(
sparse_topk = sparse_topk,
qk_norm = qk_norm,
scale = qk_norm_scale if qk_norm else self.scale,
add_zero_kv = add_zero_kv,
flash = flash,
onnxable = onnxable
)
Expand Down

0 comments on commit db58c04

Please sign in to comment.