Skip to content

Commit

Permalink
release palm lite version, thanks to @conceptofmind
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jun 18, 2022
1 parent a3b02a2 commit 428c1e5
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 20 deletions.
38 changes: 19 additions & 19 deletions palm_pytorch/palm_lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def __init__(self, dim, dim_head=64, heads=8, ff_mult=4):

attn_inner_dim = dim_head * heads
ff_inner_dim = dim * ff_mult
self.fused_dims = (attn_inner_dim, dim_head, dim_head, (ff_inner_dim * 2))
self.fused_dims = (attn_inner_dim, dim_head, (ff_inner_dim * 2))

self.heads = heads
self.scale = dim_head**-0.5
Expand Down Expand Up @@ -139,9 +139,9 @@ def forward(self, x):

x = self.norm(x)

# attention queries, keys, values, and feedforward inner
# attention queries, keys or values (shared key / values is a personal discovery of mine), and feedforward inner

q, k, v, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1)
q, kv, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1)

# split heads
# they use multi-query single-key-value attention, yet another Noam Shazeer paper
Expand All @@ -156,7 +156,7 @@ def forward(self, x):

# similarity

sim = einsum("b h i d, b j d -> b h i j", q, k)
sim = einsum("b h i d, b j d -> b h i j", q, kv)

# add the alibi bias

Expand All @@ -170,7 +170,7 @@ def forward(self, x):
# attention

attn = sim.softmax(dim=-1)
out = einsum("b h i j, b j d -> b h i d", attn, v)
out = einsum("b h i j, b j d -> b h i d", attn, kv)

# merge heads

Expand All @@ -197,21 +197,21 @@ def PaLM(*, dim, num_tokens, depth, dim_head=64, heads=8, ff_mult=4):

# For testing functionality of the model

# if __name__ == "__main__":
if __name__ == "__main__":

# palm = PaLM(
# num_tokens = 20000,
# dim = 512,
# depth = 1,
# heads = 8,
# dim_head = 64,
# )
palm = PaLM(
num_tokens = 20000,
dim = 512,
depth = 1,
heads = 8,
dim_head = 64,
)

# tokens = torch.randint(0, 20000, (1, 2048))
# logits = palm(tokens) # (1, 2048, 20000)
tokens = torch.randint(0, 20000, (1, 2048))
logits = palm(tokens) # (1, 2048, 20000)

# n_params_torch = sum(
# p.numel() for p in palm.parameters() if p.requires_grad
# )
n_params_torch = sum(
p.numel() for p in palm.parameters() if p.requires_grad
)

# print(f"Number of parameters in torch model: {n_params_torch}")
print(f"Number of parameters in torch model: {n_params_torch}")
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
setup(
name="PaLM-pytorch",
packages=find_packages(exclude=[]),
version="0.1.0",
version="0.2.0",
license="MIT",
description="PaLM: Scaling Language Modeling with Pathways - Pytorch",
author="Phil Wang",
author_email="[email protected]",
long_description_content_type = 'text/markdown',
url="https://github.com/lucidrains/PaLM-pytorch",
keywords=[
"artificial general intelligence",
Expand Down

0 comments on commit 428c1e5

Please sign in to comment.