From 775ec64a0ed9a987c42730f607e36c9d84bb2573 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Thu, 30 Nov 2023 08:20:48 -0800 Subject: [PATCH] forgot the sigmoid for Q multiple action head. also make sure dueling is supported for multiple action heads, even though not clear whether dueling is used these days or not --- q_transformer/q_robotic_transformer.py | 24 ++++++++++++++++++------ setup.py | 2 +- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/q_transformer/q_robotic_transformer.py b/q_transformer/q_robotic_transformer.py index 2b059da..da00f34 100644 --- a/q_transformer/q_robotic_transformer.py +++ b/q_transformer/q_robotic_transformer.py @@ -568,7 +568,7 @@ def __init__( self, dim, expansion_factor = 2, - action_bins =256 + action_bins = 256 ): super().__init__() dim_hidden = dim * expansion_factor @@ -657,7 +657,8 @@ def __init__( action_bins = 256, attn_depth = 2, attn_dim_head = 32, - attn_heads = 8 + attn_heads = 8, + dueling = False ): super().__init__() self.num_actions = num_actions @@ -676,6 +677,10 @@ def __init__( self.final_norm = nn.LayerNorm(dim) + self.dueling = dueling + if dueling: + self.to_values = nn.Parameter(torch.zeros(num_actions, dim)) + @property def device(self): return self.action_embeddings.device @@ -757,9 +762,17 @@ def forward( num_actions = embed.shape[-2] action_bin_embeddings = self.action_bin_embeddings[:num_actions] - q_values = einsum('b n d, n a d -> b n a', embed, action_bin_embeddings) + if self.dueling: + advantages = einsum('b n d, n a d -> b n a', embed, action_bin_embeddings) - return q_values + values = einsum('b n d, n d -> b n', embed, self.to_values[:num_actions]) + values = rearrange(values, 'b n -> b n 1') + + q_values = values + (advantages - reduce(advantages, '... a -> ... 1', 'mean')) + else: + q_values = einsum('b n d, n a d -> b n a', embed, action_bin_embeddings) + + return q_values.sigmoid() # Robotic Transformer @@ -853,11 +866,10 @@ def __init__( dueling = dueling ) else: - assert not dueling, 'dueling not supported yet for action transformer decoder' - self.q_head = QHeadMultipleActions( attend_dim, action_bins = action_bins, + dueling = dueling, **q_head_attn_kwargs ) diff --git a/setup.py b/setup.py index f6fff89..a19cf7c 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'q-transformer', packages = find_packages(exclude=[]), - version = '0.0.25', + version = '0.0.27', license='MIT', description = 'Q-Transformer', author = 'Phil Wang',