From 3bb5ecb52b8b0a4daf8b0cac10bd7a66c8d329a1 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 5 Feb 2024 17:47:59 +0000 Subject: [PATCH] init --- test/test_utils.py | 20 +++++++++++++++++++- torchrl/_utils.py | 14 ++++++++++++++ torchrl/envs/utils.py | 20 +++++++++++++++----- 3 files changed, 48 insertions(+), 6 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index 620149daeb6..dcf8fa71e50 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -12,7 +12,10 @@ import _utils_internal import pytest -from torchrl._utils import get_binary_env_var, implement_for +import torch + +from _utils_internal import get_default_devices +from torchrl._utils import _rng_decorator, get_binary_env_var, implement_for from torchrl.envs.libs.gym import gym_backend, GymWrapper, set_gym_backend @@ -358,6 +361,21 @@ class MockGym: ) # would break with gymnasium +@pytest.mark.parametrize("device", get_default_devices()) +def test_rng_decorator(device): + torch.manual_seed(10) + with torch.device(device): + s0a = torch.randn(3) + with _rng_decorator(0): + torch.randn(3) + s0b = torch.randn(3) + torch.manual_seed(10) + s1a = torch.randn(3) + s1b = torch.randn(3) + torch.testing.assert_close(s0a, s1a) + torch.testing.assert_close(s0b, s1b) + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/_utils.py b/torchrl/_utils.py index 9538cecb026..4ce56fec0b2 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -704,3 +704,17 @@ def _replace_last(key: NestedKey, new_ending: str) -> NestedKey: return new_ending else: return key[:-1] + (new_ending,) + + +class _rng_decorator(_DecoratorContextManager): + """Temporarily sets the seed and sets back the rng state when exiting.""" + + def __init__(self, seed): + self.seed = seed + + def __enter__(self): + self._state = torch.random.get_rng_state() + torch.manual_seed(self.seed) + + def __exit__(self, exc_type, exc_val, exc_tb): + torch.random.set_rng_state(self._state) diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index ebb9100655c..96a86ce75d6 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -31,7 +31,7 @@ set_interaction_type as set_exploration_type, ) from tensordict.utils import NestedKey -from torchrl._utils import _replace_last, logger as torchrl_logger +from torchrl._utils import _replace_last, _rng_decorator, logger as torchrl_logger from torchrl.data.tensor_specs import ( CompositeSpec, @@ -419,7 +419,9 @@ def _per_level_env_check(data0, data1, check_dtype): ) -def check_env_specs(env, return_contiguous=True, check_dtype=True, seed=0): +def check_env_specs( + env, return_contiguous=True, check_dtype=True, seed: int | None = None +): """Tests an environment specs against the results of short rollout. This test function should be used as a sanity check for an env wrapped with @@ -436,7 +438,12 @@ def check_env_specs(env, return_contiguous=True, check_dtype=True, seed=0): of inputs/outputs). Defaults to True. check_dtype (bool, optional): if False, dtype checks will be skipped. Defaults to True. - seed (int, optional): for reproducibility, a seed is set. + seed (int, optional): for reproducibility, a seed can be set. + The seed will be set in pytorch temporarily, then the RNG state will + be reverted to what it was before. For the env, we set the seed but since + setting the rng state back to what is was isn't a feature of most environment, + we leave it to the user to accomplish that. + Defaults to ``None``. Caution: this function resets the env seed. It should be used "offline" to check that an env is adequately constructed, but it may affect the seeding @@ -444,8 +451,11 @@ def check_env_specs(env, return_contiguous=True, check_dtype=True, seed=0): """ if seed is not None: - torch.manual_seed(seed) - env.set_seed(seed) + with _rng_decorator(seed): + env.set_seed(seed) + return check_env_specs( + env, return_contiguous=return_contiguous, check_dtype=check_dtype + ) fake_tensordict = env.fake_tensordict() real_tensordict = env.rollout(3, return_contiguous=return_contiguous)