Skip to content

Commit

Permalink
complete the autoregressive discrete formulation of q-learning for hi…
Browse files Browse the repository at this point in the history
…gh action spaces
  • Loading branch information
lucidrains committed Nov 29, 2023
1 parent a650866 commit b7477b6
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 13 deletions.
4 changes: 1 addition & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,10 @@ I will be keeping around the logic for Q-learning on single action just for fina
- [x] add optional deep dueling architecture
- [x] add n-step Q learning
- [x] build the conservative regularization

- [x] build out main proposal in paper (autoregressive discrete actions until last action, reward given only on last)
- [x] improvise decoder head variant, instead of concatenating previous actions at the frames + learned tokens stage. in other words, use classic encoder - decoder
- [ ] allow for cross attention to fine frame / learned tokens

- [ ] build out main proposal in paper (autoregressive discrete actions until last action, reward given only on last)

- [ ] build out a simple dataset creator class, taking in the environment as an iterator / generator
- [ ] see if the main idea in this paper is applicable to language models <a href="https://github.com/lucidrains/llama-qrlhf">here</a>
- [ ] consult some RL experts and figure out if there are any new headways into resolving <a href="https://www.cs.toronto.edu/~cebly/Papers/CONQUR_ICML_2020_camera_ready.pdf">delusional bias</a>
Expand Down
2 changes: 1 addition & 1 deletion q_transformer/mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __getitem__(self, _):

instruction = "please clean the kitchen"
state = torch.randn(*self.time_shape, 3, *self.video_shape)
action = torch.randint(0, self.num_action_bins + 1, (*action_dims, *self.time_shape))
action = torch.randint(0, self.num_action_bins + 1, (*self.time_shape, *action_dims))
next_state = torch.randn(3, *self.video_shape)
reward = torch.randint(0, 2, self.time_shape)
done = torch.zeros(self.time_shape, dtype = torch.bool)
Expand Down
81 changes: 73 additions & 8 deletions q_transformer/q_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def q_learn(
reward: TensorType['b', float],
done: TensorType['b', bool],
*,
monte_carlo_return = -1e4
monte_carlo_return = None

) -> Tuple[Tensor, QIntermediates]:
# 'next' stands for the very next time step (whether state, q, actions etc)
Expand All @@ -270,7 +270,7 @@ def q_learn(
# the max Q value is taken as the optimal action is implicitly the one with the highest Q score

q_next = self.ema_model(next_states, instructions).amax(dim = -1)
q_next = q_next.clamp(min = monte_carlo_return)
q_next.clamp_(min = default(monte_carlo_return, -1e4))

# Bellman's equation. most important line of code, hopefully done correctly

Expand All @@ -294,7 +294,7 @@ def n_step_q_learn(
rewards: TensorType['b', 't', float],
dones: TensorType['b', 't', bool],
*,
monte_carlo_return = -1e4
monte_carlo_return = None

) -> Tuple[Tensor, QIntermediates]:
"""
Expand Down Expand Up @@ -338,7 +338,7 @@ def n_step_q_learn(
q_pred = unpack_one(q_pred, time_ps, '*')

q_next = self.ema_model(next_states, instructions).amax(dim = -1)
q_next = q_next.clamp(min = monte_carlo_return)
q_next.clamp_(min = default(monte_carlo_return, -1e4))

# prepare rewards and discount factors across timesteps

Expand Down Expand Up @@ -369,7 +369,7 @@ def autoregressive_q_learn(
rewards: TensorType['b', 't', float],
dones: TensorType['b', 't', bool],
*,
monte_carlo_return = -1e4
monte_carlo_return = None

) -> Tuple[Tensor, QIntermediates]:
"""
Expand All @@ -385,20 +385,84 @@ def autoregressive_q_learn(
a - action bins
q - q values
"""

monte_carlo_return = default(monte_carlo_return, -1e4)
num_timesteps, device = states.shape[1], states.device

# fold time steps into batch

states, time_ps = pack_one(states, '* c f h w')
actions, _ = pack_one(actions, '* n')

# repeat instructions per timestep

repeated_instructions = repeat_tuple_el(instructions, num_timesteps)

# anything after the first done flag will be considered terminal

dones = dones.cumsum(dim = -1) > 0
dones = F.pad(dones, (1, -1), value = False)

not_terminal = (~dones).float()

# rewards should not be given on and after terminal step

rewards = rewards * not_terminal

# because greek unicode is nice to look at

γ = self.discount_factor_gamma

raise NotImplementedError
# get predicted Q for each action
# unpack back to (b, t, n)

q_pred_all_actions = self.model(states, repeated_instructions, actions = actions)

q_pred_all_actions, btn_ps = pack_one(q_pred_all_actions, '* a')
flattened_actions, _ = pack_one(actions, '*')

q_pred = batch_select_indices(q_pred_all_actions, flattened_actions)

q_pred = unpack_one(q_pred, btn_ps, '*')
q_pred = unpack_one(q_pred, time_ps, '* n')

# get q_next

q_next = self.ema_model(next_states, instructions)
q_next = q_next.max(dim = -1).values
q_next.clamp_(min = monte_carlo_return)

# get target Q
# unpack back to - (b, t, n)

q_target_all_actions = self.ema_model(states, repeated_instructions, actions = actions)
q_target = q_target_all_actions.max(dim = -1).values

q_target.clamp_(min = monte_carlo_return)
q_target = unpack_one(q_target, time_ps, '* n')

# main contribution of the paper is the following logic
# section 4.1 - eq. 1

# first take care of the loss for all actions except for the very last one

q_pred_rest_actions, q_pred_last_action = q_pred[..., :-1], q_pred[..., -1]
q_target_first_action, q_target_rest_actions = q_target[..., 0], q_target[..., 1:]

losses_all_actions_but_last = F.mse_loss(q_pred_rest_actions, q_target_rest_actions, reduction = 'none')

# next take care of the very last action, which incorporates the rewards

q_target_last_action, _ = pack([q_target_first_action[..., 1:], q_next], 'b *')

q_target_last_action = rewards + q_target_last_action

losses_last_action = F.mse_loss(q_pred_last_action, q_target_last_action, reduction = 'none')

# flatten and average

losses, _ = pack([losses_all_actions_but_last, losses_last_action], '*')

return losses.mean(), QIntermediates(q_pred_all_actions, q_pred, q_next, q_target)

def learn(
self,
Expand Down Expand Up @@ -434,10 +498,11 @@ def learn(
batch = actions.shape[0]

q_preds = q_intermediates.q_pred_all_actions

num_action_bins = q_preds.shape[-1]
num_non_dataset_actions = num_action_bins - 1

actions = rearrange(actions, '... -> ... 1')
actions = rearrange(actions, '... -> (...) 1')

dataset_action_mask = torch.zeros_like(q_preds).scatter_(-1, actions, torch.ones_like(q_preds))

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'q-transformer',
packages = find_packages(exclude=[]),
version = '0.0.21',
version = '0.0.22',
license='MIT',
description = 'Q-Transformer',
author = 'Phil Wang',
Expand Down

0 comments on commit b7477b6

Please sign in to comment.