diff --git a/nGPT_pytorch/nGPT.py b/nGPT_pytorch/nGPT.py index c66247c..9022f27 100644 --- a/nGPT_pytorch/nGPT.py +++ b/nGPT_pytorch/nGPT.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from functools import partial import torch @@ -19,6 +21,11 @@ def exists(v): def default(v, d): return v if exists(v) else d +def cast_tuple(t, length = 1): + out = t if isinstance(t, tuple) else ((t,) * length) + assert len(out) == length + return out + def l2norm(t, dim = -1): return F.normalize(t, dim = dim, p = 2) @@ -101,7 +108,9 @@ def __init__( dim_head = 64, heads = 8, norm_qk = True, - manual_norm_weights = False + manual_norm_weights = False, + s_qk_init = 1., + s_qk_scale = None ): super().__init__() NormLinear_ = partial(NormLinear, parametrize = not manual_norm_weights) @@ -167,7 +176,11 @@ def __init__( dim, *, expand_factor = 4, - manual_norm_weights = False + manual_norm_weights = False, + s_hidden_init = 1., + s_hidden_scale = 1., + s_gate_init = 1., + s_gate_scale = 1. ): super().__init__() NormLinear_ = partial(NormLinear, parametrize = not manual_norm_weights) @@ -178,8 +191,8 @@ def __init__( self.to_hidden = NormLinear_(dim, dim_inner) self.to_gate = NormLinear_(dim, dim_inner) - self.hidden_scale = Scale(dim_inner) - self.gate_scale = Scale(dim_inner) + self.hidden_scale = Scale(dim_inner, s_hidden_init, s_hidden_scale) + self.gate_scale = Scale(dim_inner, s_gate_init, s_gate_scale) self.to_out = NormLinear_(dim_inner, dim, norm_dim_in = False) @@ -206,31 +219,98 @@ def __init__( attn_norm_qk = True, # they say the query/key normalization is optional ff_expand_factor = 4., ce_ignore_index = -1, - residual_lerp_scale_init = None, manual_norm_weights = False, - tied_embedding = False + tied_embedding = False, + # below are all the scale related hyperparameters, for controlling effective relative learning rates throughout the network + alpha_init: float | None = None, # this would set the alpha init for all residuals, but would be overridden by alpha_attn_init and alpha_ff_init if they are specified + s_logit_init: float = 1., + s_logit_scale: float | None = None, + alpha_attn_init: float | tuple[float, ...] | None = None, + alpha_attn_scale: float | tuple[float, ...] | None = None, + alpha_ff_init: float | tuple[float, ...] | None = None, + alpha_ff_scale: float | tuple[float, ...] | None = None, + s_qk_init: float | tuple[float, ...] = 1., + s_qk_scale: float | tuple[float, ...] | None = None, + s_ff_hidden_init: float | tuple[float, ...] = 1., + s_ff_hidden_scale: float | tuple[float, ...] = 1., + s_ff_gate_init: float | tuple[float, ...] = 1., + s_ff_gate_scale: float | tuple[float, ...] = 1. ): super().__init__() NormLinear_ = partial(NormLinear, parametrize = not manual_norm_weights) self.dim = dim - residual_lerp_scale_init = default(residual_lerp_scale_init, 1. / depth) + alpha_init = default(alpha_init, 1. / depth) self.token_embed = NormLinear_(dim, num_tokens) self.layers = ModuleList([]) - for _ in range(depth): - self.layers.append(ModuleList([ - Attention(dim, dim_head = dim_head, heads = heads, norm_qk = attn_norm_qk, manual_norm_weights = manual_norm_weights), - FeedForward(dim, expand_factor = ff_expand_factor, manual_norm_weights = manual_norm_weights), - Scale(dim, residual_lerp_scale_init, dim ** -0.5), - Scale(dim, residual_lerp_scale_init, dim ** -0.5), - ])) + scale_hparams = ( + alpha_attn_init, + alpha_attn_scale, + alpha_ff_init, + alpha_ff_scale, + s_qk_init, + s_qk_scale, + s_ff_hidden_init, + s_ff_hidden_scale, + s_ff_gate_init, + s_ff_gate_scale + ) + + scale_hparams = tuple(cast_tuple(hparam, depth) for hparam in scale_hparams) + + for ( + alpha_attn_init_, + alpha_attn_scale_, + alpha_ff_init_, + alpha_ff_scale_, + s_qk_init_, + s_qk_scale_, + s_ff_hidden_init_, + s_ff_hidden_scale_, + s_ff_gate_init_, + s_ff_gate_scale_ + ) in zip(*scale_hparams): + + attn = Attention( + dim, + dim_head = dim_head, + heads = heads, + norm_qk = attn_norm_qk, + manual_norm_weights = manual_norm_weights, + s_qk_init = s_qk_init_, + s_qk_scale = s_qk_scale_, + ) + + ff = FeedForward( + dim, + expand_factor = ff_expand_factor, + manual_norm_weights = manual_norm_weights, + s_hidden_init = s_ff_hidden_init_, + s_hidden_scale = s_ff_hidden_scale_, + s_gate_init = s_ff_gate_init_, + s_gate_scale = s_ff_gate_scale_ + ) + + attn_interp_factor = Scale( + dim, + default(alpha_attn_init_, alpha_init), + default(alpha_attn_scale_, dim ** -0.5) + ) + + ff_interp_factor = Scale( + dim, + default(alpha_ff_init_, alpha_init), + default(alpha_ff_scale_, dim ** -0.5) + ) + + self.layers.append(ModuleList([attn, ff, attn_interp_factor, ff_interp_factor])) self.to_logits = NormLinear_(dim, num_tokens) if not tied_embedding else None - self.logit_scale = Scale(num_tokens, 1., dim ** -0.5) + self.logit_scale = Scale(num_tokens, s_logit_init, default(s_logit_scale, dim ** -0.5)) self.ignore_index = ce_ignore_index diff --git a/pyproject.toml b/pyproject.toml index 27e42fe..f0d2dc5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "nGPT-pytorch" -version = "0.0.11" +version = "0.0.12" description = "nGPT" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/train.py b/train.py index f69468b..2bf8dec 100644 --- a/train.py +++ b/train.py @@ -92,6 +92,7 @@ def base_decoding( dim = 512, depth = 8, manual_norm_weights = True, + tied_embedding = True ).to(device) # prepare enwik8 data