From 4669c7ca7d49a940fa97f0344dbb11c8e581f58b Mon Sep 17 00:00:00 2001 From: lucidrains Date: Wed, 29 Nov 2023 12:00:56 -0800 Subject: [PATCH] finally approaching the main contribution of the paper. to be finished tomorrow morning --- q_transformer/q_transformer.py | 37 +++++++++++++++++++++++++++- q_transformer/robotic_transformer.py | 1 + 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/q_transformer/q_transformer.py b/q_transformer/q_transformer.py index 02d147b..24892e2 100644 --- a/q_transformer/q_transformer.py +++ b/q_transformer/q_transformer.py @@ -113,6 +113,8 @@ def __init__( ): super().__init__() + self.is_multiple_actions = model.num_actions > 1 + # q-learning related hyperparameters self.discount_factor_gamma = discount_factor_gamma @@ -358,6 +360,34 @@ def n_step_q_learn( return loss, QIntermediates(q_pred_all_actions, q_pred, q_next, q_target) + def autoregressive_q_learn( + self, + instructions: Tuple[str], + states: TensorType['b', 't', 'c', 'f', 'h', 'w', float], + actions: TensorType['b', 't', 'n', int], + next_states: TensorType['b', 'c', 'f', 'h', 'w', float], + rewards: TensorType['b', 't', float], + dones: TensorType['b', 't', bool], + *, + monte_carlo_return = None + + ) -> Tuple[Tensor, QIntermediates]: + """ + einops + + b - batch + c - channels + f - frames + h - height + w - width + t - timesteps + n - number of actions + a - action bins + q - q values + """ + + raise NotImplementedError + def learn( self, *args, @@ -374,9 +404,14 @@ def learn( # main q-learning loss, whether single or n-step - if self.n_step_q_learning: + if self.is_multiple_actions: + td_loss, q_intermediates = self.autoregressive_q_learn(*args, **q_learn_kwargs) + num_timesteps = actions.shape[1] + + elif self.n_step_q_learning: td_loss, q_intermediates = self.n_step_q_learn(*args, **q_learn_kwargs) num_timesteps = actions.shape[1] + else: td_loss, q_intermediates = self.q_learn(*args, **q_learn_kwargs) num_timesteps = 1 diff --git a/q_transformer/robotic_transformer.py b/q_transformer/robotic_transformer.py index 9d9a0ae..2b059da 100644 --- a/q_transformer/robotic_transformer.py +++ b/q_transformer/robotic_transformer.py @@ -807,6 +807,7 @@ def __init__( assert num_actions >= 1 + self.num_actions = num_actions self.is_single_action = num_actions == 1 self.action_bins = action_bins