Skip to content

Commit

Permalink
fix rotary embedding caching
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Apr 5, 2022
1 parent 05a7c46 commit 7889d97
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion palm_pytorch/palm_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def get_rotary_embedding(self, n, device):
return self.pos_emb[:n]

pos_emb = self.rotary_emb(n, device=device)
self.register_buffer("position", pos_emb, persistent=False)
self.register_buffer("pos_emb", pos_emb, persistent=False)
return pos_emb

def forward(self, x):
Expand Down
2 changes: 1 addition & 1 deletion palm_pytorch/triton/palm.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def get_rotary_embedding(self, n, device):
return self.pos_emb[:n]

pos_emb = self.rotary_emb(n, device=device)
self.register_buffer("position", pos_emb, persistent=False)
self.register_buffer("pos_emb", pos_emb, persistent=False)
return pos_emb

def forward(self, x):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name="PaLM-pytorch",
packages=find_packages(exclude=[]),
version="0.0.11",
version="0.0.12",
license="MIT",
description="PaLM: Scaling Language Modeling with Pathways - Pytorch",
author="Phil Wang",
Expand Down
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset

from palm_pytorch.triton import PaLM
from palm_pytorch import PaLM
from palm_pytorch.autoregressive_wrapper import AutoregressiveWrapper

# constants
Expand Down

0 comments on commit 7889d97

Please sign in to comment.