From ada2b4dd0c2049503dfa98fc7a3b56a9eddbf77e Mon Sep 17 00:00:00 2001
From: vmoens <vincentmoens@gmail.com>
Date: Tue, 5 Mar 2024 13:05:00 +0000
Subject: [PATCH] amend

---
 test/test_libs.py                  | 101 ++++++++++++++++++++++++++++-
 torchrl/data/datasets/atari_dqn.py |  19 ++++--
 2 files changed, 112 insertions(+), 8 deletions(-)

diff --git a/test/test_libs.py b/test/test_libs.py
index 2dfad22bae3..81c6254b9dd 100644
--- a/test/test_libs.py
+++ b/test/test_libs.py
@@ -73,6 +73,7 @@
 from torchrl.data.datasets.roboset import RobosetExperienceReplay
 from torchrl.data.datasets.vd4rl import VD4RLExperienceReplay
 from torchrl.data.replay_buffers import SamplerWithoutReplacement
+from torchrl.data.utils import CloudpickleWrapper
 from torchrl.envs import (
     CatTensors,
     Compose,
@@ -2627,10 +2628,13 @@ def test_atari_preproc(self, dataset_id):
             num_slices=8,
             batch_size=64,
             num_procs=max(0, os.cpu_count() - 4),
+            download="force",
         )
 
         t = Compose(
-            UnsqueezeTransform(unsqueeze_dim=-3, in_keys=["observation", ("next", "observation")]),
+            UnsqueezeTransform(
+                unsqueeze_dim=-3, in_keys=["observation", ("next", "observation")]
+            ),
             Resize(32, in_keys=["observation", ("next", "observation")]),
             RenameTransform(in_keys=["action"], out_keys=["other_action"]),
         )
@@ -2638,8 +2642,22 @@ def test_atari_preproc(self, dataset_id):
         def preproc(data):
             return t(data)
 
-        dataset.preprocess(preproc, num_workers=max(1, os.cpu_count()-4), num_chunks=1000, mp_start_method="fork", pbar=True)
-        print(dataset)
+        dataset.preprocess(
+            preproc,
+            num_workers=max(1, os.cpu_count() - 4),
+            num_chunks=1000,
+            mp_start_method="fork",
+            pbar=True,
+        )
+
+        dataset = AtariDQNExperienceReplay(
+            dataset_id,
+            slice_len=None,
+            num_slices=8,
+            batch_size=64,
+            num_procs=max(0, os.cpu_count() - 4),
+            download=True,
+        )
 
 
 @pytest.mark.slow
@@ -2754,6 +2772,83 @@ def test_openx(
         elif num_slices is not None:
             assert sample.get(("next", "done")).sum() == num_slices
 
+    def test_openx_preproc(self):
+        dataset = OpenXExperienceReplay(
+            "cmu_stretch",
+            download="force",
+            streaming=False,
+            batch_size=64,
+            shuffle=True,
+            num_slices=8,
+            slice_len=None,
+        )
+        from torchrl.envs import Compose, RenameTransform, Resize
+
+        t = Compose(
+            Resize(
+                64,
+                64,
+                in_keys=[("observation", "image"), ("next", "observation", "image")],
+            ),
+            RenameTransform(
+                in_keys=[
+                    ("observation", "image"),
+                    ("next", "observation", "image"),
+                    ("observation", "state"),
+                    ("next", "observation", "state"),
+                ],
+                out_keys=["pixels", ("next", "pixels"), "state", ("next", "state")],
+            ),
+        )
+
+        def fn(data: TensorDict):
+            data.unlock_()
+            data = data.select(
+                "action",
+                "done",
+                "episode",
+                ("next", "done"),
+                ("next", "observation"),
+                ("next", "reward"),
+                ("next", "terminated"),
+                ("next", "truncated"),
+                "observation",
+                "terminated",
+                "truncated",
+            )
+            data = t(data)
+            data = data.select(*data.keys(True, True))
+            return data
+
+        dataset.preprocess(
+            CloudpickleWrapper(fn),
+            num_workers=max(1, os.cpu_count() - 2),
+            num_chunks=500,
+            mp_start_method="fork",
+        )
+        sample = dataset.sample(32)
+        assert "observation" not in sample.keys()
+        assert "pixels" in sample.keys()
+        assert ("next", "pixels") in sample.keys(True)
+        assert "state" in sample.keys()
+        assert ("next", "state") in sample.keys(True)
+        assert sample["pixels"].shape == torch.Size([32, 3, 64, 64])
+        dataset = OpenXExperienceReplay(
+            "cmu_stretch",
+            download=True,
+            streaming=False,
+            batch_size=64,
+            shuffle=True,
+            num_slices=8,
+            slice_len=None,
+        )
+        sample = dataset.sample(32)
+        assert "observation" not in sample.keys()
+        assert "pixels" in sample.keys()
+        assert ("next", "pixels") in sample.keys(True)
+        assert "state" in sample.keys()
+        assert ("next", "state") in sample.keys(True)
+
 
 @pytest.mark.skipif(not _has_sklearn, reason="Scikit-learn not found")
 @pytest.mark.parametrize(
diff --git a/torchrl/data/datasets/atari_dqn.py b/torchrl/data/datasets/atari_dqn.py
index 0cc25ad396b..ab8b1c00e93 100644
--- a/torchrl/data/datasets/atari_dqn.py
+++ b/torchrl/data/datasets/atari_dqn.py
@@ -10,7 +10,6 @@
 import importlib.util
 import io
 import json
-from torchrl.data.utils import CloudpickleWrapper
 import os
 import shutil
 import subprocess
@@ -33,6 +32,7 @@
 )
 from torchrl.data.replay_buffers.storages import Storage, TensorStorage
 from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter
+from torchrl.data.utils import CloudpickleWrapper
 from torchrl.envs.utils import _classproperty
 
 
@@ -535,7 +535,7 @@ def _download_and_preproc(self):
                             dataset_path=self.dataset_path,
                             total_episodes=total_runs,
                             max_runs=self._max_runs,
-                        multithreaded=True,
+                            multithreaded=True,
                         )
                 else:
                     func = functools.partial(
@@ -560,7 +560,15 @@ def _download_and_preproc(self):
 
     @classmethod
     def _download_and_proc_split(
-        cls, run, run_files, *, tempdir, dataset_path, total_episodes, max_runs, multithreaded=True
+        cls,
+        run,
+        run_files,
+        *,
+        tempdir,
+        dataset_path,
+        total_episodes,
+        max_runs,
+        multithreaded=True,
     ):
         if (max_runs is not None) and (run >= max_runs):
             return
@@ -757,7 +765,8 @@ def get_folders(path):
         frames_per_split[:, 1] = frames_per_split[:, 1].cumsum(0)
         self.frames_per_split = torch.cat(
             # [torch.tensor([[-1, 0]]), frames_per_split], 0
-            [torch.tensor([[-1, 0]]), frames_per_split], 0
+            [torch.tensor([[-1, 0]]), frames_per_split],
+            0,
         )
 
         # retrieve episodes
@@ -783,7 +792,7 @@ def _read_from_splits(self, item: int | torch.Tensor):
         else:
             is_int = False
         split = (item < self.frames_per_split[1:, 1].unsqueeze(1)) & (
-                item >= self.frames_per_split[:-1, 1].unsqueeze(1)
+            item >= self.frames_per_split[:-1, 1].unsqueeze(1)
         )
         # split_tmp, idx = split.squeeze().nonzero().unbind(-1)
         split_tmp, idx = split.nonzero().unbind(-1)