From a6508669277d2c0026fe9ea7c355ec2c05b40fd9 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Wed, 29 Nov 2023 13:11:56 -0800 Subject: [PATCH] let us just finish it today, life is short --- q_transformer/mocks.py | 15 +++++++++++++-- q_transformer/q_transformer.py | 22 +++++++++++++++++----- 2 files changed, 30 insertions(+), 7 deletions(-) diff --git a/q_transformer/mocks.py b/q_transformer/mocks.py index 6652582..fd849de 100644 --- a/q_transformer/mocks.py +++ b/q_transformer/mocks.py @@ -7,10 +7,12 @@ class MockReplayDataset(Dataset): def __init__( self, length = 10000, + num_actions = 1, num_action_bins = 256, video_shape = (6, 224, 224) ): self.length = length + self.num_actions = num_actions self.num_action_bins = num_action_bins self.video_shape = video_shape @@ -21,7 +23,12 @@ def __getitem__(self, _): instruction = "please clean the kitchen" state = torch.randn(3, *self.video_shape) - action = torch.tensor(randrange(self.num_action_bins + 1)) + + if self.num_actions == 1: + action = torch.tensor(randrange(self.num_action_bins + 1)) + else: + action = torch.randint(0, self.num_action_bins + 1, (self.num_actions,)) + next_state = torch.randn(3, *self.video_shape) reward = torch.tensor(randrange(2)) done = torch.tensor(randrange(2), dtype = torch.bool) @@ -33,12 +40,14 @@ def __init__( self, length = 10000, num_steps = 2, + num_actions = 1, num_action_bins = 256, video_shape = (6, 224, 224) ): self.num_steps = num_steps self.time_shape = (num_steps,) self.length = length + self.num_actions = num_actions self.num_action_bins = num_action_bins self.video_shape = video_shape @@ -47,9 +56,11 @@ def __len__(self): def __getitem__(self, _): + action_dims = (self.num_actions,) if self.num_actions > 1 else tuple() + instruction = "please clean the kitchen" state = torch.randn(*self.time_shape, 3, *self.video_shape) - action = torch.randint(0, self.num_action_bins + 1, self.time_shape) + action = torch.randint(0, self.num_action_bins + 1, (*action_dims, *self.time_shape)) next_state = torch.randn(3, *self.video_shape) reward = torch.randint(0, 2, self.time_shape) done = torch.zeros(self.time_shape, dtype = torch.bool) diff --git a/q_transformer/q_transformer.py b/q_transformer/q_transformer.py index 24892e2..b33f557 100644 --- a/q_transformer/q_transformer.py +++ b/q_transformer/q_transformer.py @@ -252,7 +252,7 @@ def q_learn( reward: TensorType['b', float], done: TensorType['b', bool], *, - monte_carlo_return = None + monte_carlo_return = -1e4 ) -> Tuple[Tensor, QIntermediates]: # 'next' stands for the very next time step (whether state, q, actions etc) @@ -270,7 +270,7 @@ def q_learn( # the max Q value is taken as the optimal action is implicitly the one with the highest Q score q_next = self.ema_model(next_states, instructions).amax(dim = -1) - q_next = q_next.clamp(min = default(monte_carlo_return, -1e4)) + q_next = q_next.clamp(min = monte_carlo_return) # Bellman's equation. most important line of code, hopefully done correctly @@ -294,7 +294,7 @@ def n_step_q_learn( rewards: TensorType['b', 't', float], dones: TensorType['b', 't', bool], *, - monte_carlo_return = None + monte_carlo_return = -1e4 ) -> Tuple[Tensor, QIntermediates]: """ @@ -338,7 +338,7 @@ def n_step_q_learn( q_pred = unpack_one(q_pred, time_ps, '*') q_next = self.ema_model(next_states, instructions).amax(dim = -1) - q_next = q_next.clamp(min = default(monte_carlo_return, -1e4)) + q_next = q_next.clamp(min = monte_carlo_return) # prepare rewards and discount factors across timesteps @@ -369,7 +369,7 @@ def autoregressive_q_learn( rewards: TensorType['b', 't', float], dones: TensorType['b', 't', bool], *, - monte_carlo_return = None + monte_carlo_return = -1e4 ) -> Tuple[Tensor, QIntermediates]: """ @@ -386,6 +386,18 @@ def autoregressive_q_learn( q - q values """ + num_timesteps, device = states.shape[1], states.device + + # fold time steps into batch + + states, time_ps = pack_one(states, '* c f h w') + + # repeat instructions per timestep + + repeated_instructions = repeat_tuple_el(instructions, num_timesteps) + + γ = self.discount_factor_gamma + raise NotImplementedError def learn(