Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] batched trajectories - SliceSampler compatibility #1775

Merged
merged 34 commits into from
Feb 22, 2024
Merged
Changes from 1 commit
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
ffcdf51
init
vmoens Jan 8, 2024
d567665
Merge remote-tracking branch 'origin/main' into extend-rb-dim1
vmoens Feb 6, 2024
a05eaae
amend
vmoens Feb 6, 2024
34fadac
Merge remote-tracking branch 'origin/main' into extend-rb-dim1
vmoens Feb 7, 2024
fe960fa
amend
vmoens Feb 7, 2024
91941ea
amend
vmoens Feb 7, 2024
5e06cd4
amend
vmoens Feb 7, 2024
2523ab0
amend
vmoens Feb 8, 2024
d1e8878
amend
vmoens Feb 12, 2024
238e9b3
amend
vmoens Feb 12, 2024
eb731a4
Merge remote-tracking branch 'origin/main' into extend-rb-dim1
vmoens Feb 13, 2024
cdd5f9a
amend
vmoens Feb 14, 2024
775ffb6
amend
vmoens Feb 14, 2024
01bcc79
amend
vmoens Feb 15, 2024
136ad70
amend
vmoens Feb 15, 2024
2fbb08b
Merge remote-tracking branch 'origin/main' into extend-rb-dim1
vmoens Feb 15, 2024
2c717c8
amend
vmoens Feb 15, 2024
94a9b61
amend
vmoens Feb 16, 2024
358a15f
amend
vmoens Feb 16, 2024
1495c37
amend
vmoens Feb 16, 2024
08e4084
Merge remote-tracking branch 'origin/main' into extend-rb-dim1
vmoens Feb 20, 2024
b049c49
Merge remote-tracking branch 'origin/main' into extend-rb-dim1
vmoens Feb 21, 2024
253565d
fix
vmoens Feb 21, 2024
fb9e562
fix
vmoens Feb 21, 2024
b349086
fix
vmoens Feb 21, 2024
9094798
fix
vmoens Feb 21, 2024
49f32b8
fix
vmoens Feb 21, 2024
9aa80ee
fix
vmoens Feb 21, 2024
ff535a6
fix
vmoens Feb 21, 2024
ee61ed5
Merge remote-tracking branch 'origin/main' into extend-rb-dim1
vmoens Feb 21, 2024
15f3d53
amend
vmoens Feb 21, 2024
d7118d0
amend
vmoens Feb 22, 2024
9612977
amend
vmoens Feb 22, 2024
36ea29f
amend
vmoens Feb 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix
vmoens committed Feb 21, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
commit 9aa80ee4d903d0ab0f8bedd7dee0f93ba82204fd
3 changes: 1 addition & 2 deletions torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
@@ -1350,7 +1350,7 @@ def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]
preceding_stop_idx = stop_idx[..., 0, None] - subtractive_idx[None, ...]
preceding_stop_idx = preceding_stop_idx.reshape(-1, 1)
preceding_stop_idx = torch.cat(
[preceding_stop_idx, stop_idx[:, 1:].repeat_interleave(seq_length - 1)], -1
[preceding_stop_idx, stop_idx[:, 1:].repeat_interleave(seq_length - 1, dim=0)], -1
)
if storage.ndim > 1:
# convert the 2d index into a flat one to accomodate the _sum_tree
@@ -1376,7 +1376,6 @@ def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]

# extends starting indices of each slice with sequence_length to get indices of all steps
index = self._tensor_slices_from_startend(seq_length, starts)
assert index.ndim == 2

# repeat the weight of each slice to match the number of steps
info["_weight"] = torch.repeat_interleave(info["_weight"], seq_length)