From 428c1e5ffb634d506e1a96813cfff3c6d8765c87 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sat, 18 Jun 2022 08:48:31 -0700 Subject: [PATCH] release palm lite version, thanks to @conceptofmind --- palm_pytorch/palm_lite.py | 38 +++++++++++++++++++------------------- setup.py | 3 ++- 2 files changed, 21 insertions(+), 20 deletions(-) diff --git a/palm_pytorch/palm_lite.py b/palm_pytorch/palm_lite.py index 000d6cb..cb37bdb 100644 --- a/palm_pytorch/palm_lite.py +++ b/palm_pytorch/palm_lite.py @@ -97,7 +97,7 @@ def __init__(self, dim, dim_head=64, heads=8, ff_mult=4): attn_inner_dim = dim_head * heads ff_inner_dim = dim * ff_mult - self.fused_dims = (attn_inner_dim, dim_head, dim_head, (ff_inner_dim * 2)) + self.fused_dims = (attn_inner_dim, dim_head, (ff_inner_dim * 2)) self.heads = heads self.scale = dim_head**-0.5 @@ -139,9 +139,9 @@ def forward(self, x): x = self.norm(x) - # attention queries, keys, values, and feedforward inner + # attention queries, keys or values (shared key / values is a personal discovery of mine), and feedforward inner - q, k, v, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1) + q, kv, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1) # split heads # they use multi-query single-key-value attention, yet another Noam Shazeer paper @@ -156,7 +156,7 @@ def forward(self, x): # similarity - sim = einsum("b h i d, b j d -> b h i j", q, k) + sim = einsum("b h i d, b j d -> b h i j", q, kv) # add the alibi bias @@ -170,7 +170,7 @@ def forward(self, x): # attention attn = sim.softmax(dim=-1) - out = einsum("b h i j, b j d -> b h i d", attn, v) + out = einsum("b h i j, b j d -> b h i d", attn, kv) # merge heads @@ -197,21 +197,21 @@ def PaLM(*, dim, num_tokens, depth, dim_head=64, heads=8, ff_mult=4): # For testing functionality of the model -# if __name__ == "__main__": +if __name__ == "__main__": -# palm = PaLM( -# num_tokens = 20000, -# dim = 512, -# depth = 1, -# heads = 8, -# dim_head = 64, -# ) + palm = PaLM( + num_tokens = 20000, + dim = 512, + depth = 1, + heads = 8, + dim_head = 64, + ) -# tokens = torch.randint(0, 20000, (1, 2048)) -# logits = palm(tokens) # (1, 2048, 20000) + tokens = torch.randint(0, 20000, (1, 2048)) + logits = palm(tokens) # (1, 2048, 20000) -# n_params_torch = sum( -# p.numel() for p in palm.parameters() if p.requires_grad -# ) + n_params_torch = sum( + p.numel() for p in palm.parameters() if p.requires_grad + ) -# print(f"Number of parameters in torch model: {n_params_torch}") \ No newline at end of file + print(f"Number of parameters in torch model: {n_params_torch}") diff --git a/setup.py b/setup.py index c53f640..98e9ffd 100644 --- a/setup.py +++ b/setup.py @@ -3,11 +3,12 @@ setup( name="PaLM-pytorch", packages=find_packages(exclude=[]), - version="0.1.0", + version="0.2.0", license="MIT", description="PaLM: Scaling Language Modeling with Pathways - Pytorch", author="Phil Wang", author_email="lucidrains@gmail.com", + long_description_content_type = 'text/markdown', url="https://github.com/lucidrains/PaLM-pytorch", keywords=[ "artificial general intelligence",