Skip to content

Commit

Permalink
make sure cross attention context is also kv cached
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 31, 2024
1 parent ecb90fc commit 1ec66e9
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 8 deletions.
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ test=pytest
[tool:pytest]
addopts = --verbose -s
python_files = tests/*.py
python_paths = "."
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '1.30.3',
version = '1.30.4',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand Down
15 changes: 9 additions & 6 deletions tests/test_kv_cache.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch

from x_transformers import (
from x_transformers.x_transformers import (
TransformerWrapper,
Decoder,
AutoregressiveWrapper
Expand All @@ -13,24 +13,27 @@ def test_kv_cache():
max_seq_len = 1024,
attn_layers = Decoder(
dim = 8,
depth = 1,
heads = 4
depth = 2,
heads = 4,
cross_attend = True
)
)

model.eval()

prompts = torch.zeros((1, 16))
prompts = torch.zeros((2, 16))
context = torch.randn(2, 8, 8)

logits, cache = model(
prompts,
context = context,
return_intermediates = True
)

sampled = logits[:, -1].argmax(dim = -1, keepdim = True)
prompts = torch.cat((prompts, sampled), dim = -1)

next_logits = model(prompts)
next_logits_with_cache = model(prompts, cache = cache)
next_logits = model(prompts, context = context)
next_logits_with_cache = model(prompts, context = context, cache = cache)

assert torch.allclose(next_logits[:, -1], next_logits_with_cache[:, -1], atol = 1e-6)
5 changes: 4 additions & 1 deletion x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,7 +868,7 @@ def forward(

k, v, r = map(lambda t: maybe(rearrange)(t, 'b n (h d) -> b h n d', h = kv_h), (k, v, r))

if exists(cache) and not has_context:
if exists(cache):
ck, cv = cache.cached_kv

if exists(mem):
Expand Down Expand Up @@ -1338,6 +1338,9 @@ def forward(
if exists(cache):
assert not self.training and self.causal and not any([*map(exists, (mask, attn_mask))])

if exists(context):
context = context[:, :0]

if cache_age > 0:
x = x[:, -cache_age:] # for spec decoding, may be greater than 1

Expand Down

0 comments on commit 1ec66e9

Please sign in to comment.