Skip to content

Commit

Permalink
agent now converts instructions to text embeddings, stored in memmap,…
Browse files Browse the repository at this point in the history
… and q-learner now trains off text embeddings
  • Loading branch information
lucidrains committed Dec 8, 2023
1 parent a1fbe3e commit 24624fc
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 42 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ I will be keeping around the logic for Q-learning on single action just for fina
- [x] `ReplayDataset` that takes in folder
- [x] 1 time step option
- [x] n-time steps
- [x] handle multiple instructions correctly

- [ ] show a simple end-to-end example, in the same style as all other repos

- [ ] consult some RL experts and figure out if there are any new headways into resolving <a href="https://www.cs.toronto.edu/~cebly/Papers/CONQUR_ICML_2020_camera_ready.pdf">delusional bias</a>

- [ ] handle multiple instructions correctly, but also handle no instructions, in case one wants to train a robot doing a single specialized task
- [ ] handle no instructions, leverage null conditioner in CFG library
- [ ] for exploration, allow for finely randomizing a subset of actions, and not all actions at once
- [ ] figure out if one can train with randomized orders of actions - order could be sent as a conditioning that is concatted or summed before attention layers
- [ ] offer an improvised variant where the first action token suggests the action ordering. all actions aren't made equal, and some may need to attend to past actions more than others
Expand Down
48 changes: 33 additions & 15 deletions q_transformer/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

# constants

TEXT_EMBEDS_FILENAME = 'text_embeds.memmap.npy'
STATES_FILENAME = 'states.memmap.npy'
ACTIONS_FILENAME = 'actions.memmap.npy'
REWARDS_FILENAME = 'rewards.memmap.npy'
Expand All @@ -35,26 +36,26 @@ def exists(v):
# replay memory dataset

class ReplayMemoryDataset(Dataset):
@beartype
def __init__(
self,
instruction: str,
folder: str = DEFAULT_REPLAY_MEMORIES_FOLDER,
num_timesteps = 1
num_timesteps: int = 1
):
self.instruction = instruction

assert num_timesteps >= 1
self.is_single_timestep = num_timesteps == 1
self.num_timesteps = num_timesteps

folder = Path(folder)
assert folder.exists() and folder.is_dir()

text_embeds_path = folder / TEXT_EMBEDS_FILENAME
states_path = folder / STATES_FILENAME
actions_path = folder / ACTIONS_FILENAME
rewards_path = folder / REWARDS_FILENAME
dones_path = folder / DONES_FILENAME

self.text_embeds = open_memmap(str(text_embeds_path), dtype = 'float32', mode = 'r')
self.states = open_memmap(str(states_path), dtype = 'float32', mode = 'r')
self.actions = open_memmap(str(actions_path), dtype = 'int', mode = 'r')
self.rewards = open_memmap(str(rewards_path), dtype = 'float32', mode = 'r')
Expand All @@ -69,6 +70,7 @@ def __init__(

trainable_episode_indices = self.episode_length >= num_timesteps

self.text_embeds = self.text_embeds[trainable_episode_indices]
self.states = self.states[trainable_episode_indices]
self.actions = self.actions[trainable_episode_indices]
self.rewards = self.rewards[trainable_episode_indices]
Expand Down Expand Up @@ -98,24 +100,27 @@ def __getitem__(self, idx):

timestep_slice = slice(timestep_index, (timestep_index + self.num_timesteps))

text_embeds = self.text_embeds[episode_index, timestep_slice]
states = self.states[episode_index, timestep_slice]
actions = self.actions[episode_index, timestep_slice]
rewards = self.rewards[episode_index, timestep_slice]
dones = self.dones[episode_index, timestep_slice]

next_state = self.states[episode_index, min(timestep_index, self.max_episode_len - 1)]

return self.instruction, states, actions, next_state, rewards, dones
return text_embeds, states, actions, next_state, rewards, dones

# base environment class to extend

class BaseEnvironment(Module):
def __init__(
self,
state_shape: Tuple[int, ...] = ()
state_shape: Tuple[int, ...] = tuple(),
text_embed_shape: Tuple[int, ...] = tuple()
):
super().__init__()
self.state_shape = state_shape
self.text_embed_shape = text_embed_shape
self.register_buffer('dummy', torch.zeros(0), persistent = False)

@property
Expand Down Expand Up @@ -155,6 +160,8 @@ def __init__(
self.q_transformer = q_transformer
self.environment = environment

assert hasattr(environment, 'state_shape') and hasattr(environment, 'text_embed_shape')

assert 0. <= epsilon_start <= 1.
assert 0. <= epsilon_end <= 1.
assert epsilon_start >= epsilon_end
Expand All @@ -173,6 +180,7 @@ def __init__(
mem_path.mkdir(exist_ok = True, parents = True)
assert mem_path.is_dir()

text_embeds_path = mem_path / TEXT_EMBEDS_FILENAME
states_path = mem_path / STATES_FILENAME
actions_path = mem_path / ACTIONS_FILENAME
rewards_path = mem_path / REWARDS_FILENAME
Expand All @@ -182,10 +190,14 @@ def __init__(
num_actions = q_transformer.num_actions
state_shape = environment.state_shape

self.states = open_memmap(str(states_path), dtype = 'float32', mode = 'w+', shape = (*prec_shape, *state_shape))
self.actions = open_memmap(str(actions_path), dtype = 'int', mode = 'w+', shape = (*prec_shape, num_actions))
self.rewards = open_memmap(str(rewards_path), dtype = 'float32', mode = 'w+', shape = prec_shape)
self.dones = open_memmap(str(dones_path), dtype = 'bool', mode = 'w+', shape = prec_shape)
text_embed_shape = environment.text_embed_shape
self.text_embed_shape = text_embed_shape

self.text_embeds = open_memmap(str(text_embeds_path), dtype = 'float32', mode = 'w+', shape = (*prec_shape, *text_embed_shape))
self.states = open_memmap(str(states_path), dtype = 'float32', mode = 'w+', shape = (*prec_shape, *state_shape))
self.actions = open_memmap(str(actions_path), dtype = 'int', mode = 'w+', shape = (*prec_shape, num_actions))
self.rewards = open_memmap(str(rewards_path), dtype = 'float32', mode = 'w+', shape = prec_shape)
self.dones = open_memmap(str(dones_path), dtype = 'bool', mode = 'w+', shape = prec_shape)

def get_epsilon(self, step):
return max(self.epsilon_end, self.epsilon_slope * float(step) + self.epsilon_start)
Expand All @@ -205,9 +217,11 @@ def forward(self):

epsilon = self.get_epsilon(step)

text_embed = self.q_transformer.embed_texts([instruction])

actions = self.q_transformer.get_actions(
rearrange(curr_state, '... -> 1 ...'),
[instruction],
text_embeds = text_embed,
prob_random_action = epsilon
)

Expand All @@ -217,10 +231,13 @@ def forward(self):

# store memories using memmap, for later reflection and learning

self.states[episode, step] = curr_state
self.actions[episode, step] = actions
self.rewards[episode, step] = reward
self.dones[episode, step] = done
assert text_embed.shape[1:] == self.text_embed_shape

self.text_embeds[episode, step] = text_embed
self.states[episode, step] = curr_state
self.actions[episode, step] = actions
self.rewards[episode, step] = reward
self.dones[episode, step] = done

# if done, move onto next episode

Expand All @@ -231,6 +248,7 @@ def forward(self):

curr_state = next_state

self.text_embeds.flush()
self.states.flush()
self.actions.flush()
self.rewards.flush()
Expand Down
43 changes: 20 additions & 23 deletions q_transformer/q_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,6 @@ def default(val, d):
def is_divisible(num, den):
return (num % den) == 0

def repeat_tuple_el(t: Tuple, i: int) -> Tuple:
out = []
for el in t:
for _ in range(i):
out.append(el)
return tuple(out)

def pack_one(t, pattern):
return pack([t], pattern)

Expand Down Expand Up @@ -263,7 +256,7 @@ def get_discount_matrix(self, timestep):

def q_learn(
self,
instructions: Tuple[str],
text_embeds: TensorType['b', 'd', float],
states: TensorType['b', 'c', 'f', 'h', 'w', float],
actions: TensorType['b', int],
next_states: TensorType['b', 'c', 'f', 'h', 'w', float],
Expand All @@ -281,13 +274,13 @@ def q_learn(
# first make a prediction with online q robotic transformer
# select out the q-values for the action that was taken

q_pred_all_actions = self.model(states, instructions)
q_pred_all_actions = self.model(states, text_embeds = text_embeds)
q_pred = batch_select_indices(q_pred_all_actions, actions)

# use an exponentially smoothed copy of model for the future q target. more stable than setting q_target to q_eval after each batch
# the max Q value is taken as the optimal action is implicitly the one with the highest Q score

q_next = self.ema_model(next_states, instructions).amax(dim = -1)
q_next = self.ema_model(next_states, text_embeds = text_embeds).amax(dim = -1)
q_next.clamp_(min = default(monte_carlo_return, -1e4))

# Bellman's equation. most important line of code, hopefully done correctly
Expand All @@ -305,7 +298,7 @@ def q_learn(

def n_step_q_learn(
self,
instructions: Tuple[str],
text_embeds: TensorType['b', 'd', float],
states: TensorType['b', 't', 'c', 'f', 'h', 'w', float],
actions: TensorType['b', 't', int],
next_states: TensorType['b', 'c', 'f', 'h', 'w', float],
Expand All @@ -326,17 +319,19 @@ def n_step_q_learn(
t - timesteps
a - action bins
q - q values
d - text cond dimension
"""

num_timesteps, device = states.shape[1], states.device

# fold time steps into batch

states, time_ps = pack_one(states, '* c f h w')
text_embeds, _ = pack_one(text_embeds, '* d')

# repeat instructions per timestep
# repeat text embeds per timestep

repeated_instructions = repeat_tuple_el(instructions, num_timesteps)
repeated_text_embeds = repeat(text_embeds, 'b ... -> (b n) ...', n = num_timesteps)

γ = self.discount_factor_gamma

Expand All @@ -351,11 +346,11 @@ def n_step_q_learn(

actions = rearrange(actions, 'b t -> (b t)')

q_pred_all_actions = self.model(states, repeated_instructions)
q_pred_all_actions = self.model(states, text_embeds = repeated_text_embeds)
q_pred = batch_select_indices(q_pred_all_actions, actions)
q_pred = unpack_one(q_pred, time_ps, '*')

q_next = self.ema_model(next_states, instructions).amax(dim = -1)
q_next = self.ema_model(next_states, text_embeds = text_embeds).amax(dim = -1)
q_next.clamp_(min = default(monte_carlo_return, -1e4))

# prepare rewards and discount factors across timesteps
Expand All @@ -380,7 +375,7 @@ def n_step_q_learn(

def autoregressive_q_learn_handle_single_timestep(
self,
instructions,
text_embeds,
states,
actions,
next_states,
Expand All @@ -405,11 +400,11 @@ def autoregressive_q_learn_handle_single_timestep(
if dones.ndim == 1:
dones = rearrange(dones, 'b -> b 1')

return self.autoregressive_q_learn(instructions, states, actions, next_states, rewards, dones, monte_carlo_return = monte_carlo_return)
return self.autoregressive_q_learn(text_embeds, states, actions, next_states, rewards, dones, monte_carlo_return = monte_carlo_return)

def autoregressive_q_learn(
self,
instructions: Tuple[str],
text_embeds: TensorType['b', 'd', float],
states: TensorType['b', 't', 'c', 'f', 'h', 'w', float],
actions: TensorType['b', 't', 'n', int],
next_states: TensorType['b', 'c', 'f', 'h', 'w', float],
Expand All @@ -431,6 +426,7 @@ def autoregressive_q_learn(
n - number of actions
a - action bins
q - q values
d - text cond dimension
"""
monte_carlo_return = default(monte_carlo_return, -1e4)
num_timesteps, device = states.shape[1], states.device
Expand All @@ -439,10 +435,11 @@ def autoregressive_q_learn(

states, time_ps = pack_one(states, '* c f h w')
actions, _ = pack_one(actions, '* n')
text_embeds, _ = pack_one(text_embeds, '* d')

# repeat instructions per timestep
# repeat text embeds per timestep

repeated_instructions = repeat_tuple_el(instructions, num_timesteps)
repeated_text_embeds = repeat(text_embeds, 'b ... -> (b n) ...', n = num_timesteps)

# anything after the first done flag will be considered terminal

Expand All @@ -462,20 +459,20 @@ def autoregressive_q_learn(
# get predicted Q for each action
# unpack back to (b, t, n)

q_pred_all_actions = self.model(states, repeated_instructions, actions = actions)
q_pred_all_actions = self.model(states, text_embeds = repeated_text_embeds, actions = actions)
q_pred = batch_select_indices(q_pred_all_actions, actions)
q_pred = unpack_one(q_pred, time_ps, '* n')

# get q_next

q_next = self.ema_model(next_states, instructions)
q_next = self.ema_model(next_states, text_embeds = text_embeds)
q_next = q_next.max(dim = -1).values
q_next.clamp_(min = monte_carlo_return)

# get target Q
# unpack back to - (b, t, n)

q_target_all_actions = self.ema_model(states, repeated_instructions, actions = actions)
q_target_all_actions = self.ema_model(states, text_embeds = repeated_text_embeds, actions = actions)
q_target = q_target_all_actions.max(dim = -1).values

q_target.clamp_(min = monte_carlo_return)
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'q-transformer',
packages = find_packages(exclude=[]),
version = '0.0.45',
version = '0.1.0',
license='MIT',
description = 'Q-Transformer',
author = 'Phil Wang',
Expand All @@ -20,7 +20,7 @@
install_requires=[
'accelerate',
'beartype',
'classifier-free-guidance-pytorch>=0.4.0',
'classifier-free-guidance-pytorch>=0.4.1',
'einops>=0.7.0',
'ema-pytorch>=0.3.1',
'numpy',
Expand Down

0 comments on commit 24624fc

Please sign in to comment.