Skip to content

Commit

Permalink
multiple actions is ready for q-learning!
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 29, 2023
1 parent 0750e9d commit 8bfb82c
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 5 deletions.
32 changes: 28 additions & 4 deletions q_transformer/robotic_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,7 +719,11 @@ def get_optimal_actions(

return action_bins

def forward(self, encoded_state):
def forward(
self,
encoded_state: Tensor,
actions: Optional[Tensor] = None
):
"""
einops
b - batch
Expand All @@ -732,7 +736,22 @@ def forward(self, encoded_state):

sos_token = reduce(encoded_state, 'b ... d -> b 1 d', 'mean')

embed = self.transformer(sos_token)
if exists(actions):
batch, num_actions = actions.shape
action_embeddings = self.action_bin_embeddings[:num_actions]

action_embeddings = repeat(action_embeddings, 'n a d -> b n a d', b = batch)
past_action_bins = repeat(actions, 'b n -> b n 1 d', d = action_embeddings.shape[-1])

bin_embeddings = action_embeddings.gather(-2, past_action_bins)
bin_embeddings = rearrange(bin_embeddings, 'b n 1 d -> b n d')

tokens = torch.cat((sos_token, bin_embeddings), dim = -2)
tokens = tokens[:, :self.num_actions] # last action bin not needed for the proposed q-learning
else:
tokens = sos_token

embed = self.transformer(tokens)
embed = self.final_norm(embed)

num_actions = embed.shape[-2]
Expand Down Expand Up @@ -784,7 +803,7 @@ def __init__(

attend_dim = vit.embed_dim

# q-transformer related action embeddings - redo
# q-transformer related action embeddings

assert num_actions >= 1

Expand Down Expand Up @@ -944,6 +963,11 @@ def forward(
# head that returns the q values
# supporting both single and multiple actions

q_values = self.q_head(encoded_state)
if self.is_single_action:
assert not exists(actions), 'actions should not be passed in for single action robotic transformer'

q_values = self.q_head(encoded_state)
else:
q_values = self.q_head(encoded_state, actions = actions)

return q_values
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'q-transformer',
packages = find_packages(exclude=[]),
version = '0.0.20',
version = '0.0.21',
license='MIT',
description = 'Q-Transformer',
author = 'Phil Wang',
Expand Down

0 comments on commit 8bfb82c

Please sign in to comment.