Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jul 9, 2024
1 parent 4cb6162 commit eab40f9
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 9 deletions.
4 changes: 2 additions & 2 deletions test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1101,7 +1101,7 @@ def test_gym_gymnasium_parallel(self, maybe_fork_ParallelEnv):
def test_vecenvs_nan(self): # noqa: F811
# old versions of gym must return nan for next values when there is a done state
torch.manual_seed(0)
env = GymEnv("CartPole-v0", num_envs=2)
env = GymEnv("CartPole-v0", num_envs=2, device="cpu")
env.set_seed(0)
rollout = env.rollout(200)
assert torch.isfinite(rollout.get("observation")).all()
Expand All @@ -1110,7 +1110,7 @@ def test_vecenvs_nan(self): # noqa: F811
del env

# same with collector
env = GymEnv("CartPole-v0", num_envs=2)
env = GymEnv("CartPole-v0", num_envs=2) # , device="cpu")
env.set_seed(0)
c = SyncDataCollector(
env, RandomPolicy(env.action_spec), total_frames=2000, frames_per_batch=200
Expand Down
2 changes: 2 additions & 0 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1143,6 +1143,7 @@ def __eq__(self, other):
if not isinstance(other, LazyStackedTensorSpec):
return False
if self.device != other.device:
raise RuntimeError((self, other))
return False
if len(self._specs) != len(other._specs):
return False
Expand Down Expand Up @@ -4778,6 +4779,7 @@ def _stack_specs(list_of_spec, dim, out=None):
dim += len(shape) + 1
shape.insert(dim, len(list_of_spec))
return spec0.clone().unsqueeze(dim).expand(shape)
raise RuntimeError(list_of_spec)
return LazyStackedTensorSpec(*list_of_spec, dim=dim)
else:
raise NotImplementedError
Expand Down
6 changes: 2 additions & 4 deletions torchrl/envs/gym_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,8 +348,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
batch_size=tensordict.batch_size,
)
if self.device is not None:
tensordict_out = tensordict_out.to(self.device, non_blocking=True)
self._sync_device()
tensordict_out = tensordict_out.to(self.device)

if self.info_dict_reader and (info_dict is not None):
if not isinstance(info_dict, dict):
Expand Down Expand Up @@ -393,8 +392,7 @@ def _reset(
if key not in tensordict_out.keys(True, True):
tensordict_out[key] = item.zero()
if self.device is not None:
tensordict_out = tensordict_out.to(self.device, non_blocking=True)
self._sync_device()
tensordict_out = tensordict_out.to(self.device)
return tensordict_out

@abc.abstractmethod
Expand Down
5 changes: 2 additions & 3 deletions torchrl/envs/libs/gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
BoundedTensorSpec,
CompositeSpec,
DiscreteTensorSpec,
LazyStackedTensorSpec,
MultiDiscreteTensorSpec,
MultiOneHotDiscreteTensorSpec,
OneHotDiscreteTensorSpec,
Expand Down Expand Up @@ -246,8 +245,8 @@ def _gym_to_torchrl_spec_transform(
).expand(batch_size)
gym_spaces = gym_backend("spaces")
if isinstance(spec, gym_spaces.tuple.Tuple):
result = LazyStackedTensorSpec(
*[
result = torch.stack(
[
_gym_to_torchrl_spec_transform(
s,
device=device,
Expand Down

0 comments on commit eab40f9

Please sign in to comment.