diff --git a/setup.py b/setup.py index 8b7a52bb..6f7d1dd5 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'x-transformers', packages = find_packages(exclude=['examples']), - version = '1.37.7', + version = '1.37.8', license='MIT', description = 'X-Transformers - Pytorch', author = 'Phil Wang', diff --git a/x_transformers/attend.py b/x_transformers/attend.py index 48828a7d..665695c7 100644 --- a/x_transformers/attend.py +++ b/x_transformers/attend.py @@ -211,12 +211,12 @@ def flash_attn( if self.l2_distance: k_norm_sq = k.norm(dim = -1, keepdim = True) ** 2 - k = F.pad(k, (0, 1), value = 1.) - k = torch.cat((k, -k_norm_sq), dim = -1) + k = F.pad(k, (0, 1), value = -1.) + k = torch.cat((k, k_norm_sq), dim = -1) q_norm_sq = q.norm(dim = -1, keepdim = True) ** 2 - q = torch.cat((2 * q, -q_norm_sq), dim = -1) - q = F.pad(q, (0, 1), value = 1.) + q = torch.cat((2 * q, q_norm_sq), dim = -1) + q = F.pad(q, (0, 1), value = -1.) # handle scale - by default they scale by dim_head ** -0.5, but need to take care if using cosine sim attention