Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Feb 5, 2024
1 parent 80fc87f commit 3bb5ecb
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 6 deletions.
20 changes: 19 additions & 1 deletion test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
import _utils_internal
import pytest

from torchrl._utils import get_binary_env_var, implement_for
import torch

from _utils_internal import get_default_devices
from torchrl._utils import _rng_decorator, get_binary_env_var, implement_for

from torchrl.envs.libs.gym import gym_backend, GymWrapper, set_gym_backend

Expand Down Expand Up @@ -358,6 +361,21 @@ class MockGym:
) # would break with gymnasium


@pytest.mark.parametrize("device", get_default_devices())
def test_rng_decorator(device):
torch.manual_seed(10)
with torch.device(device):
s0a = torch.randn(3)
with _rng_decorator(0):
torch.randn(3)
s0b = torch.randn(3)
torch.manual_seed(10)
s1a = torch.randn(3)
s1b = torch.randn(3)
torch.testing.assert_close(s0a, s1a)
torch.testing.assert_close(s0b, s1b)


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
14 changes: 14 additions & 0 deletions torchrl/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,3 +704,17 @@ def _replace_last(key: NestedKey, new_ending: str) -> NestedKey:
return new_ending
else:
return key[:-1] + (new_ending,)


class _rng_decorator(_DecoratorContextManager):
"""Temporarily sets the seed and sets back the rng state when exiting."""

def __init__(self, seed):
self.seed = seed

def __enter__(self):
self._state = torch.random.get_rng_state()
torch.manual_seed(self.seed)

def __exit__(self, exc_type, exc_val, exc_tb):
torch.random.set_rng_state(self._state)
20 changes: 15 additions & 5 deletions torchrl/envs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
set_interaction_type as set_exploration_type,
)
from tensordict.utils import NestedKey
from torchrl._utils import _replace_last, logger as torchrl_logger
from torchrl._utils import _replace_last, _rng_decorator, logger as torchrl_logger

from torchrl.data.tensor_specs import (
CompositeSpec,
Expand Down Expand Up @@ -419,7 +419,9 @@ def _per_level_env_check(data0, data1, check_dtype):
)


def check_env_specs(env, return_contiguous=True, check_dtype=True, seed=0):
def check_env_specs(
env, return_contiguous=True, check_dtype=True, seed: int | None = None
):
"""Tests an environment specs against the results of short rollout.
This test function should be used as a sanity check for an env wrapped with
Expand All @@ -436,16 +438,24 @@ def check_env_specs(env, return_contiguous=True, check_dtype=True, seed=0):
of inputs/outputs). Defaults to True.
check_dtype (bool, optional): if False, dtype checks will be skipped.
Defaults to True.
seed (int, optional): for reproducibility, a seed is set.
seed (int, optional): for reproducibility, a seed can be set.
The seed will be set in pytorch temporarily, then the RNG state will
be reverted to what it was before. For the env, we set the seed but since
setting the rng state back to what is was isn't a feature of most environment,
we leave it to the user to accomplish that.
Defaults to ``None``.
Caution: this function resets the env seed. It should be used "offline" to
check that an env is adequately constructed, but it may affect the seeding
of an experiment and as such should be kept out of training scripts.
"""
if seed is not None:
torch.manual_seed(seed)
env.set_seed(seed)
with _rng_decorator(seed):
env.set_seed(seed)
return check_env_specs(
env, return_contiguous=return_contiguous, check_dtype=check_dtype
)

fake_tensordict = env.fake_tensordict()
real_tensordict = env.rollout(3, return_contiguous=return_contiguous)
Expand Down

0 comments on commit 3bb5ecb

Please sign in to comment.