Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Feb 5, 2024
1 parent 5215904 commit b29cf5b
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 5 deletions.
1 change: 0 additions & 1 deletion test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import argparse
import os
import sys
import time
from copy import copy
from importlib import import_module
from unittest import mock
Expand Down
7 changes: 4 additions & 3 deletions torchrl/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,11 +721,12 @@ def __enter__(self):
def _get_state(self):
if self.has_cuda:
if self.device is None:
self._state = (
torch.random.get_rng_state(), torch.cuda.get_rng_state())
self._state = (torch.random.get_rng_state(), torch.cuda.get_rng_state())
else:
self._state = (
torch.random.get_rng_state(), torch.cuda.get_rng_state(self.device))
torch.random.get_rng_state(),
torch.cuda.get_rng_state(self.device),
)

else:
self.state = torch.random.get_rng_state()
Expand Down
5 changes: 4 additions & 1 deletion torchrl/envs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,10 @@ def check_env_specs(
"""
if seed is not None:
with _rng_decorator(seed):
device = (
env.device if env.device is not None and env.device.type == "cuda" else None
)
with _rng_decorator(seed, device=device):
env.set_seed(seed)
return check_env_specs(
env, return_contiguous=return_contiguous, check_dtype=check_dtype
Expand Down

0 comments on commit b29cf5b

Please sign in to comment.