From 49232e12bb297db927810409c827d98ef29dc600 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Mon, 4 Apr 2022 16:47:12 -0700 Subject: [PATCH] fix prelayernorm in attention --- palm_pytorch/palm_pytorch.py | 7 +++++-- setup.py | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/palm_pytorch/palm_pytorch.py b/palm_pytorch/palm_pytorch.py index 5a18052..1bce6c1 100644 --- a/palm_pytorch/palm_pytorch.py +++ b/palm_pytorch/palm_pytorch.py @@ -99,14 +99,17 @@ def forward(self, x): """ n, device, h = x.shape[1], x.device, self.heads - q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1)) # pre layernorm x = self.norm(x) + # queries, keys, values + + q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1)) + # split heads - # they use multi-query attention, yet another Noam Shazeer paper + # they use multi-query single-key-value attention, yet another Noam Shazeer paper # they found no performance loss past a certain scale, and more efficient decoding obviously # https://arxiv.org/abs/1911.02150 diff --git a/setup.py b/setup.py index 0a13899..a4fb367 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'PaLM-pytorch', packages = find_packages(exclude=[]), - version = '0.0.8', + version = '0.0.9', license='MIT', description = 'PaLM: Scaling Language Modeling with Pathways - Pytorch', author = 'Phil Wang',