From fa76dafd9d11c536e5fa38eb8e928ed174e1a67d Mon Sep 17 00:00:00 2001 From: lucidrains Date: Fri, 22 Nov 2024 17:55:30 +0000 Subject: [PATCH] seeing really big improvements with per token learned value residual mixing values --- setup.py | 2 +- tests/test_x_transformers.py | 6 +++++- x_transformers/x_transformers.py | 29 ++++++++++++++++++++++++----- 3 files changed, 30 insertions(+), 7 deletions(-) diff --git a/setup.py b/setup.py index 9245bccc..ec14f823 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'x-transformers', packages = find_packages(exclude=['examples']), - version = '1.42.12', + version = '1.42.14', license='MIT', description = 'X-Transformers - Pytorch', author = 'Phil Wang', diff --git a/tests/test_x_transformers.py b/tests/test_x_transformers.py index bc5e5c4d..2f7c36f9 100644 --- a/tests/test_x_transformers.py +++ b/tests/test_x_transformers.py @@ -331,7 +331,10 @@ def test_reinject_input(): model(x) # (1, 1024, 20000) -def test_value_residual(): +@pytest.mark.parametrize('learned_value_residual_mix', (False, True)) +def test_value_residual( + learned_value_residual_mix: bool +): model = TransformerWrapper( num_tokens = 20000, @@ -341,6 +344,7 @@ def test_value_residual(): depth = 6, heads = 8, add_value_residual = True, + learned_value_residual_mix = learned_value_residual_mix ) ) diff --git a/x_transformers/x_transformers.py b/x_transformers/x_transformers.py index f2ff60d7..8018f2fd 100644 --- a/x_transformers/x_transformers.py +++ b/x_transformers/x_transformers.py @@ -1072,6 +1072,7 @@ def __init__( logit_softclamp_value = 50., neutreno_value_residual = False, # Nguyen et al. https://arxiv.org/abs/2312.00751 neutreno_alpha = 0.4, + learned_value_residual_mix = False, onnxable = False, attend_sdp_kwargs: dict = dict( enable_flash = True, @@ -1231,6 +1232,14 @@ def __init__( self.mem_k = nn.Parameter(torch.randn(kv_heads, num_mem_kv, dim_head)) self.mem_v = nn.Parameter(torch.randn(kv_heads, num_mem_kv, dim_head)) + # maybe learned value residual mixer per token + + self.to_value_residual_mix = nn.Sequential( + nn.Linear(dim, 1), + nn.Sigmoid(), + Rearrange('b n 1 -> b 1 n 1') + ) if learned_value_residual_mix else always(0.5) + # attention on attention self.attn_on_attn = on_attn @@ -1303,7 +1312,8 @@ def forward( diff_values = repeat(diff_values, 'b h n d -> b (r h) n d', r = h // kv_h) else: # https://arxiv.org/abs/2410.17897v1 - v = 0.5 * (v + value_residual) + value_residual_mix = self.to_value_residual_mix(q_input) + v = v * value_residual_mix + value_residual * (1. - value_residual_mix) # take care of caching @@ -1541,8 +1551,9 @@ def __init__( use_layerscale = False, layerscale_init_value = 0., unet_skips = False, - reinject_input = False, # seen first in DEQ paper https://arxiv.org/abs/1909.01377, but later used in a number of papers trying to achieve depthwise generalization https://arxiv.org/abs/2410.03020v1 - add_value_residual = False, # resformer from Zhou et al - https://arxiv.org/abs/2410.17897v1 + reinject_input = False, # seen first in DEQ paper https://arxiv.org/abs/1909.01377, but later used in a number of papers trying to achieve depthwise generalization https://arxiv.org/abs/2410.03020v1 + add_value_residual = False, # resformer from Zhou et al - https://arxiv.org/abs/2410.17897v1 + learned_value_residual_mix = False, # seeing big improvements when the value residual mix value is learned per token - credit goes to @faresobeid for taking the first step with learned scalar mix, then @Blinkdl for taking it a step further with data dependent. here we will use per token learned rel_pos_kwargs: dict = dict(), **kwargs ): @@ -1786,6 +1797,10 @@ def __init__( self.add_value_residual = add_value_residual + is_first_self_attn = True + is_first_cross_attn = True + learned_value_residual_mix &= add_value_residual + # iterate and construct layers for ind, (layer_type, layer_shift_tokens) in enumerate(zip(self.layer_types, shift_tokens)): @@ -1801,9 +1816,13 @@ def __init__( # attention, cross attention, feedforward if layer_type == 'a': - layer = Attention(dim, heads = heads, causal = causal, **attn_kwargs) + self_attn_learned_value_residual = learned_value_residual_mix and not is_first_self_attn + layer = Attention(dim, heads = heads, causal = causal, learned_value_residual_mix = self_attn_learned_value_residual, **attn_kwargs) + is_first_self_attn = False elif layer_type == 'c': - layer = Attention(dim, heads = heads, **{**attn_kwargs, **cross_attn_kwargs}) + cross_attn_learned_value_residual = learned_value_residual_mix and not is_first_cross_attn + layer = Attention(dim, heads = heads, learned_value_residual_mix = learned_value_residual_mix and not is_first_cross_attn, **{**attn_kwargs, **cross_attn_kwargs}) + is_first_cross_attn = False elif layer_type == 'f': layer = FeedForward(dim, **ff_kwargs) layer = layer if not macaron else Scale(0.5, layer)