From ada2b4dd0c2049503dfa98fc7a3b56a9eddbf77e Mon Sep 17 00:00:00 2001 From: vmoens <vincentmoens@gmail.com> Date: Tue, 5 Mar 2024 13:05:00 +0000 Subject: [PATCH] amend --- test/test_libs.py | 101 ++++++++++++++++++++++++++++- torchrl/data/datasets/atari_dqn.py | 19 ++++-- 2 files changed, 112 insertions(+), 8 deletions(-) diff --git a/test/test_libs.py b/test/test_libs.py index 2dfad22bae3..81c6254b9dd 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -73,6 +73,7 @@ from torchrl.data.datasets.roboset import RobosetExperienceReplay from torchrl.data.datasets.vd4rl import VD4RLExperienceReplay from torchrl.data.replay_buffers import SamplerWithoutReplacement +from torchrl.data.utils import CloudpickleWrapper from torchrl.envs import ( CatTensors, Compose, @@ -2627,10 +2628,13 @@ def test_atari_preproc(self, dataset_id): num_slices=8, batch_size=64, num_procs=max(0, os.cpu_count() - 4), + download="force", ) t = Compose( - UnsqueezeTransform(unsqueeze_dim=-3, in_keys=["observation", ("next", "observation")]), + UnsqueezeTransform( + unsqueeze_dim=-3, in_keys=["observation", ("next", "observation")] + ), Resize(32, in_keys=["observation", ("next", "observation")]), RenameTransform(in_keys=["action"], out_keys=["other_action"]), ) @@ -2638,8 +2642,22 @@ def test_atari_preproc(self, dataset_id): def preproc(data): return t(data) - dataset.preprocess(preproc, num_workers=max(1, os.cpu_count()-4), num_chunks=1000, mp_start_method="fork", pbar=True) - print(dataset) + dataset.preprocess( + preproc, + num_workers=max(1, os.cpu_count() - 4), + num_chunks=1000, + mp_start_method="fork", + pbar=True, + ) + + dataset = AtariDQNExperienceReplay( + dataset_id, + slice_len=None, + num_slices=8, + batch_size=64, + num_procs=max(0, os.cpu_count() - 4), + download=True, + ) @pytest.mark.slow @@ -2754,6 +2772,83 @@ def test_openx( elif num_slices is not None: assert sample.get(("next", "done")).sum() == num_slices + def test_openx_preproc(self): + dataset = OpenXExperienceReplay( + "cmu_stretch", + download="force", + streaming=False, + batch_size=64, + shuffle=True, + num_slices=8, + slice_len=None, + ) + from torchrl.envs import Compose, RenameTransform, Resize + + t = Compose( + Resize( + 64, + 64, + in_keys=[("observation", "image"), ("next", "observation", "image")], + ), + RenameTransform( + in_keys=[ + ("observation", "image"), + ("next", "observation", "image"), + ("observation", "state"), + ("next", "observation", "state"), + ], + out_keys=["pixels", ("next", "pixels"), "state", ("next", "state")], + ), + ) + + def fn(data: TensorDict): + data.unlock_() + data = data.select( + "action", + "done", + "episode", + ("next", "done"), + ("next", "observation"), + ("next", "reward"), + ("next", "terminated"), + ("next", "truncated"), + "observation", + "terminated", + "truncated", + ) + data = t(data) + data = data.select(*data.keys(True, True)) + return data + + dataset.preprocess( + CloudpickleWrapper(fn), + num_workers=max(1, os.cpu_count() - 2), + num_chunks=500, + mp_start_method="fork", + ) + sample = dataset.sample(32) + assert "observation" not in sample.keys() + assert "pixels" in sample.keys() + assert ("next", "pixels") in sample.keys(True) + assert "state" in sample.keys() + assert ("next", "state") in sample.keys(True) + assert sample["pixels"].shape == torch.Size([32, 3, 64, 64]) + dataset = OpenXExperienceReplay( + "cmu_stretch", + download=True, + streaming=False, + batch_size=64, + shuffle=True, + num_slices=8, + slice_len=None, + ) + sample = dataset.sample(32) + assert "observation" not in sample.keys() + assert "pixels" in sample.keys() + assert ("next", "pixels") in sample.keys(True) + assert "state" in sample.keys() + assert ("next", "state") in sample.keys(True) + @pytest.mark.skipif(not _has_sklearn, reason="Scikit-learn not found") @pytest.mark.parametrize( diff --git a/torchrl/data/datasets/atari_dqn.py b/torchrl/data/datasets/atari_dqn.py index 0cc25ad396b..ab8b1c00e93 100644 --- a/torchrl/data/datasets/atari_dqn.py +++ b/torchrl/data/datasets/atari_dqn.py @@ -10,7 +10,6 @@ import importlib.util import io import json -from torchrl.data.utils import CloudpickleWrapper import os import shutil import subprocess @@ -33,6 +32,7 @@ ) from torchrl.data.replay_buffers.storages import Storage, TensorStorage from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter +from torchrl.data.utils import CloudpickleWrapper from torchrl.envs.utils import _classproperty @@ -535,7 +535,7 @@ def _download_and_preproc(self): dataset_path=self.dataset_path, total_episodes=total_runs, max_runs=self._max_runs, - multithreaded=True, + multithreaded=True, ) else: func = functools.partial( @@ -560,7 +560,15 @@ def _download_and_preproc(self): @classmethod def _download_and_proc_split( - cls, run, run_files, *, tempdir, dataset_path, total_episodes, max_runs, multithreaded=True + cls, + run, + run_files, + *, + tempdir, + dataset_path, + total_episodes, + max_runs, + multithreaded=True, ): if (max_runs is not None) and (run >= max_runs): return @@ -757,7 +765,8 @@ def get_folders(path): frames_per_split[:, 1] = frames_per_split[:, 1].cumsum(0) self.frames_per_split = torch.cat( # [torch.tensor([[-1, 0]]), frames_per_split], 0 - [torch.tensor([[-1, 0]]), frames_per_split], 0 + [torch.tensor([[-1, 0]]), frames_per_split], + 0, ) # retrieve episodes @@ -783,7 +792,7 @@ def _read_from_splits(self, item: int | torch.Tensor): else: is_int = False split = (item < self.frames_per_split[1:, 1].unsqueeze(1)) & ( - item >= self.frames_per_split[:-1, 1].unsqueeze(1) + item >= self.frames_per_split[:-1, 1].unsqueeze(1) ) # split_tmp, idx = split.squeeze().nonzero().unbind(-1) split_tmp, idx = split.nonzero().unbind(-1)