diff --git a/test/test_utils.py b/test/test_utils.py index 620149daeb6..c2ce2eae6b9 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -12,7 +12,10 @@ import _utils_internal import pytest -from torchrl._utils import get_binary_env_var, implement_for +import torch + +from _utils_internal import get_default_devices +from torchrl._utils import _rng_decorator, get_binary_env_var, implement_for from torchrl.envs.libs.gym import gym_backend, GymWrapper, set_gym_backend @@ -358,6 +361,21 @@ class MockGym: ) # would break with gymnasium +@pytest.mark.parametrize("device", get_default_devices()) +def test_rng_decorator(device): + with torch.device(device): + torch.manual_seed(10) + s0a = torch.randn(3) + with _rng_decorator(0): + torch.randn(3) + s0b = torch.randn(3) + torch.manual_seed(10) + s1a = torch.randn(3) + s1b = torch.randn(3) + torch.testing.assert_close(s0a, s1a) + torch.testing.assert_close(s0b, s1b) + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/_utils.py b/torchrl/_utils.py index 9538cecb026..6c52b1d66e7 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -704,3 +704,40 @@ def _replace_last(key: NestedKey, new_ending: str) -> NestedKey: return new_ending else: return key[:-1] + (new_ending,) + + +class _rng_decorator(_DecoratorContextManager): + """Temporarily sets the seed and sets back the rng state when exiting.""" + + def __init__(self, seed, device=None): + self.seed = seed + self.device = device + self.has_cuda = torch.cuda.is_available() + + def __enter__(self): + self._get_state() + torch.manual_seed(self.seed) + + def _get_state(self): + if self.has_cuda: + if self.device is None: + self._state = (torch.random.get_rng_state(), torch.cuda.get_rng_state()) + else: + self._state = ( + torch.random.get_rng_state(), + torch.cuda.get_rng_state(self.device), + ) + + else: + self.state = torch.random.get_rng_state() + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.has_cuda: + torch.random.set_rng_state(self._state[0]) + if self.device is not None: + torch.cuda.set_rng_state(self._state[1], device=self.device) + else: + torch.cuda.set_rng_state(self._state[1]) + + else: + torch.random.set_rng_state(self._state) diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index ebb9100655c..a3aeecbebbb 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -31,7 +31,7 @@ set_interaction_type as set_exploration_type, ) from tensordict.utils import NestedKey -from torchrl._utils import _replace_last, logger as torchrl_logger +from torchrl._utils import _replace_last, _rng_decorator, logger as torchrl_logger from torchrl.data.tensor_specs import ( CompositeSpec, @@ -419,7 +419,9 @@ def _per_level_env_check(data0, data1, check_dtype): ) -def check_env_specs(env, return_contiguous=True, check_dtype=True, seed=0): +def check_env_specs( + env, return_contiguous=True, check_dtype=True, seed: int | None = None +): """Tests an environment specs against the results of short rollout. This test function should be used as a sanity check for an env wrapped with @@ -436,7 +438,12 @@ def check_env_specs(env, return_contiguous=True, check_dtype=True, seed=0): of inputs/outputs). Defaults to True. check_dtype (bool, optional): if False, dtype checks will be skipped. Defaults to True. - seed (int, optional): for reproducibility, a seed is set. + seed (int, optional): for reproducibility, a seed can be set. + The seed will be set in pytorch temporarily, then the RNG state will + be reverted to what it was before. For the env, we set the seed but since + setting the rng state back to what is was isn't a feature of most environment, + we leave it to the user to accomplish that. + Defaults to ``None``. Caution: this function resets the env seed. It should be used "offline" to check that an env is adequately constructed, but it may affect the seeding @@ -444,8 +451,14 @@ def check_env_specs(env, return_contiguous=True, check_dtype=True, seed=0): """ if seed is not None: - torch.manual_seed(seed) - env.set_seed(seed) + device = ( + env.device if env.device is not None and env.device.type == "cuda" else None + ) + with _rng_decorator(seed, device=device): + env.set_seed(seed) + return check_env_specs( + env, return_contiguous=return_contiguous, check_dtype=check_dtype + ) fake_tensordict = env.fake_tensordict() real_tensordict = env.rollout(3, return_contiguous=return_contiguous)