Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Mar 5, 2024
1 parent a92de2c commit ada2b4d
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 8 deletions.
101 changes: 98 additions & 3 deletions test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
from torchrl.data.datasets.roboset import RobosetExperienceReplay
from torchrl.data.datasets.vd4rl import VD4RLExperienceReplay
from torchrl.data.replay_buffers import SamplerWithoutReplacement
from torchrl.data.utils import CloudpickleWrapper
from torchrl.envs import (
CatTensors,
Compose,
Expand Down Expand Up @@ -2627,19 +2628,36 @@ def test_atari_preproc(self, dataset_id):
num_slices=8,
batch_size=64,
num_procs=max(0, os.cpu_count() - 4),
download="force",
)

t = Compose(
UnsqueezeTransform(unsqueeze_dim=-3, in_keys=["observation", ("next", "observation")]),
UnsqueezeTransform(
unsqueeze_dim=-3, in_keys=["observation", ("next", "observation")]
),
Resize(32, in_keys=["observation", ("next", "observation")]),
RenameTransform(in_keys=["action"], out_keys=["other_action"]),
)

def preproc(data):
return t(data)

dataset.preprocess(preproc, num_workers=max(1, os.cpu_count()-4), num_chunks=1000, mp_start_method="fork", pbar=True)
print(dataset)
dataset.preprocess(
preproc,
num_workers=max(1, os.cpu_count() - 4),
num_chunks=1000,
mp_start_method="fork",
pbar=True,
)

dataset = AtariDQNExperienceReplay(
dataset_id,
slice_len=None,
num_slices=8,
batch_size=64,
num_procs=max(0, os.cpu_count() - 4),
download=True,
)


@pytest.mark.slow
Expand Down Expand Up @@ -2754,6 +2772,83 @@ def test_openx(
elif num_slices is not None:
assert sample.get(("next", "done")).sum() == num_slices

def test_openx_preproc(self):
dataset = OpenXExperienceReplay(
"cmu_stretch",
download="force",
streaming=False,
batch_size=64,
shuffle=True,
num_slices=8,
slice_len=None,
)
from torchrl.envs import Compose, RenameTransform, Resize

t = Compose(
Resize(
64,
64,
in_keys=[("observation", "image"), ("next", "observation", "image")],
),
RenameTransform(
in_keys=[
("observation", "image"),
("next", "observation", "image"),
("observation", "state"),
("next", "observation", "state"),
],
out_keys=["pixels", ("next", "pixels"), "state", ("next", "state")],
),
)

def fn(data: TensorDict):
data.unlock_()
data = data.select(
"action",
"done",
"episode",
("next", "done"),
("next", "observation"),
("next", "reward"),
("next", "terminated"),
("next", "truncated"),
"observation",
"terminated",
"truncated",
)
data = t(data)
data = data.select(*data.keys(True, True))
return data

dataset.preprocess(
CloudpickleWrapper(fn),
num_workers=max(1, os.cpu_count() - 2),
num_chunks=500,
mp_start_method="fork",
)
sample = dataset.sample(32)
assert "observation" not in sample.keys()
assert "pixels" in sample.keys()
assert ("next", "pixels") in sample.keys(True)
assert "state" in sample.keys()
assert ("next", "state") in sample.keys(True)
assert sample["pixels"].shape == torch.Size([32, 3, 64, 64])
dataset = OpenXExperienceReplay(
"cmu_stretch",
download=True,
streaming=False,
batch_size=64,
shuffle=True,
num_slices=8,
slice_len=None,
)
sample = dataset.sample(32)
assert "observation" not in sample.keys()
assert "pixels" in sample.keys()
assert ("next", "pixels") in sample.keys(True)
assert "state" in sample.keys()
assert ("next", "state") in sample.keys(True)


@pytest.mark.skipif(not _has_sklearn, reason="Scikit-learn not found")
@pytest.mark.parametrize(
Expand Down
19 changes: 14 additions & 5 deletions torchrl/data/datasets/atari_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import importlib.util
import io
import json
from torchrl.data.utils import CloudpickleWrapper
import os
import shutil
import subprocess
Expand All @@ -33,6 +32,7 @@
)
from torchrl.data.replay_buffers.storages import Storage, TensorStorage
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter
from torchrl.data.utils import CloudpickleWrapper
from torchrl.envs.utils import _classproperty


Expand Down Expand Up @@ -535,7 +535,7 @@ def _download_and_preproc(self):
dataset_path=self.dataset_path,
total_episodes=total_runs,
max_runs=self._max_runs,
multithreaded=True,
multithreaded=True,
)
else:
func = functools.partial(
Expand All @@ -560,7 +560,15 @@ def _download_and_preproc(self):

@classmethod
def _download_and_proc_split(
cls, run, run_files, *, tempdir, dataset_path, total_episodes, max_runs, multithreaded=True
cls,
run,
run_files,
*,
tempdir,
dataset_path,
total_episodes,
max_runs,
multithreaded=True,
):
if (max_runs is not None) and (run >= max_runs):
return
Expand Down Expand Up @@ -757,7 +765,8 @@ def get_folders(path):
frames_per_split[:, 1] = frames_per_split[:, 1].cumsum(0)
self.frames_per_split = torch.cat(
# [torch.tensor([[-1, 0]]), frames_per_split], 0
[torch.tensor([[-1, 0]]), frames_per_split], 0
[torch.tensor([[-1, 0]]), frames_per_split],
0,
)

# retrieve episodes
Expand All @@ -783,7 +792,7 @@ def _read_from_splits(self, item: int | torch.Tensor):
else:
is_int = False
split = (item < self.frames_per_split[1:, 1].unsqueeze(1)) & (
item >= self.frames_per_split[:-1, 1].unsqueeze(1)
item >= self.frames_per_split[:-1, 1].unsqueeze(1)
)
# split_tmp, idx = split.squeeze().nonzero().unbind(-1)
split_tmp, idx = split.nonzero().unbind(-1)
Expand Down

0 comments on commit ada2b4d

Please sign in to comment.