Skip to content

Commit

Permalink
Update DeepMind wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Aug 9, 2024
1 parent d88da15 commit 75a80b5
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 24 deletions.
2 changes: 1 addition & 1 deletion skrl/envs/wrappers/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def _in(values, container):
return "isaacgym-preview2"
elif _in("robosuite.environments.", base_classes):
return "robosuite"
elif _in("dm_env._environment.Environment.", base_classes):
elif _in("dm_env..*", base_classes):
return "dm"
elif _in("pettingzoo.utils.env", base_classes) or _in("pettingzoo.utils.wrappers", base_classes):
return "pettingzoo"
Expand Down
34 changes: 11 additions & 23 deletions skrl/envs/wrappers/torch/deepmind_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
import torch

from skrl import logger
from skrl.envs.wrappers.torch.base import Wrapper


Expand All @@ -21,29 +22,17 @@ def __init__(self, env: Any) -> None:
from dm_env import specs
self._specs = specs

# observation and action spaces
self._observation_space = self._spec_to_space(self._env.observation_spec())
self._action_space = self._spec_to_space(self._env.action_spec())

@property
def state_space(self) -> gym.Space:
"""State space
An alias for the ``observation_space`` property
"""
return self._observation_space

@property
def observation_space(self) -> gym.Space:
"""Observation space
"""
return self._observation_space
return self._spec_to_space(self._env.observation_spec())

@property
def action_space(self) -> gym.Space:
"""Action space
"""
return self._action_space
return self._spec_to_space(self._env.action_spec())

def _spec_to_space(self, spec: Any) -> gym.Space:
"""Convert the DeepMind spec to a Gym space
Expand Down Expand Up @@ -149,7 +138,7 @@ def reset(self) -> Tuple[torch.Tensor, Any]:
timestep = self._env.reset()
return self._observation_to_tensor(timestep.observation), {}

def render(self, *args, **kwargs) -> None:
def render(self, *args, **kwargs) -> np.ndarray:
"""Render the environment
OpenCV is used to render the environment.
Expand All @@ -158,11 +147,10 @@ def render(self, *args, **kwargs) -> None:
frame = self._env.physics.render(480, 640, camera_id=0)

# render the frame using OpenCV
import cv2
cv2.imshow("env", cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
cv2.waitKey(1)

def close(self) -> None:
"""Close the environment
"""
self._env.close()
try:
import cv2
cv2.imshow("env", cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
cv2.waitKey(1)
except ImportError as e:
logger.warning(f"Unable to import opencv-python: {e}. Frame will not be rendered.")
return frame

0 comments on commit 75a80b5

Please sign in to comment.