From b7477b60aa3f903dfd0937c35c6a5390d7f73f24 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Wed, 29 Nov 2023 14:10:29 -0800 Subject: [PATCH] complete the autoregressive discrete formulation of q-learning for high action spaces --- README.md | 4 +- q_transformer/mocks.py | 2 +- q_transformer/q_transformer.py | 81 ++++++++++++++++++++++++++++++---- setup.py | 2 +- 4 files changed, 76 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index 408153c..7b1fc76 100644 --- a/README.md +++ b/README.md @@ -17,12 +17,10 @@ I will be keeping around the logic for Q-learning on single action just for fina - [x] add optional deep dueling architecture - [x] add n-step Q learning - [x] build the conservative regularization - +- [x] build out main proposal in paper (autoregressive discrete actions until last action, reward given only on last) - [x] improvise decoder head variant, instead of concatenating previous actions at the frames + learned tokens stage. in other words, use classic encoder - decoder - [ ] allow for cross attention to fine frame / learned tokens -- [ ] build out main proposal in paper (autoregressive discrete actions until last action, reward given only on last) - - [ ] build out a simple dataset creator class, taking in the environment as an iterator / generator - [ ] see if the main idea in this paper is applicable to language models here - [ ] consult some RL experts and figure out if there are any new headways into resolving delusional bias diff --git a/q_transformer/mocks.py b/q_transformer/mocks.py index fd849de..7f198bf 100644 --- a/q_transformer/mocks.py +++ b/q_transformer/mocks.py @@ -60,7 +60,7 @@ def __getitem__(self, _): instruction = "please clean the kitchen" state = torch.randn(*self.time_shape, 3, *self.video_shape) - action = torch.randint(0, self.num_action_bins + 1, (*action_dims, *self.time_shape)) + action = torch.randint(0, self.num_action_bins + 1, (*self.time_shape, *action_dims)) next_state = torch.randn(3, *self.video_shape) reward = torch.randint(0, 2, self.time_shape) done = torch.zeros(self.time_shape, dtype = torch.bool) diff --git a/q_transformer/q_transformer.py b/q_transformer/q_transformer.py index b33f557..5089117 100644 --- a/q_transformer/q_transformer.py +++ b/q_transformer/q_transformer.py @@ -252,7 +252,7 @@ def q_learn( reward: TensorType['b', float], done: TensorType['b', bool], *, - monte_carlo_return = -1e4 + monte_carlo_return = None ) -> Tuple[Tensor, QIntermediates]: # 'next' stands for the very next time step (whether state, q, actions etc) @@ -270,7 +270,7 @@ def q_learn( # the max Q value is taken as the optimal action is implicitly the one with the highest Q score q_next = self.ema_model(next_states, instructions).amax(dim = -1) - q_next = q_next.clamp(min = monte_carlo_return) + q_next.clamp_(min = default(monte_carlo_return, -1e4)) # Bellman's equation. most important line of code, hopefully done correctly @@ -294,7 +294,7 @@ def n_step_q_learn( rewards: TensorType['b', 't', float], dones: TensorType['b', 't', bool], *, - monte_carlo_return = -1e4 + monte_carlo_return = None ) -> Tuple[Tensor, QIntermediates]: """ @@ -338,7 +338,7 @@ def n_step_q_learn( q_pred = unpack_one(q_pred, time_ps, '*') q_next = self.ema_model(next_states, instructions).amax(dim = -1) - q_next = q_next.clamp(min = monte_carlo_return) + q_next.clamp_(min = default(monte_carlo_return, -1e4)) # prepare rewards and discount factors across timesteps @@ -369,7 +369,7 @@ def autoregressive_q_learn( rewards: TensorType['b', 't', float], dones: TensorType['b', 't', bool], *, - monte_carlo_return = -1e4 + monte_carlo_return = None ) -> Tuple[Tensor, QIntermediates]: """ @@ -385,20 +385,84 @@ def autoregressive_q_learn( a - action bins q - q values """ - + monte_carlo_return = default(monte_carlo_return, -1e4) num_timesteps, device = states.shape[1], states.device # fold time steps into batch states, time_ps = pack_one(states, '* c f h w') + actions, _ = pack_one(actions, '* n') # repeat instructions per timestep repeated_instructions = repeat_tuple_el(instructions, num_timesteps) + # anything after the first done flag will be considered terminal + + dones = dones.cumsum(dim = -1) > 0 + dones = F.pad(dones, (1, -1), value = False) + + not_terminal = (~dones).float() + + # rewards should not be given on and after terminal step + + rewards = rewards * not_terminal + + # because greek unicode is nice to look at + γ = self.discount_factor_gamma - raise NotImplementedError + # get predicted Q for each action + # unpack back to (b, t, n) + + q_pred_all_actions = self.model(states, repeated_instructions, actions = actions) + + q_pred_all_actions, btn_ps = pack_one(q_pred_all_actions, '* a') + flattened_actions, _ = pack_one(actions, '*') + + q_pred = batch_select_indices(q_pred_all_actions, flattened_actions) + + q_pred = unpack_one(q_pred, btn_ps, '*') + q_pred = unpack_one(q_pred, time_ps, '* n') + + # get q_next + + q_next = self.ema_model(next_states, instructions) + q_next = q_next.max(dim = -1).values + q_next.clamp_(min = monte_carlo_return) + + # get target Q + # unpack back to - (b, t, n) + + q_target_all_actions = self.ema_model(states, repeated_instructions, actions = actions) + q_target = q_target_all_actions.max(dim = -1).values + + q_target.clamp_(min = monte_carlo_return) + q_target = unpack_one(q_target, time_ps, '* n') + + # main contribution of the paper is the following logic + # section 4.1 - eq. 1 + + # first take care of the loss for all actions except for the very last one + + q_pred_rest_actions, q_pred_last_action = q_pred[..., :-1], q_pred[..., -1] + q_target_first_action, q_target_rest_actions = q_target[..., 0], q_target[..., 1:] + + losses_all_actions_but_last = F.mse_loss(q_pred_rest_actions, q_target_rest_actions, reduction = 'none') + + # next take care of the very last action, which incorporates the rewards + + q_target_last_action, _ = pack([q_target_first_action[..., 1:], q_next], 'b *') + + q_target_last_action = rewards + q_target_last_action + + losses_last_action = F.mse_loss(q_pred_last_action, q_target_last_action, reduction = 'none') + + # flatten and average + + losses, _ = pack([losses_all_actions_but_last, losses_last_action], '*') + + return losses.mean(), QIntermediates(q_pred_all_actions, q_pred, q_next, q_target) def learn( self, @@ -434,10 +498,11 @@ def learn( batch = actions.shape[0] q_preds = q_intermediates.q_pred_all_actions + num_action_bins = q_preds.shape[-1] num_non_dataset_actions = num_action_bins - 1 - actions = rearrange(actions, '... -> ... 1') + actions = rearrange(actions, '... -> (...) 1') dataset_action_mask = torch.zeros_like(q_preds).scatter_(-1, actions, torch.ones_like(q_preds)) diff --git a/setup.py b/setup.py index a3f175d..41553a7 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'q-transformer', packages = find_packages(exclude=[]), - version = '0.0.21', + version = '0.0.22', license='MIT', description = 'Q-Transformer', author = 'Phil Wang',