From 4a4364c07fbca359d43730c44df40f2d7abdf541 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Serrano=20Mu=C3=B1oz?= Date: Tue, 7 Jan 2025 16:10:31 -0500 Subject: [PATCH] Update Gymnasium checking for vectorized environments --- skrl/envs/wrappers/jax/gymnasium_envs.py | 14 +++++++++----- skrl/envs/wrappers/torch/gymnasium_envs.py | 14 +++++++++----- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/skrl/envs/wrappers/jax/gymnasium_envs.py b/skrl/envs/wrappers/jax/gymnasium_envs.py index 9f836fc3..30ad6cb7 100644 --- a/skrl/envs/wrappers/jax/gymnasium_envs.py +++ b/skrl/envs/wrappers/jax/gymnasium_envs.py @@ -26,13 +26,17 @@ def __init__(self, env: Any) -> None: self._vectorized = False try: - if isinstance(env, gymnasium.vector.VectorEnv) or isinstance(env, gymnasium.experimental.vector.VectorEnv): - self._vectorized = True - self._reset_once = True - self._observation = None - self._info = None + self._vectorized = self._vectorized or isinstance(env, gymnasium.vector.VectorEnv) + except Exception as e: + pass + try: + self._vectorized = self._vectorized or isinstance(env, gymnasium.experimental.vector.VectorEnv) except Exception as e: logger.warning(f"Failed to check for a vectorized environment: {e}") + if self._vectorized: + self._reset_once = True + self._observation = None + self._info = None @property def observation_space(self) -> gymnasium.Space: diff --git a/skrl/envs/wrappers/torch/gymnasium_envs.py b/skrl/envs/wrappers/torch/gymnasium_envs.py index 1708756e..49e43e0b 100644 --- a/skrl/envs/wrappers/torch/gymnasium_envs.py +++ b/skrl/envs/wrappers/torch/gymnasium_envs.py @@ -25,13 +25,17 @@ def __init__(self, env: Any) -> None: self._vectorized = False try: - if isinstance(env, gymnasium.vector.VectorEnv) or isinstance(env, gymnasium.experimental.vector.VectorEnv): - self._vectorized = True - self._reset_once = True - self._observation = None - self._info = None + self._vectorized = self._vectorized or isinstance(env, gymnasium.vector.VectorEnv) + except Exception as e: + pass + try: + self._vectorized = self._vectorized or isinstance(env, gymnasium.experimental.vector.VectorEnv) except Exception as e: logger.warning(f"Failed to check for a vectorized environment: {e}") + if self._vectorized: + self._reset_once = True + self._observation = None + self._info = None @property def observation_space(self) -> gymnasium.Space: