Skip to content

Commit

Permalink
forgot the sigmoid for Q multiple action head. also make sure dueling…
Browse files Browse the repository at this point in the history
… is supported for multiple action heads, even though not clear whether dueling is used these days or not
  • Loading branch information
lucidrains committed Nov 30, 2023
1 parent 2ad6f4d commit 775ec64
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 7 deletions.
24 changes: 18 additions & 6 deletions q_transformer/q_robotic_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,7 +568,7 @@ def __init__(
self,
dim,
expansion_factor = 2,
action_bins =256
action_bins = 256
):
super().__init__()
dim_hidden = dim * expansion_factor
Expand Down Expand Up @@ -657,7 +657,8 @@ def __init__(
action_bins = 256,
attn_depth = 2,
attn_dim_head = 32,
attn_heads = 8
attn_heads = 8,
dueling = False
):
super().__init__()
self.num_actions = num_actions
Expand All @@ -676,6 +677,10 @@ def __init__(

self.final_norm = nn.LayerNorm(dim)

self.dueling = dueling
if dueling:
self.to_values = nn.Parameter(torch.zeros(num_actions, dim))

@property
def device(self):
return self.action_embeddings.device
Expand Down Expand Up @@ -757,9 +762,17 @@ def forward(
num_actions = embed.shape[-2]
action_bin_embeddings = self.action_bin_embeddings[:num_actions]

q_values = einsum('b n d, n a d -> b n a', embed, action_bin_embeddings)
if self.dueling:
advantages = einsum('b n d, n a d -> b n a', embed, action_bin_embeddings)

return q_values
values = einsum('b n d, n d -> b n', embed, self.to_values[:num_actions])
values = rearrange(values, 'b n -> b n 1')

q_values = values + (advantages - reduce(advantages, '... a -> ... 1', 'mean'))
else:
q_values = einsum('b n d, n a d -> b n a', embed, action_bin_embeddings)

return q_values.sigmoid()

# Robotic Transformer

Expand Down Expand Up @@ -853,11 +866,10 @@ def __init__(
dueling = dueling
)
else:
assert not dueling, 'dueling not supported yet for action transformer decoder'

self.q_head = QHeadMultipleActions(
attend_dim,
action_bins = action_bins,
dueling = dueling,
**q_head_attn_kwargs
)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'q-transformer',
packages = find_packages(exclude=[]),
version = '0.0.25',
version = '0.0.27',
license='MIT',
description = 'Q-Transformer',
author = 'Phil Wang',
Expand Down

0 comments on commit 775ec64

Please sign in to comment.