Skip to content

Commit

Permalink
vanilla q-learning for a robot transformer taking single action
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 26, 2023
1 parent 73b2607 commit 6245f29
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 15 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ Implementation of <a href="https://qtransformer.github.io/">Q-Transformer</a>, S

## Todo

- [ ] first work way towards single action support
- [x] first work way towards single action support

- [ ] build out main proposal in paper (autoregressive discrete actions until last action, reward given only on last)
- [ ] do n-step Q learning, even though not that big of improvement
- [ ] figure out the conservative regularization, read prior work
Expand Down
88 changes: 79 additions & 9 deletions q_transformer/q_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,17 @@ def cycle(dl):
for batch in dl:
yield batch

# tensor helpers

def batch_select_indices(t, indices):
batch = t.shape[0]
batch_arange = torch.arange(batch, device = indices.device)
batch_arange = rearrange(batch_arange, 'b -> b 1')
indices = rearrange(indices, 'b -> b 1')

selected = t[batch_arange, indices]
return rearrange(selected, 'b 1 -> b')

# Q learning on robotic transformer

class QLearner(Module):
Expand Down Expand Up @@ -158,6 +169,10 @@ def load(self, path):

self.optimizer.load_state_dict(pkg['optimizer'])

@property
def device(self):
return self.accelerator.device

@property
def is_main(self):
return self.accelerator.is_main_process
Expand All @@ -171,22 +186,77 @@ def print(self, msg):
def wait(self):
return self.accelerator.wait_for_everyone()

def q_learn(
self,
instructions: Tuple[str],
states: Tensor,
actions: Tensor,
next_states: Tensor,
reward: Tensor,
done: Tensor
) -> Tensor:

# 'next' stands for the very next time step (whether state, q, actions etc)

γ = self.discount_factor_gamma
q_eval = self.model
q_target = self.ema_model
not_terminal = (~done).float()

# first make a prediction with online q robotic transformer

q_pred = batch_select_indices(q_eval(states, instructions), actions)

# use an exponentially smoothed copy of model for the future q target. more stable than setting q_target to q_eval after each batch

q_next = q_target(next_states, instructions).amax(dim = -1)

# Bellman's equation. most important line of code, hopefully done correctly

q_target = reward + γ * not_terminal * q_next

# now just force the online model to be able to predict this target

loss = F.mse_loss(q_pred, q_target)

# that's it. 4 loc for the heart of q-learning
# return loss and some of the intermediates for logging

return loss, (q_pred, q_next, q_target)

def forward(self):
step = self.step.item()

replay_buffer_iter = cycle(self.dataloader)

self.model.train()
self.ema_model.train()

while step < self.num_train_steps:

# sample from replay buffer and q-learn
# zero grads

self.optimizer.zero_grad()

# main q-learning algorithm

with self.accelerator.autocast():

loss, _ = self.q_learn(*next(replay_buffer_iter))

self.accelerator.backward(loss)

self.print(f'loss: {loss.item():.3f}')

# take optimizer step

self.optimizer.step()

# update target ema

self.wait()

(
instruction,
state,
action,
next_state,
reward,
done
) = next(replay_buffer_iter)
self.ema_model.update()

# increment step

Expand Down
15 changes: 10 additions & 5 deletions q_transformer/robotic_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,10 +613,11 @@ def __init__(

self.cond_drop_prob = cond_drop_prob

self.to_logits = nn.Sequential(
self.to_q_values = nn.Sequential(
LayerNorm(attend_dim),
nn.Linear(attend_dim, num_actions * action_bins),
Rearrange('... (a b) -> ... a b', b = action_bins)
Rearrange('... (a b) -> ... a b', b = action_bins),
nn.Sigmoid()
)

@property
Expand Down Expand Up @@ -697,7 +698,11 @@ def forward(

attended_tokens = self.transformer(learned_tokens, cond_fns = transformer_cond_fns, attn_mask = ~attn_mask)

pooled = reduce(attended_tokens, 'b (f n) d -> b f d', 'mean', f = frames)
pooled = reduce(attended_tokens, 'b (f n) d -> b d', 'mean', f = frames)

logits = self.to_logits(pooled)
return logits
q_values = self.to_q_values(pooled)

if self.num_actions == 1:
q_values = rearrange(q_values, '... 1 b -> ... b')

return q_values

0 comments on commit 6245f29

Please sign in to comment.