From 8bfb82cea94dfa898284fe1601f5314e024b15d5 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Wed, 29 Nov 2023 11:52:10 -0800 Subject: [PATCH] multiple actions is ready for q-learning! --- q_transformer/robotic_transformer.py | 32 ++++++++++++++++++++++++---- setup.py | 2 +- 2 files changed, 29 insertions(+), 5 deletions(-) diff --git a/q_transformer/robotic_transformer.py b/q_transformer/robotic_transformer.py index 7203311..9d9a0ae 100644 --- a/q_transformer/robotic_transformer.py +++ b/q_transformer/robotic_transformer.py @@ -719,7 +719,11 @@ def get_optimal_actions( return action_bins - def forward(self, encoded_state): + def forward( + self, + encoded_state: Tensor, + actions: Optional[Tensor] = None + ): """ einops b - batch @@ -732,7 +736,22 @@ def forward(self, encoded_state): sos_token = reduce(encoded_state, 'b ... d -> b 1 d', 'mean') - embed = self.transformer(sos_token) + if exists(actions): + batch, num_actions = actions.shape + action_embeddings = self.action_bin_embeddings[:num_actions] + + action_embeddings = repeat(action_embeddings, 'n a d -> b n a d', b = batch) + past_action_bins = repeat(actions, 'b n -> b n 1 d', d = action_embeddings.shape[-1]) + + bin_embeddings = action_embeddings.gather(-2, past_action_bins) + bin_embeddings = rearrange(bin_embeddings, 'b n 1 d -> b n d') + + tokens = torch.cat((sos_token, bin_embeddings), dim = -2) + tokens = tokens[:, :self.num_actions] # last action bin not needed for the proposed q-learning + else: + tokens = sos_token + + embed = self.transformer(tokens) embed = self.final_norm(embed) num_actions = embed.shape[-2] @@ -784,7 +803,7 @@ def __init__( attend_dim = vit.embed_dim - # q-transformer related action embeddings - redo + # q-transformer related action embeddings assert num_actions >= 1 @@ -944,6 +963,11 @@ def forward( # head that returns the q values # supporting both single and multiple actions - q_values = self.q_head(encoded_state) + if self.is_single_action: + assert not exists(actions), 'actions should not be passed in for single action robotic transformer' + + q_values = self.q_head(encoded_state) + else: + q_values = self.q_head(encoded_state, actions = actions) return q_values diff --git a/setup.py b/setup.py index 54c581a..a3f175d 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'q-transformer', packages = find_packages(exclude=[]), - version = '0.0.20', + version = '0.0.21', license='MIT', description = 'Q-Transformer', author = 'Phil Wang',