From 3bb5ecb52b8b0a4daf8b0cac10bd7a66c8d329a1 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 5 Feb 2024 17:47:59 +0000 Subject: [PATCH 1/8] init --- test/test_utils.py | 20 +++++++++++++++++++- torchrl/_utils.py | 14 ++++++++++++++ torchrl/envs/utils.py | 20 +++++++++++++++----- 3 files changed, 48 insertions(+), 6 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index 620149daeb6..dcf8fa71e50 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -12,7 +12,10 @@ import _utils_internal import pytest -from torchrl._utils import get_binary_env_var, implement_for +import torch + +from _utils_internal import get_default_devices +from torchrl._utils import _rng_decorator, get_binary_env_var, implement_for from torchrl.envs.libs.gym import gym_backend, GymWrapper, set_gym_backend @@ -358,6 +361,21 @@ class MockGym: ) # would break with gymnasium +@pytest.mark.parametrize("device", get_default_devices()) +def test_rng_decorator(device): + torch.manual_seed(10) + with torch.device(device): + s0a = torch.randn(3) + with _rng_decorator(0): + torch.randn(3) + s0b = torch.randn(3) + torch.manual_seed(10) + s1a = torch.randn(3) + s1b = torch.randn(3) + torch.testing.assert_close(s0a, s1a) + torch.testing.assert_close(s0b, s1b) + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/_utils.py b/torchrl/_utils.py index 9538cecb026..4ce56fec0b2 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -704,3 +704,17 @@ def _replace_last(key: NestedKey, new_ending: str) -> NestedKey: return new_ending else: return key[:-1] + (new_ending,) + + +class _rng_decorator(_DecoratorContextManager): + """Temporarily sets the seed and sets back the rng state when exiting.""" + + def __init__(self, seed): + self.seed = seed + + def __enter__(self): + self._state = torch.random.get_rng_state() + torch.manual_seed(self.seed) + + def __exit__(self, exc_type, exc_val, exc_tb): + torch.random.set_rng_state(self._state) diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index ebb9100655c..96a86ce75d6 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -31,7 +31,7 @@ set_interaction_type as set_exploration_type, ) from tensordict.utils import NestedKey -from torchrl._utils import _replace_last, logger as torchrl_logger +from torchrl._utils import _replace_last, _rng_decorator, logger as torchrl_logger from torchrl.data.tensor_specs import ( CompositeSpec, @@ -419,7 +419,9 @@ def _per_level_env_check(data0, data1, check_dtype): ) -def check_env_specs(env, return_contiguous=True, check_dtype=True, seed=0): +def check_env_specs( + env, return_contiguous=True, check_dtype=True, seed: int | None = None +): """Tests an environment specs against the results of short rollout. This test function should be used as a sanity check for an env wrapped with @@ -436,7 +438,12 @@ def check_env_specs(env, return_contiguous=True, check_dtype=True, seed=0): of inputs/outputs). Defaults to True. check_dtype (bool, optional): if False, dtype checks will be skipped. Defaults to True. - seed (int, optional): for reproducibility, a seed is set. + seed (int, optional): for reproducibility, a seed can be set. + The seed will be set in pytorch temporarily, then the RNG state will + be reverted to what it was before. For the env, we set the seed but since + setting the rng state back to what is was isn't a feature of most environment, + we leave it to the user to accomplish that. + Defaults to ``None``. Caution: this function resets the env seed. It should be used "offline" to check that an env is adequately constructed, but it may affect the seeding @@ -444,8 +451,11 @@ def check_env_specs(env, return_contiguous=True, check_dtype=True, seed=0): """ if seed is not None: - torch.manual_seed(seed) - env.set_seed(seed) + with _rng_decorator(seed): + env.set_seed(seed) + return check_env_specs( + env, return_contiguous=return_contiguous, check_dtype=check_dtype + ) fake_tensordict = env.fake_tensordict() real_tensordict = env.rollout(3, return_contiguous=return_contiguous) From 2ea60a985719d66e214f2162c0574034acf0dd7c Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 5 Feb 2024 20:33:36 +0000 Subject: [PATCH 2/8] amend --- test/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_utils.py b/test/test_utils.py index dcf8fa71e50..c2ce2eae6b9 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -363,8 +363,8 @@ class MockGym: @pytest.mark.parametrize("device", get_default_devices()) def test_rng_decorator(device): - torch.manual_seed(10) with torch.device(device): + torch.manual_seed(10) s0a = torch.randn(3) with _rng_decorator(0): torch.randn(3) From 81539eccb19a46fa84b2a206a99442a93563e974 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 5 Feb 2024 20:37:47 +0000 Subject: [PATCH 3/8] amend --- torchrl/_utils.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/torchrl/_utils.py b/torchrl/_utils.py index 4ce56fec0b2..2f562451adb 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -711,10 +711,15 @@ class _rng_decorator(_DecoratorContextManager): def __init__(self, seed): self.seed = seed + self.stream = torch.cuda.Stream() + self.event = self.stream.record_event() def __enter__(self): self._state = torch.random.get_rng_state() torch.manual_seed(self.seed) + return torch.cuda.stream(self.stream) def __exit__(self, exc_type, exc_val, exc_tb): torch.random.set_rng_state(self._state) + self.event.wait() + self.event.synchronize() From 5756a85b0ccd24050ab7aeca88f2f0f7007df3f5 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 5 Feb 2024 20:41:10 +0000 Subject: [PATCH 4/8] amend --- test/test_utils.py | 2 ++ torchrl/_utils.py | 5 ----- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index c2ce2eae6b9..88d7efed993 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -5,6 +5,7 @@ import argparse import os import sys +import time from copy import copy from importlib import import_module from unittest import mock @@ -368,6 +369,7 @@ def test_rng_decorator(device): s0a = torch.randn(3) with _rng_decorator(0): torch.randn(3) + time.sleep(4) s0b = torch.randn(3) torch.manual_seed(10) s1a = torch.randn(3) diff --git a/torchrl/_utils.py b/torchrl/_utils.py index 2f562451adb..4ce56fec0b2 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -711,15 +711,10 @@ class _rng_decorator(_DecoratorContextManager): def __init__(self, seed): self.seed = seed - self.stream = torch.cuda.Stream() - self.event = self.stream.record_event() def __enter__(self): self._state = torch.random.get_rng_state() torch.manual_seed(self.seed) - return torch.cuda.stream(self.stream) def __exit__(self, exc_type, exc_val, exc_tb): torch.random.set_rng_state(self._state) - self.event.wait() - self.event.synchronize() From 1813c5bbb4343172b3a8dee3d346698f63d47728 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 5 Feb 2024 20:47:05 +0000 Subject: [PATCH 5/8] amend --- test/test_utils.py | 1 - torchrl/_utils.py | 19 ++++++++++++++++--- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index 88d7efed993..65a5ae352c7 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -369,7 +369,6 @@ def test_rng_decorator(device): s0a = torch.randn(3) with _rng_decorator(0): torch.randn(3) - time.sleep(4) s0b = torch.randn(3) torch.manual_seed(10) s1a = torch.randn(3) diff --git a/torchrl/_utils.py b/torchrl/_utils.py index 4ce56fec0b2..6395f7bb301 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -709,12 +709,25 @@ def _replace_last(key: NestedKey, new_ending: str) -> NestedKey: class _rng_decorator(_DecoratorContextManager): """Temporarily sets the seed and sets back the rng state when exiting.""" - def __init__(self, seed): + def __init__(self, seed, device=None): self.seed = seed + self.device = device + self.has_cuda = torch.cuda.is_available(): def __enter__(self): - self._state = torch.random.get_rng_state() + self._get_state() torch.manual_seed(self.seed) + def _get_state(self): + if self.has_cuda: + self._state = ( + torch.random.get_rng_state(), torch.cuda.get_rng_state(self.device)) + else: + self.state = torch.random.get_rng_state() + def __exit__(self, exc_type, exc_val, exc_tb): - torch.random.set_rng_state(self._state) + if self.has_cuda: + torch.random.set_rng_state(self._state[0]) + torch.cuda.set_rng_state(self._state[1], device=self.device) + else: + torch.random.set_rng_state(self._state) From 318a6a236092846a87ca170632d9aa7772b49d8a Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 5 Feb 2024 20:48:19 +0000 Subject: [PATCH 6/8] amend --- torchrl/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/_utils.py b/torchrl/_utils.py index 6395f7bb301..bf9feaab76b 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -712,7 +712,7 @@ class _rng_decorator(_DecoratorContextManager): def __init__(self, seed, device=None): self.seed = seed self.device = device - self.has_cuda = torch.cuda.is_available(): + self.has_cuda = torch.cuda.is_available() def __enter__(self): self._get_state() From 5215904d03134bfe118a6d6de7812621a0f9e708 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 5 Feb 2024 20:50:08 +0000 Subject: [PATCH 7/8] amend --- torchrl/_utils.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/torchrl/_utils.py b/torchrl/_utils.py index bf9feaab76b..9ff6a65ef0c 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -720,14 +720,23 @@ def __enter__(self): def _get_state(self): if self.has_cuda: - self._state = ( - torch.random.get_rng_state(), torch.cuda.get_rng_state(self.device)) + if self.device is None: + self._state = ( + torch.random.get_rng_state(), torch.cuda.get_rng_state()) + else: + self._state = ( + torch.random.get_rng_state(), torch.cuda.get_rng_state(self.device)) + else: self.state = torch.random.get_rng_state() def __exit__(self, exc_type, exc_val, exc_tb): if self.has_cuda: torch.random.set_rng_state(self._state[0]) - torch.cuda.set_rng_state(self._state[1], device=self.device) + if self.device is not None: + torch.cuda.set_rng_state(self._state[1], device=self.device) + else: + torch.cuda.set_rng_state(self._state[1]) + else: torch.random.set_rng_state(self._state) From b29cf5ba843e803618e5f81eefe74f73d895bcb8 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 5 Feb 2024 12:52:53 -0800 Subject: [PATCH 8/8] amend --- test/test_utils.py | 1 - torchrl/_utils.py | 7 ++++--- torchrl/envs/utils.py | 5 ++++- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index 65a5ae352c7..c2ce2eae6b9 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -5,7 +5,6 @@ import argparse import os import sys -import time from copy import copy from importlib import import_module from unittest import mock diff --git a/torchrl/_utils.py b/torchrl/_utils.py index 9ff6a65ef0c..6c52b1d66e7 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -721,11 +721,12 @@ def __enter__(self): def _get_state(self): if self.has_cuda: if self.device is None: - self._state = ( - torch.random.get_rng_state(), torch.cuda.get_rng_state()) + self._state = (torch.random.get_rng_state(), torch.cuda.get_rng_state()) else: self._state = ( - torch.random.get_rng_state(), torch.cuda.get_rng_state(self.device)) + torch.random.get_rng_state(), + torch.cuda.get_rng_state(self.device), + ) else: self.state = torch.random.get_rng_state() diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index 96a86ce75d6..a3aeecbebbb 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -451,7 +451,10 @@ def check_env_specs( """ if seed is not None: - with _rng_decorator(seed): + device = ( + env.device if env.device is not None and env.device.type == "cuda" else None + ) + with _rng_decorator(seed, device=device): env.set_seed(seed) return check_env_specs( env, return_contiguous=return_contiguous, check_dtype=check_dtype