Vectorized Trajectory Management #2660
Replies: 3 comments
-
I've got a band aid solution: # Sort buffer so episodes are contiguous. Lets us sample continuous trajectories.
_, sorted_indices = torch.sort(buffer.storage._storage[:len(buffer)]["episode"])
buffer.storage._storage[:len(buffer)] = buffer.storage._storage[:len(buffer)][sorted_indices]
sample = buffer.sample() Although if there's something faster that'd be awesome. |
Beta Was this translation helpful? Give feedback.
-
Hey! Essentially what you want to do is store every trajectory coming from each env one after the other (ie, something like this: the trajectories coming in are
and we want to concat them along dim=-1 to get traj = tensor([
[0, 0, 0, 2, 2, 2, 2, 5, 5, 5, 5, 7, 7, 7, 7, 9],
[1, 1, 1, 1, 3, 3, 4, 4, 4, 6, 6, 6, 8, 8, 8, 8],
]) ) The core thing is to tell your storage that you have 2 dims ... storage=LazyTensorStorage(100, ndim=2), When you call extend, just pass a tensordict of shape Then you can use a SliceSampler and it'll know how the trajectories are stored and sample things accordingly. LMK if anything's unclear! |
Beta Was this translation helpful? Give feedback.
-
oh nice okay! I missed that part of the doc. I think this part confused me a bit:
I'm not sure full understand what that means still, and dim_extend isn't mentioned elsewhere. Although what you suggested above works! action = torch.tensor(env.action_space.sample(), device=device)
next_obs, reward, terminated, truncated, info = env.step(action)
tdlist.append(TensorDict(
batch_size=cfg.training.num_envs,
episode=episode_idx,
obs=obs, action=action, reward=reward, next_obs=next_obs, truncated=truncated, terminated=terminated,
steps = info['elapsed_steps'],
))
obs = next_obs
I store the tds in a list, where episode looks like [1, 2, 3, 4, 5, 6, 7, 8, 9], and then when there's any done signals I increment and add the list to the buffer. dones = torch.logical_or(terminated, truncated)
if dones.any():
reset_obs, info = env.reset(options={'env_idx': torch.nonzero(dones).squeeze()})
obs = torch.where(dones.unsqueeze(-1), reset_obs, obs)
print(f"Resetting {dones.sum()} environments")
trajectories = torch.stack(tdlist, dim=0).transpose(0, 1)
buffer.extend(trajectories)
tdlist.clear()
# Update episode_idx for done episodes to be a new unique id
update_indices = torch.nonzero(dones).squeeze()
next_values = torch.arange(episode_idx.max() + 1, episode_idx.max() + 1 + len(update_indices), device=device)
episode_idx[update_indices] = next_values That works! Should I just extend instead of accumulating for speed? |
Beta Was this translation helpful? Give feedback.
-
Hi -- I'm unsure how to neatly manage vectorized trajectories with a replay buffer.
I was thinking something like this:
But it seems like the trajectories need to be contiguous. Since my environments might not end at the same timestep, is there a better way to manage that in a vectorized/fast way?
Beta Was this translation helpful? Give feedback.
All reactions