diff --git a/nGPT_pytorch/nTransformer.py b/nGPT_pytorch/nTransformer.py index 761ab19..ba618b6 100644 --- a/nGPT_pytorch/nTransformer.py +++ b/nGPT_pytorch/nTransformer.py @@ -157,7 +157,8 @@ def __init__( enable_mem_efficient = True ), norm_eps = 0., - num_hyperspheres = 1 + num_hyperspheres = 1, + rotary_embed = True ): super().__init__() self.heads = heads @@ -182,7 +183,7 @@ def __init__( # rotary - self.rotary_emb = RotaryEmbedding(dim_head) + self.rotary_emb = RotaryEmbedding(dim_head) if rotary_embed else None # qk rmsnorm + scale @@ -218,8 +219,9 @@ def forward( # rotary positions - q = self.rotary_emb.rotate_queries_or_keys(q) - k = self.rotary_emb.rotate_queries_or_keys(k) + if exists(self.rotary_emb): + q = self.rotary_emb.rotate_queries_or_keys(q) + k = self.rotary_emb.rotate_queries_or_keys(k) # for non-autoregressive masking @@ -295,6 +297,7 @@ def __init__( tied_embedding = False, num_hyperspheres = 1, causal = True, + rotary_embed = False, # below are all the scale related hyperparameters, for controlling effective relative learning rates throughout the network alpha_init: float | None = None, # this would set the alpha init for all residuals, but would be overridden by alpha_attn_init and alpha_ff_init if they are specified s_logit_init: float = 1., @@ -364,7 +367,8 @@ def __init__( s_qk_scale = s_qk_scale_, flash_kwargs = attn_flash_kwargs, norm_eps = norm_eps, - num_hyperspheres = num_hyperspheres + num_hyperspheres = num_hyperspheres, + rotary_embed = rotary_embed ) ff = FeedForward( diff --git a/pyproject.toml b/pyproject.toml index 65257ae..cdcf4f6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "nGPT-pytorch" -version = "0.1.15" +version = "0.1.16" description = "nGPT" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }