From 832464f9458dece74fe935143bde53087e4fd2bf Mon Sep 17 00:00:00 2001 From: lucidrains Date: Sat, 6 Jul 2024 06:55:43 -0700 Subject: [PATCH] handle key padding mask directly passed into Attend --- setup.py | 2 +- x_transformers/attend.py | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 95f4cf15..9f56293b 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'x-transformers', packages = find_packages(exclude=['examples']), - version = '1.31.8', + version = '1.31.9', license='MIT', description = 'X-Transformers - Pytorch', author = 'Phil Wang', diff --git a/x_transformers/attend.py b/x_transformers/attend.py index 852b92d1..cd30ae24 100644 --- a/x_transformers/attend.py +++ b/x_transformers/attend.py @@ -268,6 +268,11 @@ def forward( causal = self.causal + # handle key padding mask + + if exists(mask) and mask.ndim == 2: + mask = rearrange(mask, 'b j -> b 1 1 j') + # handle kv cached decoding if n == 1 and causal: