Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into prioritized_slice_sam…
Browse files Browse the repository at this point in the history
…pler
  • Loading branch information
vmoens committed Feb 6, 2024
2 parents 74a4bee + ff3a350 commit ffba3dc
Show file tree
Hide file tree
Showing 17 changed files with 387 additions and 186 deletions.
35 changes: 34 additions & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,14 @@ TorchRL

TorchRL is an open-source Reinforcement Learning (RL) library for PyTorch.

It provides pytorch and python-first, low and high level abstractions for RL that are intended to be efficient, modular, documented and properly tested.
You can install TorchRL directly from PyPI (see more about installation
instructions in the dedicated section below):

.. code-block::
$ pip install torchrl
TorchRL provides pytorch and python-first, low and high level abstractions for RL that are intended to be efficient, modular, documented and properly tested.
The code is aimed at supporting research in RL. Most of it is written in python in a highly modular way, such that researchers can easily swap components, transform them or write new ones with little effort.

This repo attempts to align with the existing pytorch ecosystem libraries in that it has a "dataset pillar"
Expand All @@ -30,6 +37,32 @@ TorchRL aims at a high modularity and good runtime performance.
To read more about TorchRL philosophy and capabilities beyond this API reference,
check the `TorchRL paper <https://arxiv.org/abs/2306.00577>`__.

Installation
============

TorchRL releases are synced with PyTorch, so make sure you always enjoy the latest
features of the library with the `most recent version of PyTorch <https://pytorch.org/get-started/locally/>`__ (although core features
are guaranteed to be backward compatible with pytorch>=1.13).
Nightly releases can be installed via

.. code-block::
$ pip install tensordict-nightly
$ pip install torchrl-nightly
or via a ``git clone`` if you're willing to contribute to the library:

.. code-block::
$ cd path/to/root
$ git clone https://github.com/pytorch/tensordict
$ git clone https://github.com/pytorch/rl
$ cd tensordict
$ python setup.py develop
$ cd ../rl
$ python setup.py develop
Tutorials
=========

Expand Down
19 changes: 13 additions & 6 deletions test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1072,7 +1072,7 @@ def _step(
tensordict: TensorDictBase,
) -> TensorDictBase:
action = tensordict.get(self.action_key)
self.count += action.to(torch.int).to(self.device)
self.count += action.to(dtype=torch.int, device=self.device)
tensordict = TensorDict(
source={
"observation": self.count.clone(),
Expand Down Expand Up @@ -1426,10 +1426,12 @@ def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs):
3,
)
),
device=self.device,
)

self.unbatched_action_spec = CompositeSpec(
lazy=action_specs,
device=self.device,
)
self.unbatched_reward_spec = CompositeSpec(
{
Expand All @@ -1441,7 +1443,8 @@ def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs):
},
shape=(self.n_nested_dim,),
)
}
},
device=self.device,
)
self.unbatched_done_spec = CompositeSpec(
{
Expand All @@ -1455,7 +1458,8 @@ def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs):
},
shape=(self.n_nested_dim,),
)
}
},
device=self.device,
)

self.action_spec = self.unbatched_action_spec.expand(
Expand Down Expand Up @@ -1488,7 +1492,8 @@ def get_agent_obs_spec(self, i):
"lidar": lidar,
"vector": vector_3d,
"tensor_0": tensor_0,
}
},
device=self.device,
)
elif i == 1:
return CompositeSpec(
Expand All @@ -1497,15 +1502,17 @@ def get_agent_obs_spec(self, i):
"lidar": lidar,
"vector": vector_2d,
"tensor_1": tensor_1,
}
},
device=self.device,
)
elif i == 2:
return CompositeSpec(
{
"camera": camera,
"vector": vector_2d,
"tensor_2": tensor_2,
}
},
device=self.device,
)
else:
raise ValueError(f"Index {i} undefined for index 3")
Expand Down
18 changes: 13 additions & 5 deletions test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -1675,8 +1675,12 @@ def test_maxframes_error():
@pytest.mark.parametrize("policy_device", [None, *get_available_devices()])
@pytest.mark.parametrize("env_device", [None, *get_available_devices()])
@pytest.mark.parametrize("storing_device", [None, *get_available_devices()])
@pytest.mark.parametrize("parallel", [False, True])
def test_reset_heterogeneous_envs(
policy_device: torch.device, env_device: torch.device, storing_device: torch.device
policy_device: torch.device,
env_device: torch.device,
storing_device: torch.device,
parallel,
):
if (
policy_device is not None
Expand All @@ -1686,9 +1690,13 @@ def test_reset_heterogeneous_envs(
env_device = torch.device("cpu") # explicit mapping
elif env_device is not None and env_device.type == "cuda" and policy_device is None:
policy_device = torch.device("cpu")
env1 = lambda: TransformedEnv(CountingEnv(), StepCounter(2))
env2 = lambda: TransformedEnv(CountingEnv(), StepCounter(3))
env = SerialEnv(2, [env1, env2], device=env_device)
env1 = lambda: TransformedEnv(CountingEnv(device="cpu"), StepCounter(2))
env2 = lambda: TransformedEnv(CountingEnv(device="cpu"), StepCounter(3))
if parallel:
cls = ParallelEnv
else:
cls = SerialEnv
env = cls(2, [env1, env2], device=env_device)
collector = SyncDataCollector(
env,
RandomPolicy(env.action_spec),
Expand All @@ -1705,7 +1713,7 @@ def test_reset_heterogeneous_envs(
assert (
data[0]["next", "truncated"].squeeze()
== torch.tensor([False, True], device=data_device).repeat(25)[:50]
).all(), data[0]["next", "truncated"][:10]
).all(), data[0]["next", "truncated"]
assert (
data[1]["next", "truncated"].squeeze()
== torch.tensor([False, False, True], device=data_device).repeat(17)[:50]
Expand Down
7 changes: 5 additions & 2 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -2095,7 +2095,10 @@ def test_rollout_policy(self, batch_size, rollout_steps, count):

@pytest.mark.parametrize("batch_size", [(1, 2)])
@pytest.mark.parametrize("env_type", ["serial", "parallel"])
def test_vec_env(self, batch_size, env_type, rollout_steps=4, n_workers=2):
@pytest.mark.parametrize("break_when_any_done", [False, True])
def test_vec_env(
self, batch_size, env_type, break_when_any_done, rollout_steps=4, n_workers=2
):
env_fun = lambda: HeterogeneousCountingEnv(batch_size=batch_size)
if env_type == "serial":
vec_env = SerialEnv(n_workers, env_fun)
Expand All @@ -2109,7 +2112,7 @@ def test_vec_env(self, batch_size, env_type, rollout_steps=4, n_workers=2):
rollout_steps,
policy=policy,
return_contiguous=False,
break_when_any_done=False,
break_when_any_done=break_when_any_done,
)
td = dense_stack_tds(td)
for i in range(env_fun().n_nested_dim):
Expand Down
35 changes: 31 additions & 4 deletions test/test_tensordictmodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
CompositeSpec,
UnboundedContinuousTensorSpec,
)
from torchrl.envs import EnvCreator, SerialEnv
from torchrl.envs.utils import set_exploration_type, step_mdp
from torchrl.modules import (
AdditiveGaussianWrapper,
Expand Down Expand Up @@ -1782,9 +1783,12 @@ def test_multi_consecutive(self, shape, python_based):
)

@pytest.mark.parametrize("python_based", [True, False])
def test_lstm_parallel_env(self, python_based):
@pytest.mark.parametrize("parallel", [True, False])
@pytest.mark.parametrize("heterogeneous", [True, False])
def test_lstm_parallel_env(self, python_based, parallel, heterogeneous):
from torchrl.envs import InitTracker, ParallelEnv, TransformedEnv

torch.manual_seed(0)
device = "cuda" if torch.cuda.device_count() else "cpu"
# tests that hidden states are carried over with parallel envs
lstm_module = LSTMModule(
Expand All @@ -1796,6 +1800,10 @@ def test_lstm_parallel_env(self, python_based):
device=device,
python_based=python_based,
)
if parallel:
cls = ParallelEnv
else:
cls = SerialEnv

def create_transformed_env():
primer = lstm_module.make_tensordict_primer()
Expand All @@ -1807,7 +1815,12 @@ def create_transformed_env():
env.append_transform(primer)
return env

env = ParallelEnv(
if heterogeneous:
create_transformed_env = [
EnvCreator(create_transformed_env),
EnvCreator(create_transformed_env),
]
env = cls(
create_env_fn=create_transformed_env,
num_workers=2,
)
Expand Down Expand Up @@ -2109,9 +2122,13 @@ def test_multi_consecutive(self, shape, python_based):
)

@pytest.mark.parametrize("python_based", [True, False])
def test_gru_parallel_env(self, python_based):
@pytest.mark.parametrize("parallel", [True, False])
@pytest.mark.parametrize("heterogeneous", [True, False])
def test_gru_parallel_env(self, python_based, parallel, heterogeneous):
from torchrl.envs import InitTracker, ParallelEnv, TransformedEnv

torch.manual_seed(0)

device = "cuda" if torch.cuda.device_count() else "cpu"
# tests that hidden states are carried over with parallel envs
gru_module = GRUModule(
Expand All @@ -2134,7 +2151,17 @@ def create_transformed_env():
env.append_transform(primer)
return env

env = ParallelEnv(
if parallel:
cls = ParallelEnv
else:
cls = SerialEnv
if heterogeneous:
create_transformed_env = [
EnvCreator(create_transformed_env),
EnvCreator(create_transformed_env),
]

env = cls(
create_env_fn=create_transformed_env,
num_workers=2,
)
Expand Down
20 changes: 19 additions & 1 deletion test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
import _utils_internal
import pytest

from torchrl._utils import get_binary_env_var, implement_for
import torch

from _utils_internal import get_default_devices
from torchrl._utils import _rng_decorator, get_binary_env_var, implement_for

from torchrl.envs.libs.gym import gym_backend, GymWrapper, set_gym_backend

Expand Down Expand Up @@ -358,6 +361,21 @@ class MockGym:
) # would break with gymnasium


@pytest.mark.parametrize("device", get_default_devices())
def test_rng_decorator(device):
with torch.device(device):
torch.manual_seed(10)
s0a = torch.randn(3)
with _rng_decorator(0):
torch.randn(3)
s0b = torch.randn(3)
torch.manual_seed(10)
s1a = torch.randn(3)
s1b = torch.randn(3)
torch.testing.assert_close(s0a, s1a)
torch.testing.assert_close(s0b, s1b)


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
37 changes: 37 additions & 0 deletions torchrl/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,3 +704,40 @@ def _replace_last(key: NestedKey, new_ending: str) -> NestedKey:
return new_ending
else:
return key[:-1] + (new_ending,)


class _rng_decorator(_DecoratorContextManager):
"""Temporarily sets the seed and sets back the rng state when exiting."""

def __init__(self, seed, device=None):
self.seed = seed
self.device = device
self.has_cuda = torch.cuda.is_available()

def __enter__(self):
self._get_state()
torch.manual_seed(self.seed)

def _get_state(self):
if self.has_cuda:
if self.device is None:
self._state = (torch.random.get_rng_state(), torch.cuda.get_rng_state())
else:
self._state = (
torch.random.get_rng_state(),
torch.cuda.get_rng_state(self.device),
)

else:
self._state = torch.random.get_rng_state()

def __exit__(self, exc_type, exc_val, exc_tb):
if self.has_cuda:
torch.random.set_rng_state(self._state[0])
if self.device is not None:
torch.cuda.set_rng_state(self._state[1], device=self.device)
else:
torch.cuda.set_rng_state(self._state[1])

else:
torch.random.set_rng_state(self._state)
2 changes: 1 addition & 1 deletion torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1077,7 +1077,7 @@ def rollout(self) -> TensorDictBase:

if self.storing_device is not None:
tensordicts.append(
self._shuttle.to(self.storing_device, non_blocking=False)
self._shuttle.to(self.storing_device, non_blocking=True)
)
else:
tensordicts.append(self._shuttle)
Expand Down
2 changes: 1 addition & 1 deletion torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,7 +790,7 @@ def _get_stop_and_length(self, storage, fallback=True):
raise RuntimeError(
"Could not get a tensordict out of the storage, which is required for SliceSampler to compute the trajectories."
)
vals = self._find_start_stop_traj(end=done.squeeze())[: len(storage)]
vals = self._find_start_stop_traj(end=done.squeeze()[: len(storage)])
if self.cache_values:
self._cache["stop-and-length"] = vals
return vals
Expand Down
2 changes: 1 addition & 1 deletion torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,7 +894,7 @@ def get(self, index: Union[int, Sequence[int], slice]) -> Any:
# to be deprecated in v0.4
def map_device(tensor):
if tensor.device != self.device:
return tensor.to(self.device, non_blocking=False)
return tensor.to(self.device, non_blocking=True)
return tensor

if is_tensor_collection(result):
Expand Down
2 changes: 1 addition & 1 deletion torchrl/data/rlhf/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ def get_dataloader(
)
out = TensorDictReplayBuffer(
storage=TensorStorage(data),
collate_fn=lambda x: x.as_tensor().to(device, non_blocking=False),
collate_fn=lambda x: x.as_tensor().to(device, non_blocking=True),
sampler=SamplerWithoutReplacement(drop_last=True),
batch_size=batch_size,
prefetch=prefetch,
Expand Down
Loading

0 comments on commit ffba3dc

Please sign in to comment.