Skip to content

Commit 528faa1

Browse files
author
Vincent Moens
authored
[BugFix] check_env_specs seeding logic (#1872)
1 parent 19a920e commit 528faa1

File tree

3 files changed

+74
-6
lines changed

3 files changed

+74
-6
lines changed

test/test_utils.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@
1212
import _utils_internal
1313
import pytest
1414

15-
from torchrl._utils import get_binary_env_var, implement_for
15+
import torch
16+
17+
from _utils_internal import get_default_devices
18+
from torchrl._utils import _rng_decorator, get_binary_env_var, implement_for
1619

1720
from torchrl.envs.libs.gym import gym_backend, GymWrapper, set_gym_backend
1821

@@ -358,6 +361,21 @@ class MockGym:
358361
) # would break with gymnasium
359362

360363

364+
@pytest.mark.parametrize("device", get_default_devices())
365+
def test_rng_decorator(device):
366+
with torch.device(device):
367+
torch.manual_seed(10)
368+
s0a = torch.randn(3)
369+
with _rng_decorator(0):
370+
torch.randn(3)
371+
s0b = torch.randn(3)
372+
torch.manual_seed(10)
373+
s1a = torch.randn(3)
374+
s1b = torch.randn(3)
375+
torch.testing.assert_close(s0a, s1a)
376+
torch.testing.assert_close(s0b, s1b)
377+
378+
361379
if __name__ == "__main__":
362380
args, unknown = argparse.ArgumentParser().parse_known_args()
363381
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

torchrl/_utils.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -704,3 +704,40 @@ def _replace_last(key: NestedKey, new_ending: str) -> NestedKey:
704704
return new_ending
705705
else:
706706
return key[:-1] + (new_ending,)
707+
708+
709+
class _rng_decorator(_DecoratorContextManager):
710+
"""Temporarily sets the seed and sets back the rng state when exiting."""
711+
712+
def __init__(self, seed, device=None):
713+
self.seed = seed
714+
self.device = device
715+
self.has_cuda = torch.cuda.is_available()
716+
717+
def __enter__(self):
718+
self._get_state()
719+
torch.manual_seed(self.seed)
720+
721+
def _get_state(self):
722+
if self.has_cuda:
723+
if self.device is None:
724+
self._state = (torch.random.get_rng_state(), torch.cuda.get_rng_state())
725+
else:
726+
self._state = (
727+
torch.random.get_rng_state(),
728+
torch.cuda.get_rng_state(self.device),
729+
)
730+
731+
else:
732+
self.state = torch.random.get_rng_state()
733+
734+
def __exit__(self, exc_type, exc_val, exc_tb):
735+
if self.has_cuda:
736+
torch.random.set_rng_state(self._state[0])
737+
if self.device is not None:
738+
torch.cuda.set_rng_state(self._state[1], device=self.device)
739+
else:
740+
torch.cuda.set_rng_state(self._state[1])
741+
742+
else:
743+
torch.random.set_rng_state(self._state)

torchrl/envs/utils.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
set_interaction_type as set_exploration_type,
3232
)
3333
from tensordict.utils import NestedKey
34-
from torchrl._utils import _replace_last, logger as torchrl_logger
34+
from torchrl._utils import _replace_last, _rng_decorator, logger as torchrl_logger
3535

3636
from torchrl.data.tensor_specs import (
3737
CompositeSpec,
@@ -419,7 +419,9 @@ def _per_level_env_check(data0, data1, check_dtype):
419419
)
420420

421421

422-
def check_env_specs(env, return_contiguous=True, check_dtype=True, seed=0):
422+
def check_env_specs(
423+
env, return_contiguous=True, check_dtype=True, seed: int | None = None
424+
):
423425
"""Tests an environment specs against the results of short rollout.
424426
425427
This test function should be used as a sanity check for an env wrapped with
@@ -436,16 +438,27 @@ def check_env_specs(env, return_contiguous=True, check_dtype=True, seed=0):
436438
of inputs/outputs). Defaults to True.
437439
check_dtype (bool, optional): if False, dtype checks will be skipped.
438440
Defaults to True.
439-
seed (int, optional): for reproducibility, a seed is set.
441+
seed (int, optional): for reproducibility, a seed can be set.
442+
The seed will be set in pytorch temporarily, then the RNG state will
443+
be reverted to what it was before. For the env, we set the seed but since
444+
setting the rng state back to what is was isn't a feature of most environment,
445+
we leave it to the user to accomplish that.
446+
Defaults to ``None``.
440447
441448
Caution: this function resets the env seed. It should be used "offline" to
442449
check that an env is adequately constructed, but it may affect the seeding
443450
of an experiment and as such should be kept out of training scripts.
444451
445452
"""
446453
if seed is not None:
447-
torch.manual_seed(seed)
448-
env.set_seed(seed)
454+
device = (
455+
env.device if env.device is not None and env.device.type == "cuda" else None
456+
)
457+
with _rng_decorator(seed, device=device):
458+
env.set_seed(seed)
459+
return check_env_specs(
460+
env, return_contiguous=return_contiguous, check_dtype=check_dtype
461+
)
449462

450463
fake_tensordict = env.fake_tensordict()
451464
real_tensordict = env.rollout(3, return_contiguous=return_contiguous)

0 commit comments

Comments
 (0)