Skip to content

Commit

Permalink
[BugFix] check_env_specs seeding logic (#1872)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Feb 5, 2024
1 parent 19a920e commit 528faa1
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 6 deletions.
20 changes: 19 additions & 1 deletion test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
import _utils_internal
import pytest

from torchrl._utils import get_binary_env_var, implement_for
import torch

from _utils_internal import get_default_devices
from torchrl._utils import _rng_decorator, get_binary_env_var, implement_for

from torchrl.envs.libs.gym import gym_backend, GymWrapper, set_gym_backend

Expand Down Expand Up @@ -358,6 +361,21 @@ class MockGym:
) # would break with gymnasium


@pytest.mark.parametrize("device", get_default_devices())
def test_rng_decorator(device):
with torch.device(device):
torch.manual_seed(10)
s0a = torch.randn(3)
with _rng_decorator(0):
torch.randn(3)
s0b = torch.randn(3)
torch.manual_seed(10)
s1a = torch.randn(3)
s1b = torch.randn(3)
torch.testing.assert_close(s0a, s1a)
torch.testing.assert_close(s0b, s1b)


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
37 changes: 37 additions & 0 deletions torchrl/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,3 +704,40 @@ def _replace_last(key: NestedKey, new_ending: str) -> NestedKey:
return new_ending
else:
return key[:-1] + (new_ending,)


class _rng_decorator(_DecoratorContextManager):
"""Temporarily sets the seed and sets back the rng state when exiting."""

def __init__(self, seed, device=None):
self.seed = seed
self.device = device
self.has_cuda = torch.cuda.is_available()

def __enter__(self):
self._get_state()
torch.manual_seed(self.seed)

def _get_state(self):
if self.has_cuda:
if self.device is None:
self._state = (torch.random.get_rng_state(), torch.cuda.get_rng_state())
else:
self._state = (
torch.random.get_rng_state(),
torch.cuda.get_rng_state(self.device),
)

else:
self.state = torch.random.get_rng_state()

def __exit__(self, exc_type, exc_val, exc_tb):
if self.has_cuda:
torch.random.set_rng_state(self._state[0])
if self.device is not None:
torch.cuda.set_rng_state(self._state[1], device=self.device)
else:
torch.cuda.set_rng_state(self._state[1])

else:
torch.random.set_rng_state(self._state)
23 changes: 18 additions & 5 deletions torchrl/envs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
set_interaction_type as set_exploration_type,
)
from tensordict.utils import NestedKey
from torchrl._utils import _replace_last, logger as torchrl_logger
from torchrl._utils import _replace_last, _rng_decorator, logger as torchrl_logger

from torchrl.data.tensor_specs import (
CompositeSpec,
Expand Down Expand Up @@ -419,7 +419,9 @@ def _per_level_env_check(data0, data1, check_dtype):
)


def check_env_specs(env, return_contiguous=True, check_dtype=True, seed=0):
def check_env_specs(
env, return_contiguous=True, check_dtype=True, seed: int | None = None
):
"""Tests an environment specs against the results of short rollout.
This test function should be used as a sanity check for an env wrapped with
Expand All @@ -436,16 +438,27 @@ def check_env_specs(env, return_contiguous=True, check_dtype=True, seed=0):
of inputs/outputs). Defaults to True.
check_dtype (bool, optional): if False, dtype checks will be skipped.
Defaults to True.
seed (int, optional): for reproducibility, a seed is set.
seed (int, optional): for reproducibility, a seed can be set.
The seed will be set in pytorch temporarily, then the RNG state will
be reverted to what it was before. For the env, we set the seed but since
setting the rng state back to what is was isn't a feature of most environment,
we leave it to the user to accomplish that.
Defaults to ``None``.
Caution: this function resets the env seed. It should be used "offline" to
check that an env is adequately constructed, but it may affect the seeding
of an experiment and as such should be kept out of training scripts.
"""
if seed is not None:
torch.manual_seed(seed)
env.set_seed(seed)
device = (
env.device if env.device is not None and env.device.type == "cuda" else None
)
with _rng_decorator(seed, device=device):
env.set_seed(seed)
return check_env_specs(
env, return_contiguous=return_contiguous, check_dtype=check_dtype
)

fake_tensordict = env.fake_tensordict()
real_tensordict = env.rollout(3, return_contiguous=return_contiguous)
Expand Down

0 comments on commit 528faa1

Please sign in to comment.