Skip to content

Commit

Permalink
able to turn off rotary for nTransformer
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 28, 2024
1 parent 9c1f52c commit 6fcd3c6
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
14 changes: 9 additions & 5 deletions nGPT_pytorch/nTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,8 @@ def __init__(
enable_mem_efficient = True
),
norm_eps = 0.,
num_hyperspheres = 1
num_hyperspheres = 1,
rotary_embed = True
):
super().__init__()
self.heads = heads
Expand All @@ -182,7 +183,7 @@ def __init__(

# rotary

self.rotary_emb = RotaryEmbedding(dim_head)
self.rotary_emb = RotaryEmbedding(dim_head) if rotary_embed else None

# qk rmsnorm + scale

Expand Down Expand Up @@ -218,8 +219,9 @@ def forward(

# rotary positions

q = self.rotary_emb.rotate_queries_or_keys(q)
k = self.rotary_emb.rotate_queries_or_keys(k)
if exists(self.rotary_emb):
q = self.rotary_emb.rotate_queries_or_keys(q)
k = self.rotary_emb.rotate_queries_or_keys(k)

# for non-autoregressive masking

Expand Down Expand Up @@ -295,6 +297,7 @@ def __init__(
tied_embedding = False,
num_hyperspheres = 1,
causal = True,
rotary_embed = False,
# below are all the scale related hyperparameters, for controlling effective relative learning rates throughout the network
alpha_init: float | None = None, # this would set the alpha init for all residuals, but would be overridden by alpha_attn_init and alpha_ff_init if they are specified
s_logit_init: float = 1.,
Expand Down Expand Up @@ -364,7 +367,8 @@ def __init__(
s_qk_scale = s_qk_scale_,
flash_kwargs = attn_flash_kwargs,
norm_eps = norm_eps,
num_hyperspheres = num_hyperspheres
num_hyperspheres = num_hyperspheres,
rotary_embed = rotary_embed
)

ff = FeedForward(
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "nGPT-pytorch"
version = "0.1.15"
version = "0.1.16"
description = "nGPT"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down

0 comments on commit 6fcd3c6

Please sign in to comment.