Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Mar 22, 2024
1 parent 0de72c9 commit 60ca98c
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 5 deletions.
2 changes: 1 addition & 1 deletion torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ def _sync_w2m(self) -> Callable:
return sync_func

def _find_sync_values(self):
"""Returns the m2w and w2m sync values, in that order"""
"""Returns the m2w and w2m sync values, in that order."""
# Simplest case: everything is on the same device
worker_device = self.shared_tensordict_parent.device
self_device = self.device
Expand Down
20 changes: 16 additions & 4 deletions torchrl/envs/libs/gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,8 +769,9 @@ def _get_batch_size(self, env):

@implement_for("gymnasium") # gymnasium wants the unwrapped env
def _get_batch_size(self, env): # noqa: F811
if hasattr(env, "num_envs"):
batch_size = torch.Size([env.unwrapped.num_envs, *self.batch_size])
env_unwrapped = env.unwrapped
if hasattr(env_unwrapped, "num_envs"):
batch_size = torch.Size([env_unwrapped.num_envs, *self.batch_size])
else:
batch_size = self.batch_size
return batch_size
Expand Down Expand Up @@ -929,6 +930,16 @@ def _set_seed_initial(self, seed: int) -> None: # noqa: F811
self._seed_calls_reset = False
self._env.seed(seed=seed)

@implement_for("gym")
def _reward_space(self, env):
if hasattr(env, "reward_space") and env.reward_space is not None:
return env.reward_space
@implement_for("gymnasium")
def _reward_space(self, env):
env = env.unwrapped
if hasattr(env, "reward_space") and env.reward_space is not None:
return env.reward_space

def _make_specs(self, env: "gym.Env", batch_size=None) -> None: # noqa: F821
action_spec = _gym_to_torchrl_spec_transform(
env.action_space,
Expand All @@ -952,9 +963,10 @@ def _make_specs(self, env: "gym.Env", batch_size=None) -> None: # noqa: F821
elif observation_spec.shape[: len(self.batch_size)] != self.batch_size:
observation_spec.shape = self.batch_size

if hasattr(env, "reward_space") and env.reward_space is not None:
reward_space = self._reward_spec(env)
if reward_space is not None:
reward_spec = _gym_to_torchrl_spec_transform(
env.reward_space,
reward_space,
device=self.device,
categorical_action_encoding=self._categorical_action_encoding,
)
Expand Down

0 comments on commit 60ca98c

Please sign in to comment.