Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Mar 5, 2024
1 parent f0d7c2f commit 8323736
Show file tree
Hide file tree
Showing 11 changed files with 312 additions and 42 deletions.
41 changes: 41 additions & 0 deletions docs/source/reference/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -658,12 +658,53 @@ Here's an example:
the latest wheels are not published on PyPI. For OpenML, `scikit-learn <https://pypi.org/project/scikit-learn/>`_ and
`pandas <https://pypi.org/project/pandas>`_ are required.

Transforming datasets
~~~~~~~~~~~~~~~~~~~~~

In many instances, the raw data isn't going to be used as-is.
The natural solution could be to pass a :class:`~torchrl.envs.transforms.Transform`
instance to the dataset constructor and modify the sample on-the-fly. This will
work but it will incur an extra runtime for the transform.
If the transformations can be (at least a part) pre-applied to the dataset,
a conisderable disk space and some incurred overhead at sampling time can be
saved. To do this, the
:meth:`~torchrl.data.datasets.BaseDatasetExperienceReplay.preprocess` can be
used. This method will run a per-sample preprocessing pipeline on each element
of the dataset, and replace the existing dataset by its transformed version.

Once transformed, re-creating the same dataset will produce another object with
the same transformed storage (unless ``download="force"`` is being used):

>>> dataset = RobosetExperienceReplay(
... "FK1-v4(expert)/FK1_MicroOpenRandom_v2d-v4", batch_size=32, download="force"
... )
>>>
>>> def func(data):
... return data.set("obs_norm", data.get("observation").norm(dim=-1))
...
>>> dataset.preprocess(
... func,
... num_workers=max(1, os.cpu_count() - 2),
... num_chunks=1000,
... mp_start_method="fork",
... )
>>> sample = dataset.sample()
>>> assert "obs_norm" in sample.keys()
>>> # re-recreating the dataset gives us the transformed version back.
>>> dataset = RobosetExperienceReplay(
... "FK1-v4(expert)/FK1_MicroOpenRandom_v2d-v4", batch_size=32
... )
>>> sample = dataset.sample()
>>> assert "obs_norm" in sample.keys()


.. currentmodule:: torchrl.data.datasets

.. autosummary::
:toctree: generated/
:template: rl_template.rst

BaseDatasetExperienceReplay
AtariDQNExperienceReplay
D4RLExperienceReplay
GenDGRLExperienceReplay
Expand Down
171 changes: 169 additions & 2 deletions test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2278,6 +2278,31 @@ def new_get_category_len(cls, category_name):
yield
GenDGRLExperienceReplay._get_category_len = _get_category_len

@pytest.mark.parametrize("dataset_num", [4])
def test_gen_dgrl_preproc(self, dataset_num, tmpdir, _patch_traj_len):
dataset_id = GenDGRLExperienceReplay.available_datasets[dataset_num]
dataset = GenDGRLExperienceReplay(
dataset_id, batch_size=32, root=tmpdir, download="force"
)
from torchrl.envs import Compose, GrayScale, Resize

t = Compose(
Resize(32, in_keys=["observation", ("next", "observation")]),
GrayScale(in_keys=["observation", ("next", "observation")]),
)

def fn(data):
return t(data)

dataset.preprocess(fn, num_workers=max(1, os.cpu_count() - 2), num_chunks=1000)
sample = dataset.sample()
assert sample["observation"].shape == torch.Size([32, 1, 32, 32])
assert sample["next", "observation"].shape == torch.Size([32, 1, 32, 32])
dataset = GenDGRLExperienceReplay(dataset_id, batch_size=32, root=tmpdir)
sample = dataset.sample()
assert sample["observation"].shape == torch.Size([32, 1, 32, 32])
assert sample["next", "observation"].shape == torch.Size([32, 1, 32, 32])

@pytest.mark.parametrize("dataset_num", [0, 4, 8])
def test_gen_dgrl(self, dataset_num, tmpdir, _patch_traj_len):
dataset_id = GenDGRLExperienceReplay.available_datasets[dataset_num]
Expand Down Expand Up @@ -2310,6 +2335,51 @@ def test_gen_dgrl(self, dataset_num, tmpdir, _patch_traj_len):
@pytest.mark.skipif(not _has_d4rl, reason="D4RL not found")
@pytest.mark.slow
class TestD4RL:
def test_d4rl_preproc(self, tmpdir):
dataset_id = "walker2d-medium-replay-v2"
dataset = D4RLExperienceReplay(
dataset_id,
batch_size=32,
root=tmpdir,
download="force",
direct_download=True,
)
from torchrl.envs import CatTensors, Compose

t = Compose(
CatTensors(
in_keys=["observation", ("info", "qpos"), ("info", "qvel")],
out_key="data",
),
CatTensors(
in_keys=[
("next", "observation"),
("next", "info", "qpos"),
("next", "info", "qvel"),
],
out_key=("next", "data"),
),
)

def fn(data):
return t(data)

dataset.preprocess(
fn,
num_workers=max(1, os.cpu_count() - 2),
num_chunks=1000,
mp_start_method="fork",
)
sample = dataset.sample()
assert sample["data"].shape == torch.Size([32, 35])
assert sample["next", "data"].shape == torch.Size([32, 35])
dataset = D4RLExperienceReplay(
dataset_id, batch_size=32, root=tmpdir, direct_download=True
)
sample = dataset.sample()
assert sample["data"].shape == torch.Size([32, 35])
assert sample["next", "data"].shape == torch.Size([32, 35])

@pytest.mark.parametrize("task", ["walker2d-medium-replay-v2"])
@pytest.mark.parametrize("use_truncated_as_done", [True, False])
@pytest.mark.parametrize("split_trajs", [True, False])
Expand Down Expand Up @@ -2496,10 +2566,10 @@ def _minari_selected_datasets():


@pytest.mark.skipif(not _has_minari or not _has_gymnasium, reason="Minari not found")
@pytest.mark.parametrize("split", [False, True])
@pytest.mark.parametrize("selected_dataset", _MINARI_DATASETS)
@pytest.mark.slow
class TestMinari:
@pytest.mark.parametrize("split", [False, True])
@pytest.mark.parametrize("selected_dataset", _MINARI_DATASETS)
def test_load(self, selected_dataset, split):
torchrl_logger.info(f"dataset {selected_dataset}")
data = MinariExperienceReplay(
Expand All @@ -2515,6 +2585,55 @@ def test_load(self, selected_dataset, split):
if i == 10:
break

def test_minari_preproc(self, tmpdir):
selected_dataset = _MINARI_DATASETS[0]
dataset_id = "walker2d-medium-replay-v2"
dataset = MinariExperienceReplay(
selected_dataset, batch_size=32, split_trajs=False, download="force"
)

from torchrl.envs import CatTensors, Compose

t = Compose(
CatTensors(
in_keys=[
("observation", "observation"),
("info", "qpos"),
("info", "qvel"),
],
out_key="data",
),
CatTensors(
in_keys=[
("next", "observation", "observation"),
("next", "info", "qpos"),
("next", "info", "qvel"),
],
out_key=("next", "data"),
),
)

def fn(data):
return t(data)

dataset.preprocess(
fn,
num_workers=max(1, os.cpu_count() - 2),
num_chunks=1000,
mp_start_method="fork",
)
sample = dataset.sample()
assert sample["data"].shape == torch.Size([32, 8])
assert sample["next", "data"].shape == torch.Size([32, 8])
dataset = MinariExperienceReplay(
selected_dataset,
batch_size=32,
split_trajs=False,
)
sample = dataset.sample()
assert sample["data"].shape == torch.Size([32, 8])
assert sample["next", "data"].shape == torch.Size([32, 8])


@pytest.mark.slow
class TestRoboset:
Expand All @@ -2532,6 +2651,28 @@ def test_load(self):
if i == 10:
break

def test_roboset_preproc(self):
dataset = RobosetExperienceReplay(
"FK1-v4(expert)/FK1_MicroOpenRandom_v2d-v4", batch_size=32, download="force"
)

def func(data):
return data.set("obs_norm", data.get("observation").norm(dim=-1))

dataset.preprocess(
func,
num_workers=max(1, os.cpu_count() - 2),
num_chunks=1000,
mp_start_method="fork",
)
sample = dataset.sample()
assert "obs_norm" in sample.keys()
dataset = RobosetExperienceReplay(
"FK1-v4(expert)/FK1_MicroOpenRandom_v2d-v4", batch_size=32
)
sample = dataset.sample()
assert "obs_norm" in sample.keys()


@pytest.mark.slow
class TestVD4RL:
Expand Down Expand Up @@ -2565,6 +2706,32 @@ def test_load(self, image_size):
if i == 10:
break

def test_vd4rl_preproc(self):
torch.manual_seed(0)
datasets = VD4RLExperienceReplay.available_datasets
dataset_id = list(datasets)[4]
dataset = VD4RLExperienceReplay(dataset_id, batch_size=32, download="force")
from torchrl.envs import Compose, GrayScale, ToTensorImage

func = Compose(
ToTensorImage(in_keys=["pixels", ("next", "pixels")]),
GrayScale(in_keys=["pixels", ("next", "pixels")]),
)
dataset.preprocess(
func,
num_workers=max(1, os.cpu_count() - 2),
num_chunks=1000,
mp_start_method="fork",
)
sample = dataset.sample()
assert sample["next", "pixels"].shape == torch.Size([32, 1, 64, 64])
dataset = VD4RLExperienceReplay(
dataset_id,
batch_size=32,
)
sample = dataset.sample()
assert sample["next", "pixels"].shape == torch.Size([32, 1, 64, 64])


@pytest.mark.slow
class TestAtariDQN:
Expand Down
6 changes: 6 additions & 0 deletions torchrl/data/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from .atari_dqn import AtariDQNExperienceReplay
from .common import BaseDatasetExperienceReplay
from .d4rl import D4RLExperienceReplay
from .gen_dgrl import GenDGRLExperienceReplay
from .minari_data import MinariExperienceReplay
Expand Down
2 changes: 1 addition & 1 deletion torchrl/data/datasets/atari_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class AtariDQNExperienceReplay(BaseDatasetExperienceReplay):
Has no effect whenever the data is already downloaded. Defaults to 0
(no multiprocessing used).
download (bool or str, optional): Whether the dataset should be downloaded if
not found. Defaults to ``True``. Download can also be passed as "force",
not found. Defaults to ``True``. Download can also be passed as ``"force"``,
in which case the downloaded data will be overwritten.
sampler (Sampler, optional): the sampler to be used. If none is provided
a default RandomSampler() will be used.
Expand Down
74 changes: 49 additions & 25 deletions torchrl/data/datasets/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from __future__ import annotations

import abc
import pickle
import shutil
import tempfile
from pathlib import Path
from typing import Callable

Expand All @@ -14,21 +16,29 @@
from torch import multiprocessing as mp

from torchrl.data.replay_buffers import TensorDictReplayBuffer, TensorStorage
from torchrl.data.utils import CloudpickleWrapper


class BaseDatasetExperienceReplay(TensorDictReplayBuffer):
"""Parent class for offline datasets."""

@abc.abstractproperty
@property
@abc.abstractmethod
def data_path(self) -> Path:
"""Path to the dataset, including split."""
...

@abc.abstractproperty
@property
@abc.abstractmethod
def data_path_root(self) -> Path:
"""Path to the dataset root."""
...

@abc.abstractmethod
def _is_downloaded(self) -> bool:
"""Checks if the data has been downloaded."""
...

@property
def root(self) -> Path:
return self._root
Expand Down Expand Up @@ -210,29 +220,43 @@ def preprocess(
collate_fn=<function _collate_id at 0x120e21dc0>)
"""
if not _can_be_pickled(fn):
fn = CloudpickleWrapper(fn)
if isinstance(self._storage, TensorStorage):
example_data = fn(self._storage[0])
mock_folder = self.data_path.parent / "mock"
mmlike = example_data.expand(
(self._storage.shape[0], *example_data.shape)
).memmap_like(mock_folder, num_threads=32)
self._storage._storage.map(
fn=fn,
dim=dim,
num_workers=num_workers,
chunksize=chunksize,
num_chunks=num_chunks,
pool=pool,
generator=generator,
max_tasks_per_child=max_tasks_per_child,
worker_threads=worker_threads,
index_with_generator=index_with_generator,
pbar=pbar,
mp_start_method=mp_start_method,
out=mmlike,
)
self._storage._storage = mmlike
shutil.rmtree(self.data_path)
shutil.move(mock_folder, self.data_path)
item = self._storage[0]
with item.unlock_():
example_data = fn(item)
with tempfile.TemporaryDirectory() as mock_folder:
mmlike = example_data.expand(
(self._storage.shape[0], *example_data.shape)
).memmap_like(mock_folder, num_threads=32)
storage = self._storage._storage
with storage.unlock_():
storage.map(
fn=fn,
dim=dim,
num_workers=num_workers,
chunksize=chunksize,
num_chunks=num_chunks,
pool=pool,
generator=generator,
max_tasks_per_child=max_tasks_per_child,
worker_threads=worker_threads,
index_with_generator=index_with_generator,
pbar=pbar,
mp_start_method=mp_start_method,
out=mmlike,
)
self._storage._storage = mmlike
shutil.rmtree(self.data_path)
shutil.move(mock_folder, self.data_path)
else:
raise NotImplementedError


def _can_be_pickled(obj):
try:
pickle.dumps(obj)
return True
except (pickle.PickleError, AttributeError, TypeError):
return False
Loading

0 comments on commit 8323736

Please sign in to comment.