31
31
set_interaction_type as set_exploration_type ,
32
32
)
33
33
from tensordict .utils import NestedKey
34
- from torchrl ._utils import _replace_last , logger as torchrl_logger
34
+ from torchrl ._utils import _replace_last , _rng_decorator , logger as torchrl_logger
35
35
36
36
from torchrl .data .tensor_specs import (
37
37
CompositeSpec ,
@@ -419,7 +419,9 @@ def _per_level_env_check(data0, data1, check_dtype):
419
419
)
420
420
421
421
422
- def check_env_specs (env , return_contiguous = True , check_dtype = True , seed = 0 ):
422
+ def check_env_specs (
423
+ env , return_contiguous = True , check_dtype = True , seed : int | None = None
424
+ ):
423
425
"""Tests an environment specs against the results of short rollout.
424
426
425
427
This test function should be used as a sanity check for an env wrapped with
@@ -436,16 +438,27 @@ def check_env_specs(env, return_contiguous=True, check_dtype=True, seed=0):
436
438
of inputs/outputs). Defaults to True.
437
439
check_dtype (bool, optional): if False, dtype checks will be skipped.
438
440
Defaults to True.
439
- seed (int, optional): for reproducibility, a seed is set.
441
+ seed (int, optional): for reproducibility, a seed can be set.
442
+ The seed will be set in pytorch temporarily, then the RNG state will
443
+ be reverted to what it was before. For the env, we set the seed but since
444
+ setting the rng state back to what is was isn't a feature of most environment,
445
+ we leave it to the user to accomplish that.
446
+ Defaults to ``None``.
440
447
441
448
Caution: this function resets the env seed. It should be used "offline" to
442
449
check that an env is adequately constructed, but it may affect the seeding
443
450
of an experiment and as such should be kept out of training scripts.
444
451
445
452
"""
446
453
if seed is not None :
447
- torch .manual_seed (seed )
448
- env .set_seed (seed )
454
+ device = (
455
+ env .device if env .device is not None and env .device .type == "cuda" else None
456
+ )
457
+ with _rng_decorator (seed , device = device ):
458
+ env .set_seed (seed )
459
+ return check_env_specs (
460
+ env , return_contiguous = return_contiguous , check_dtype = check_dtype
461
+ )
449
462
450
463
fake_tensordict = env .fake_tensordict ()
451
464
real_tensordict = env .rollout (3 , return_contiguous = return_contiguous )
0 commit comments