From 24624fc0f02b6629a4728f803a7fbf5b5f37d7d8 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Fri, 8 Dec 2023 09:43:46 -0800 Subject: [PATCH] agent now converts instructions to text embeddings, stored in memmap, and q-learner now trains off text embeddings --- README.md | 4 ++-- q_transformer/agent.py | 48 ++++++++++++++++++++++++++------------ q_transformer/q_learner.py | 43 ++++++++++++++++------------------ setup.py | 4 ++-- 4 files changed, 57 insertions(+), 42 deletions(-) diff --git a/README.md b/README.md index bb21847..623de39 100644 --- a/README.md +++ b/README.md @@ -27,12 +27,12 @@ I will be keeping around the logic for Q-learning on single action just for fina - [x] `ReplayDataset` that takes in folder - [x] 1 time step option - [x] n-time steps +- [x] handle multiple instructions correctly - [ ] show a simple end-to-end example, in the same style as all other repos - [ ] consult some RL experts and figure out if there are any new headways into resolving delusional bias - -- [ ] handle multiple instructions correctly, but also handle no instructions, in case one wants to train a robot doing a single specialized task +- [ ] handle no instructions, leverage null conditioner in CFG library - [ ] for exploration, allow for finely randomizing a subset of actions, and not all actions at once - [ ] figure out if one can train with randomized orders of actions - order could be sent as a conditioning that is concatted or summed before attention layers - [ ] offer an improvised variant where the first action token suggests the action ordering. all actions aren't made equal, and some may need to attend to past actions more than others diff --git a/q_transformer/agent.py b/q_transformer/agent.py index c65677e..223927e 100644 --- a/q_transformer/agent.py +++ b/q_transformer/agent.py @@ -20,6 +20,7 @@ # constants +TEXT_EMBEDS_FILENAME = 'text_embeds.memmap.npy' STATES_FILENAME = 'states.memmap.npy' ACTIONS_FILENAME = 'actions.memmap.npy' REWARDS_FILENAME = 'rewards.memmap.npy' @@ -35,14 +36,12 @@ def exists(v): # replay memory dataset class ReplayMemoryDataset(Dataset): + @beartype def __init__( self, - instruction: str, folder: str = DEFAULT_REPLAY_MEMORIES_FOLDER, - num_timesteps = 1 + num_timesteps: int = 1 ): - self.instruction = instruction - assert num_timesteps >= 1 self.is_single_timestep = num_timesteps == 1 self.num_timesteps = num_timesteps @@ -50,11 +49,13 @@ def __init__( folder = Path(folder) assert folder.exists() and folder.is_dir() + text_embeds_path = folder / TEXT_EMBEDS_FILENAME states_path = folder / STATES_FILENAME actions_path = folder / ACTIONS_FILENAME rewards_path = folder / REWARDS_FILENAME dones_path = folder / DONES_FILENAME + self.text_embeds = open_memmap(str(text_embeds_path), dtype = 'float32', mode = 'r') self.states = open_memmap(str(states_path), dtype = 'float32', mode = 'r') self.actions = open_memmap(str(actions_path), dtype = 'int', mode = 'r') self.rewards = open_memmap(str(rewards_path), dtype = 'float32', mode = 'r') @@ -69,6 +70,7 @@ def __init__( trainable_episode_indices = self.episode_length >= num_timesteps + self.text_embeds = self.text_embeds[trainable_episode_indices] self.states = self.states[trainable_episode_indices] self.actions = self.actions[trainable_episode_indices] self.rewards = self.rewards[trainable_episode_indices] @@ -98,6 +100,7 @@ def __getitem__(self, idx): timestep_slice = slice(timestep_index, (timestep_index + self.num_timesteps)) + text_embeds = self.text_embeds[episode_index, timestep_slice] states = self.states[episode_index, timestep_slice] actions = self.actions[episode_index, timestep_slice] rewards = self.rewards[episode_index, timestep_slice] @@ -105,17 +108,19 @@ def __getitem__(self, idx): next_state = self.states[episode_index, min(timestep_index, self.max_episode_len - 1)] - return self.instruction, states, actions, next_state, rewards, dones + return text_embeds, states, actions, next_state, rewards, dones # base environment class to extend class BaseEnvironment(Module): def __init__( self, - state_shape: Tuple[int, ...] = () + state_shape: Tuple[int, ...] = tuple(), + text_embed_shape: Tuple[int, ...] = tuple() ): super().__init__() self.state_shape = state_shape + self.text_embed_shape = text_embed_shape self.register_buffer('dummy', torch.zeros(0), persistent = False) @property @@ -155,6 +160,8 @@ def __init__( self.q_transformer = q_transformer self.environment = environment + assert hasattr(environment, 'state_shape') and hasattr(environment, 'text_embed_shape') + assert 0. <= epsilon_start <= 1. assert 0. <= epsilon_end <= 1. assert epsilon_start >= epsilon_end @@ -173,6 +180,7 @@ def __init__( mem_path.mkdir(exist_ok = True, parents = True) assert mem_path.is_dir() + text_embeds_path = mem_path / TEXT_EMBEDS_FILENAME states_path = mem_path / STATES_FILENAME actions_path = mem_path / ACTIONS_FILENAME rewards_path = mem_path / REWARDS_FILENAME @@ -182,10 +190,14 @@ def __init__( num_actions = q_transformer.num_actions state_shape = environment.state_shape - self.states = open_memmap(str(states_path), dtype = 'float32', mode = 'w+', shape = (*prec_shape, *state_shape)) - self.actions = open_memmap(str(actions_path), dtype = 'int', mode = 'w+', shape = (*prec_shape, num_actions)) - self.rewards = open_memmap(str(rewards_path), dtype = 'float32', mode = 'w+', shape = prec_shape) - self.dones = open_memmap(str(dones_path), dtype = 'bool', mode = 'w+', shape = prec_shape) + text_embed_shape = environment.text_embed_shape + self.text_embed_shape = text_embed_shape + + self.text_embeds = open_memmap(str(text_embeds_path), dtype = 'float32', mode = 'w+', shape = (*prec_shape, *text_embed_shape)) + self.states = open_memmap(str(states_path), dtype = 'float32', mode = 'w+', shape = (*prec_shape, *state_shape)) + self.actions = open_memmap(str(actions_path), dtype = 'int', mode = 'w+', shape = (*prec_shape, num_actions)) + self.rewards = open_memmap(str(rewards_path), dtype = 'float32', mode = 'w+', shape = prec_shape) + self.dones = open_memmap(str(dones_path), dtype = 'bool', mode = 'w+', shape = prec_shape) def get_epsilon(self, step): return max(self.epsilon_end, self.epsilon_slope * float(step) + self.epsilon_start) @@ -205,9 +217,11 @@ def forward(self): epsilon = self.get_epsilon(step) + text_embed = self.q_transformer.embed_texts([instruction]) + actions = self.q_transformer.get_actions( rearrange(curr_state, '... -> 1 ...'), - [instruction], + text_embeds = text_embed, prob_random_action = epsilon ) @@ -217,10 +231,13 @@ def forward(self): # store memories using memmap, for later reflection and learning - self.states[episode, step] = curr_state - self.actions[episode, step] = actions - self.rewards[episode, step] = reward - self.dones[episode, step] = done + assert text_embed.shape[1:] == self.text_embed_shape + + self.text_embeds[episode, step] = text_embed + self.states[episode, step] = curr_state + self.actions[episode, step] = actions + self.rewards[episode, step] = reward + self.dones[episode, step] = done # if done, move onto next episode @@ -231,6 +248,7 @@ def forward(self): curr_state = next_state + self.text_embeds.flush() self.states.flush() self.actions.flush() self.rewards.flush() diff --git a/q_transformer/q_learner.py b/q_transformer/q_learner.py index 2670f32..3e30625 100644 --- a/q_transformer/q_learner.py +++ b/q_transformer/q_learner.py @@ -52,13 +52,6 @@ def default(val, d): def is_divisible(num, den): return (num % den) == 0 -def repeat_tuple_el(t: Tuple, i: int) -> Tuple: - out = [] - for el in t: - for _ in range(i): - out.append(el) - return tuple(out) - def pack_one(t, pattern): return pack([t], pattern) @@ -263,7 +256,7 @@ def get_discount_matrix(self, timestep): def q_learn( self, - instructions: Tuple[str], + text_embeds: TensorType['b', 'd', float], states: TensorType['b', 'c', 'f', 'h', 'w', float], actions: TensorType['b', int], next_states: TensorType['b', 'c', 'f', 'h', 'w', float], @@ -281,13 +274,13 @@ def q_learn( # first make a prediction with online q robotic transformer # select out the q-values for the action that was taken - q_pred_all_actions = self.model(states, instructions) + q_pred_all_actions = self.model(states, text_embeds = text_embeds) q_pred = batch_select_indices(q_pred_all_actions, actions) # use an exponentially smoothed copy of model for the future q target. more stable than setting q_target to q_eval after each batch # the max Q value is taken as the optimal action is implicitly the one with the highest Q score - q_next = self.ema_model(next_states, instructions).amax(dim = -1) + q_next = self.ema_model(next_states, text_embeds = text_embeds).amax(dim = -1) q_next.clamp_(min = default(monte_carlo_return, -1e4)) # Bellman's equation. most important line of code, hopefully done correctly @@ -305,7 +298,7 @@ def q_learn( def n_step_q_learn( self, - instructions: Tuple[str], + text_embeds: TensorType['b', 'd', float], states: TensorType['b', 't', 'c', 'f', 'h', 'w', float], actions: TensorType['b', 't', int], next_states: TensorType['b', 'c', 'f', 'h', 'w', float], @@ -326,6 +319,7 @@ def n_step_q_learn( t - timesteps a - action bins q - q values + d - text cond dimension """ num_timesteps, device = states.shape[1], states.device @@ -333,10 +327,11 @@ def n_step_q_learn( # fold time steps into batch states, time_ps = pack_one(states, '* c f h w') + text_embeds, _ = pack_one(text_embeds, '* d') - # repeat instructions per timestep + # repeat text embeds per timestep - repeated_instructions = repeat_tuple_el(instructions, num_timesteps) + repeated_text_embeds = repeat(text_embeds, 'b ... -> (b n) ...', n = num_timesteps) γ = self.discount_factor_gamma @@ -351,11 +346,11 @@ def n_step_q_learn( actions = rearrange(actions, 'b t -> (b t)') - q_pred_all_actions = self.model(states, repeated_instructions) + q_pred_all_actions = self.model(states, text_embeds = repeated_text_embeds) q_pred = batch_select_indices(q_pred_all_actions, actions) q_pred = unpack_one(q_pred, time_ps, '*') - q_next = self.ema_model(next_states, instructions).amax(dim = -1) + q_next = self.ema_model(next_states, text_embeds = text_embeds).amax(dim = -1) q_next.clamp_(min = default(monte_carlo_return, -1e4)) # prepare rewards and discount factors across timesteps @@ -380,7 +375,7 @@ def n_step_q_learn( def autoregressive_q_learn_handle_single_timestep( self, - instructions, + text_embeds, states, actions, next_states, @@ -405,11 +400,11 @@ def autoregressive_q_learn_handle_single_timestep( if dones.ndim == 1: dones = rearrange(dones, 'b -> b 1') - return self.autoregressive_q_learn(instructions, states, actions, next_states, rewards, dones, monte_carlo_return = monte_carlo_return) + return self.autoregressive_q_learn(text_embeds, states, actions, next_states, rewards, dones, monte_carlo_return = monte_carlo_return) def autoregressive_q_learn( self, - instructions: Tuple[str], + text_embeds: TensorType['b', 'd', float], states: TensorType['b', 't', 'c', 'f', 'h', 'w', float], actions: TensorType['b', 't', 'n', int], next_states: TensorType['b', 'c', 'f', 'h', 'w', float], @@ -431,6 +426,7 @@ def autoregressive_q_learn( n - number of actions a - action bins q - q values + d - text cond dimension """ monte_carlo_return = default(monte_carlo_return, -1e4) num_timesteps, device = states.shape[1], states.device @@ -439,10 +435,11 @@ def autoregressive_q_learn( states, time_ps = pack_one(states, '* c f h w') actions, _ = pack_one(actions, '* n') + text_embeds, _ = pack_one(text_embeds, '* d') - # repeat instructions per timestep + # repeat text embeds per timestep - repeated_instructions = repeat_tuple_el(instructions, num_timesteps) + repeated_text_embeds = repeat(text_embeds, 'b ... -> (b n) ...', n = num_timesteps) # anything after the first done flag will be considered terminal @@ -462,20 +459,20 @@ def autoregressive_q_learn( # get predicted Q for each action # unpack back to (b, t, n) - q_pred_all_actions = self.model(states, repeated_instructions, actions = actions) + q_pred_all_actions = self.model(states, text_embeds = repeated_text_embeds, actions = actions) q_pred = batch_select_indices(q_pred_all_actions, actions) q_pred = unpack_one(q_pred, time_ps, '* n') # get q_next - q_next = self.ema_model(next_states, instructions) + q_next = self.ema_model(next_states, text_embeds = text_embeds) q_next = q_next.max(dim = -1).values q_next.clamp_(min = monte_carlo_return) # get target Q # unpack back to - (b, t, n) - q_target_all_actions = self.ema_model(states, repeated_instructions, actions = actions) + q_target_all_actions = self.ema_model(states, text_embeds = repeated_text_embeds, actions = actions) q_target = q_target_all_actions.max(dim = -1).values q_target.clamp_(min = monte_carlo_return) diff --git a/setup.py b/setup.py index eff5cf0..53258e0 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'q-transformer', packages = find_packages(exclude=[]), - version = '0.0.45', + version = '0.1.0', license='MIT', description = 'Q-Transformer', author = 'Phil Wang', @@ -20,7 +20,7 @@ install_requires=[ 'accelerate', 'beartype', - 'classifier-free-guidance-pytorch>=0.4.0', + 'classifier-free-guidance-pytorch>=0.4.1', 'einops>=0.7.0', 'ema-pytorch>=0.3.1', 'numpy',