Skip to content

Commit

Permalink
let us just finish it today, life is short
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 29, 2023
1 parent 4669c7c commit a650866
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 7 deletions.
15 changes: 13 additions & 2 deletions q_transformer/mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@ class MockReplayDataset(Dataset):
def __init__(
self,
length = 10000,
num_actions = 1,
num_action_bins = 256,
video_shape = (6, 224, 224)
):
self.length = length
self.num_actions = num_actions
self.num_action_bins = num_action_bins
self.video_shape = video_shape

Expand All @@ -21,7 +23,12 @@ def __getitem__(self, _):

instruction = "please clean the kitchen"
state = torch.randn(3, *self.video_shape)
action = torch.tensor(randrange(self.num_action_bins + 1))

if self.num_actions == 1:
action = torch.tensor(randrange(self.num_action_bins + 1))
else:
action = torch.randint(0, self.num_action_bins + 1, (self.num_actions,))

next_state = torch.randn(3, *self.video_shape)
reward = torch.tensor(randrange(2))
done = torch.tensor(randrange(2), dtype = torch.bool)
Expand All @@ -33,12 +40,14 @@ def __init__(
self,
length = 10000,
num_steps = 2,
num_actions = 1,
num_action_bins = 256,
video_shape = (6, 224, 224)
):
self.num_steps = num_steps
self.time_shape = (num_steps,)
self.length = length
self.num_actions = num_actions
self.num_action_bins = num_action_bins
self.video_shape = video_shape

Expand All @@ -47,9 +56,11 @@ def __len__(self):

def __getitem__(self, _):

action_dims = (self.num_actions,) if self.num_actions > 1 else tuple()

instruction = "please clean the kitchen"
state = torch.randn(*self.time_shape, 3, *self.video_shape)
action = torch.randint(0, self.num_action_bins + 1, self.time_shape)
action = torch.randint(0, self.num_action_bins + 1, (*action_dims, *self.time_shape))
next_state = torch.randn(3, *self.video_shape)
reward = torch.randint(0, 2, self.time_shape)
done = torch.zeros(self.time_shape, dtype = torch.bool)
Expand Down
22 changes: 17 additions & 5 deletions q_transformer/q_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def q_learn(
reward: TensorType['b', float],
done: TensorType['b', bool],
*,
monte_carlo_return = None
monte_carlo_return = -1e4

) -> Tuple[Tensor, QIntermediates]:
# 'next' stands for the very next time step (whether state, q, actions etc)
Expand All @@ -270,7 +270,7 @@ def q_learn(
# the max Q value is taken as the optimal action is implicitly the one with the highest Q score

q_next = self.ema_model(next_states, instructions).amax(dim = -1)
q_next = q_next.clamp(min = default(monte_carlo_return, -1e4))
q_next = q_next.clamp(min = monte_carlo_return)

# Bellman's equation. most important line of code, hopefully done correctly

Expand All @@ -294,7 +294,7 @@ def n_step_q_learn(
rewards: TensorType['b', 't', float],
dones: TensorType['b', 't', bool],
*,
monte_carlo_return = None
monte_carlo_return = -1e4

) -> Tuple[Tensor, QIntermediates]:
"""
Expand Down Expand Up @@ -338,7 +338,7 @@ def n_step_q_learn(
q_pred = unpack_one(q_pred, time_ps, '*')

q_next = self.ema_model(next_states, instructions).amax(dim = -1)
q_next = q_next.clamp(min = default(monte_carlo_return, -1e4))
q_next = q_next.clamp(min = monte_carlo_return)

# prepare rewards and discount factors across timesteps

Expand Down Expand Up @@ -369,7 +369,7 @@ def autoregressive_q_learn(
rewards: TensorType['b', 't', float],
dones: TensorType['b', 't', bool],
*,
monte_carlo_return = None
monte_carlo_return = -1e4

) -> Tuple[Tensor, QIntermediates]:
"""
Expand All @@ -386,6 +386,18 @@ def autoregressive_q_learn(
q - q values
"""

num_timesteps, device = states.shape[1], states.device

# fold time steps into batch

states, time_ps = pack_one(states, '* c f h w')

# repeat instructions per timestep

repeated_instructions = repeat_tuple_el(instructions, num_timesteps)

γ = self.discount_factor_gamma

raise NotImplementedError

def learn(
Expand Down

0 comments on commit a650866

Please sign in to comment.