From 1ec66e9851c4231fbe82ecc6bf23030fa836e681 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Fri, 31 May 2024 11:30:01 -0700 Subject: [PATCH] make sure cross attention context is also kv cached --- setup.cfg | 1 + setup.py | 2 +- tests/test_kv_cache.py | 15 +++++++++------ x_transformers/x_transformers.py | 5 ++++- 4 files changed, 15 insertions(+), 8 deletions(-) diff --git a/setup.cfg b/setup.cfg index d1844e11..52667225 100644 --- a/setup.cfg +++ b/setup.cfg @@ -4,3 +4,4 @@ test=pytest [tool:pytest] addopts = --verbose -s python_files = tests/*.py +python_paths = "." diff --git a/setup.py b/setup.py index 91803f68..f919e34b 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'x-transformers', packages = find_packages(exclude=['examples']), - version = '1.30.3', + version = '1.30.4', license='MIT', description = 'X-Transformers - Pytorch', author = 'Phil Wang', diff --git a/tests/test_kv_cache.py b/tests/test_kv_cache.py index 3993f44a..860cf5a5 100644 --- a/tests/test_kv_cache.py +++ b/tests/test_kv_cache.py @@ -1,6 +1,6 @@ import torch -from x_transformers import ( +from x_transformers.x_transformers import ( TransformerWrapper, Decoder, AutoregressiveWrapper @@ -13,24 +13,27 @@ def test_kv_cache(): max_seq_len = 1024, attn_layers = Decoder( dim = 8, - depth = 1, - heads = 4 + depth = 2, + heads = 4, + cross_attend = True ) ) model.eval() - prompts = torch.zeros((1, 16)) + prompts = torch.zeros((2, 16)) + context = torch.randn(2, 8, 8) logits, cache = model( prompts, + context = context, return_intermediates = True ) sampled = logits[:, -1].argmax(dim = -1, keepdim = True) prompts = torch.cat((prompts, sampled), dim = -1) - next_logits = model(prompts) - next_logits_with_cache = model(prompts, cache = cache) + next_logits = model(prompts, context = context) + next_logits_with_cache = model(prompts, context = context, cache = cache) assert torch.allclose(next_logits[:, -1], next_logits_with_cache[:, -1], atol = 1e-6) diff --git a/x_transformers/x_transformers.py b/x_transformers/x_transformers.py index 94947073..7e640b47 100644 --- a/x_transformers/x_transformers.py +++ b/x_transformers/x_transformers.py @@ -868,7 +868,7 @@ def forward( k, v, r = map(lambda t: maybe(rearrange)(t, 'b n (h d) -> b h n d', h = kv_h), (k, v, r)) - if exists(cache) and not has_context: + if exists(cache): ck, cv = cache.cached_kv if exists(mem): @@ -1338,6 +1338,9 @@ def forward( if exists(cache): assert not self.training and self.causal and not any([*map(exists, (mask, attn_mask))]) + if exists(context): + context = context[:, :0] + if cache_age > 0: x = x[:, -cache_age:] # for spec decoding, may be greater than 1