diff --git a/README.md b/README.md index 41449b0..e52c99a 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,8 @@ Implementation of Q-Transformer, S ## Todo -- [ ] first work way towards single action support +- [x] first work way towards single action support + - [ ] build out main proposal in paper (autoregressive discrete actions until last action, reward given only on last) - [ ] do n-step Q learning, even though not that big of improvement - [ ] figure out the conservative regularization, read prior work diff --git a/q_transformer/q_transformer.py b/q_transformer/q_transformer.py index 30ad726..0388b9f 100644 --- a/q_transformer/q_transformer.py +++ b/q_transformer/q_transformer.py @@ -33,6 +33,17 @@ def cycle(dl): for batch in dl: yield batch +# tensor helpers + +def batch_select_indices(t, indices): + batch = t.shape[0] + batch_arange = torch.arange(batch, device = indices.device) + batch_arange = rearrange(batch_arange, 'b -> b 1') + indices = rearrange(indices, 'b -> b 1') + + selected = t[batch_arange, indices] + return rearrange(selected, 'b 1 -> b') + # Q learning on robotic transformer class QLearner(Module): @@ -158,6 +169,10 @@ def load(self, path): self.optimizer.load_state_dict(pkg['optimizer']) + @property + def device(self): + return self.accelerator.device + @property def is_main(self): return self.accelerator.is_main_process @@ -171,22 +186,77 @@ def print(self, msg): def wait(self): return self.accelerator.wait_for_everyone() + def q_learn( + self, + instructions: Tuple[str], + states: Tensor, + actions: Tensor, + next_states: Tensor, + reward: Tensor, + done: Tensor + ) -> Tensor: + + # 'next' stands for the very next time step (whether state, q, actions etc) + + γ = self.discount_factor_gamma + q_eval = self.model + q_target = self.ema_model + not_terminal = (~done).float() + + # first make a prediction with online q robotic transformer + + q_pred = batch_select_indices(q_eval(states, instructions), actions) + + # use an exponentially smoothed copy of model for the future q target. more stable than setting q_target to q_eval after each batch + + q_next = q_target(next_states, instructions).amax(dim = -1) + + # Bellman's equation. most important line of code, hopefully done correctly + + q_target = reward + γ * not_terminal * q_next + + # now just force the online model to be able to predict this target + + loss = F.mse_loss(q_pred, q_target) + + # that's it. 4 loc for the heart of q-learning + # return loss and some of the intermediates for logging + + return loss, (q_pred, q_next, q_target) + def forward(self): step = self.step.item() + replay_buffer_iter = cycle(self.dataloader) + self.model.train() + self.ema_model.train() + while step < self.num_train_steps: - # sample from replay buffer and q-learn + # zero grads + + self.optimizer.zero_grad() + + # main q-learning algorithm + + with self.accelerator.autocast(): + + loss, _ = self.q_learn(*next(replay_buffer_iter)) + + self.accelerator.backward(loss) + + self.print(f'loss: {loss.item():.3f}') + + # take optimizer step + + self.optimizer.step() + + # update target ema + + self.wait() - ( - instruction, - state, - action, - next_state, - reward, - done - ) = next(replay_buffer_iter) + self.ema_model.update() # increment step diff --git a/q_transformer/robotic_transformer.py b/q_transformer/robotic_transformer.py index c1e44e3..3426b05 100644 --- a/q_transformer/robotic_transformer.py +++ b/q_transformer/robotic_transformer.py @@ -613,10 +613,11 @@ def __init__( self.cond_drop_prob = cond_drop_prob - self.to_logits = nn.Sequential( + self.to_q_values = nn.Sequential( LayerNorm(attend_dim), nn.Linear(attend_dim, num_actions * action_bins), - Rearrange('... (a b) -> ... a b', b = action_bins) + Rearrange('... (a b) -> ... a b', b = action_bins), + nn.Sigmoid() ) @property @@ -697,7 +698,11 @@ def forward( attended_tokens = self.transformer(learned_tokens, cond_fns = transformer_cond_fns, attn_mask = ~attn_mask) - pooled = reduce(attended_tokens, 'b (f n) d -> b f d', 'mean', f = frames) + pooled = reduce(attended_tokens, 'b (f n) d -> b d', 'mean', f = frames) - logits = self.to_logits(pooled) - return logits + q_values = self.to_q_values(pooled) + + if self.num_actions == 1: + q_values = rearrange(q_values, '... 1 b -> ... b') + + return q_values