Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] check_env_specs seeding logic #1872

Merged
merged 8 commits into from
Feb 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
import _utils_internal
import pytest

from torchrl._utils import get_binary_env_var, implement_for
import torch

from _utils_internal import get_default_devices
from torchrl._utils import _rng_decorator, get_binary_env_var, implement_for

from torchrl.envs.libs.gym import gym_backend, GymWrapper, set_gym_backend

Expand Down Expand Up @@ -358,6 +361,21 @@ class MockGym:
) # would break with gymnasium


@pytest.mark.parametrize("device", get_default_devices())
def test_rng_decorator(device):
with torch.device(device):
torch.manual_seed(10)
s0a = torch.randn(3)
with _rng_decorator(0):
torch.randn(3)
s0b = torch.randn(3)
torch.manual_seed(10)
s1a = torch.randn(3)
s1b = torch.randn(3)
torch.testing.assert_close(s0a, s1a)
torch.testing.assert_close(s0b, s1b)


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
37 changes: 37 additions & 0 deletions torchrl/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,3 +704,40 @@ def _replace_last(key: NestedKey, new_ending: str) -> NestedKey:
return new_ending
else:
return key[:-1] + (new_ending,)


class _rng_decorator(_DecoratorContextManager):
"""Temporarily sets the seed and sets back the rng state when exiting."""

def __init__(self, seed, device=None):
self.seed = seed
self.device = device
self.has_cuda = torch.cuda.is_available()

def __enter__(self):
self._get_state()
torch.manual_seed(self.seed)

def _get_state(self):
if self.has_cuda:
if self.device is None:
self._state = (torch.random.get_rng_state(), torch.cuda.get_rng_state())
else:
self._state = (
torch.random.get_rng_state(),
torch.cuda.get_rng_state(self.device),
)

else:
self.state = torch.random.get_rng_state()

def __exit__(self, exc_type, exc_val, exc_tb):
if self.has_cuda:
torch.random.set_rng_state(self._state[0])
if self.device is not None:
torch.cuda.set_rng_state(self._state[1], device=self.device)
else:
torch.cuda.set_rng_state(self._state[1])

else:
torch.random.set_rng_state(self._state)
23 changes: 18 additions & 5 deletions torchrl/envs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
set_interaction_type as set_exploration_type,
)
from tensordict.utils import NestedKey
from torchrl._utils import _replace_last, logger as torchrl_logger
from torchrl._utils import _replace_last, _rng_decorator, logger as torchrl_logger

from torchrl.data.tensor_specs import (
CompositeSpec,
Expand Down Expand Up @@ -419,7 +419,9 @@ def _per_level_env_check(data0, data1, check_dtype):
)


def check_env_specs(env, return_contiguous=True, check_dtype=True, seed=0):
def check_env_specs(
env, return_contiguous=True, check_dtype=True, seed: int | None = None
):
"""Tests an environment specs against the results of short rollout.

This test function should be used as a sanity check for an env wrapped with
Expand All @@ -436,16 +438,27 @@ def check_env_specs(env, return_contiguous=True, check_dtype=True, seed=0):
of inputs/outputs). Defaults to True.
check_dtype (bool, optional): if False, dtype checks will be skipped.
Defaults to True.
seed (int, optional): for reproducibility, a seed is set.
seed (int, optional): for reproducibility, a seed can be set.
The seed will be set in pytorch temporarily, then the RNG state will
be reverted to what it was before. For the env, we set the seed but since
setting the rng state back to what is was isn't a feature of most environment,
we leave it to the user to accomplish that.
Defaults to ``None``.

Caution: this function resets the env seed. It should be used "offline" to
check that an env is adequately constructed, but it may affect the seeding
of an experiment and as such should be kept out of training scripts.

"""
if seed is not None:
torch.manual_seed(seed)
env.set_seed(seed)
device = (
env.device if env.device is not None and env.device.type == "cuda" else None
)
with _rng_decorator(seed, device=device):
env.set_seed(seed)
return check_env_specs(
env, return_contiguous=return_contiguous, check_dtype=check_dtype
)

fake_tensordict = env.fake_tensordict()
real_tensordict = env.rollout(3, return_contiguous=return_contiguous)
Expand Down
Loading