From 7889d9726be5694d36795f7d0505756dc908b577 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Tue, 5 Apr 2022 16:58:01 -0700 Subject: [PATCH] fix rotary embedding caching --- palm_pytorch/palm_pytorch.py | 2 +- palm_pytorch/triton/palm.py | 2 +- setup.py | 2 +- train.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/palm_pytorch/palm_pytorch.py b/palm_pytorch/palm_pytorch.py index c97267a..ab5e3f8 100644 --- a/palm_pytorch/palm_pytorch.py +++ b/palm_pytorch/palm_pytorch.py @@ -111,7 +111,7 @@ def get_rotary_embedding(self, n, device): return self.pos_emb[:n] pos_emb = self.rotary_emb(n, device=device) - self.register_buffer("position", pos_emb, persistent=False) + self.register_buffer("pos_emb", pos_emb, persistent=False) return pos_emb def forward(self, x): diff --git a/palm_pytorch/triton/palm.py b/palm_pytorch/triton/palm.py index efac08a..38bb436 100644 --- a/palm_pytorch/triton/palm.py +++ b/palm_pytorch/triton/palm.py @@ -95,7 +95,7 @@ def get_rotary_embedding(self, n, device): return self.pos_emb[:n] pos_emb = self.rotary_emb(n, device=device) - self.register_buffer("position", pos_emb, persistent=False) + self.register_buffer("pos_emb", pos_emb, persistent=False) return pos_emb def forward(self, x): diff --git a/setup.py b/setup.py index 12718c3..13109b8 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name="PaLM-pytorch", packages=find_packages(exclude=[]), - version="0.0.11", + version="0.0.12", license="MIT", description="PaLM: Scaling Language Modeling with Pathways - Pytorch", author="Phil Wang", diff --git a/train.py b/train.py index 64f6c55..28d9ab1 100644 --- a/train.py +++ b/train.py @@ -8,7 +8,7 @@ from torch.nn import functional as F from torch.utils.data import DataLoader, Dataset -from palm_pytorch.triton import PaLM +from palm_pytorch import PaLM from palm_pytorch.autoregressive_wrapper import AutoregressiveWrapper # constants