From b9b07eb694c65c02c8892520cefcd0af11beda73 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Serrano=20Mu=C3=B1oz?= Date: Fri, 21 Jun 2024 17:19:09 -0400 Subject: [PATCH] Improve auto wrapper detection implementation --- skrl/envs/wrappers/jax/__init__.py | 80 +++++++++++++--------------- skrl/envs/wrappers/torch/__init__.py | 80 +++++++++++++--------------- 2 files changed, 74 insertions(+), 86 deletions(-) diff --git a/skrl/envs/wrappers/jax/__init__.py b/skrl/envs/wrappers/jax/__init__.py index 70ac2a7c..fd0fdf6a 100644 --- a/skrl/envs/wrappers/jax/__init__.py +++ b/skrl/envs/wrappers/jax/__init__.py @@ -63,50 +63,44 @@ def wrap_env(env: Any, wrapper: str = "auto", verbose: bool = True) -> Union[Wra :return: Wrapped environment :rtype: Wrapper or MultiAgentEnvWrapper """ - if verbose: - logger.info("Environment class: {}".format(", ".join([str(base).replace("", "") \ - for base in env.__class__.__bases__]))) - if wrapper == "auto": - base_classes = [str(base) for base in env.__class__.__bases__] - if "" in base_classes or \ - "" in base_classes: - if verbose: - logger.info("Environment wrapper: Omniverse Isaac Gym") - return OmniverseIsaacGymWrapper(env) - elif isinstance(env, gym.core.Env) or isinstance(env, gym.core.Wrapper): - # isaaclab - if hasattr(env, "sim") and hasattr(env, "env_ns"): - if verbose: - logger.info("Environment wrapper: Isaac Lab") - return IsaacLabWrapper(env) - # gym - if verbose: - logger.info("Environment wrapper: Gym") - return GymWrapper(env) - elif isinstance(env, gymnasium.core.Env) or isinstance(env, gymnasium.core.Wrapper): - if verbose: - logger.info("Environment wrapper: Gymnasium") - return GymnasiumWrapper(env) - elif "" in base_classes: - if verbose: - logger.info("Environment wrapper: DeepMind") - return DeepMindWrapper(env) - elif "" in base_classes: - if verbose: - logger.info("Environment wrapper: Isaac Gym (preview 2)") - return IsaacGymPreview2Wrapper(env) + def _get_wrapper_name(env, verbose): + def _in(value, container): + for item in container: + if value in item: + return True + return False + + base_classes = [str(base).replace("", "") for base in env.__class__.__bases__] + try: + base_classes += [str(base).replace("", "") for base in env.unwrapped.__class__.__bases__] + except: + pass + base_classes = sorted(list(set(base_classes))) if verbose: - logger.info("Environment wrapper: Isaac Gym (preview 3/4)") - return IsaacGymPreview3Wrapper(env) # preview 4 is the same as 3 - elif wrapper == "gym": + logger.info(f"Environment wrapper: 'auto' (class: {', '.join(base_classes)})") + + if _in("omni.isaac.lab.envs.manager_based_env.ManagerBasedEnv", base_classes) or _in("omni.isaac.lab.envs.direct_rl_env.DirectRLEnv", base_classes): + return "isaaclab" + elif _in("omni.isaac.gym.vec_env.vec_env_base.VecEnvBase", base_classes) or _in("omni.isaac.gym.vec_env.vec_env_mt.VecEnvMT", base_classes): + return "omniverse-isaacgym" + elif _in("rlgpu.tasks.base.vec_task.VecTask", base_classes): + return "isaacgym-preview2" + elif _in("robosuite.environments.", base_classes): + return "robosuite" + elif _in("dm_env._environment.Environment.", base_classes): + return "dm" + elif _in("pettingzoo.utils.env", base_classes) or _in("pettingzoo.utils.wrappers", base_classes): + return "pettingzoo" + elif _in("gymnasium.core.Env", base_classes) or _in("gymnasium.core.Wrapper", base_classes): + return "gymnasium" + elif _in("gym.core.Env", base_classes) or _in("gym.core.Wrapper", base_classes): + return "gym" + return base_classes + + if wrapper == "auto": + wrapper = _get_wrapper_name(env, verbose) + + if wrapper == "gym": if verbose: logger.info("Environment wrapper: Gym") return GymWrapper(env) diff --git a/skrl/envs/wrappers/torch/__init__.py b/skrl/envs/wrappers/torch/__init__.py index c4fe61f2..1fc9898b 100644 --- a/skrl/envs/wrappers/torch/__init__.py +++ b/skrl/envs/wrappers/torch/__init__.py @@ -69,50 +69,44 @@ def wrap_env(env: Any, wrapper: str = "auto", verbose: bool = True) -> Union[Wra :return: Wrapped environment :rtype: Wrapper or MultiAgentEnvWrapper """ - if verbose: - logger.info("Environment class: {}".format(", ".join([str(base).replace("", "") \ - for base in env.__class__.__bases__]))) - if wrapper == "auto": - base_classes = [str(base) for base in env.__class__.__bases__] - if "" in base_classes or \ - "" in base_classes: - if verbose: - logger.info("Environment wrapper: Omniverse Isaac Gym") - return OmniverseIsaacGymWrapper(env) - elif isinstance(env, gym.core.Env) or isinstance(env, gym.core.Wrapper): - # isaaclab - if hasattr(env, "sim") and hasattr(env, "env_ns"): - if verbose: - logger.info("Environment wrapper: Isaac Lab") - return IsaacLabWrapper(env) - # gym - if verbose: - logger.info("Environment wrapper: Gym") - return GymWrapper(env) - elif isinstance(env, gymnasium.core.Env) or isinstance(env, gymnasium.core.Wrapper): - if verbose: - logger.info("Environment wrapper: Gymnasium") - return GymnasiumWrapper(env) - elif "" in base_classes: - if verbose: - logger.info("Environment wrapper: DeepMind") - return DeepMindWrapper(env) - elif "" in base_classes: - if verbose: - logger.info("Environment wrapper: Isaac Gym (preview 2)") - return IsaacGymPreview2Wrapper(env) + def _get_wrapper_name(env, verbose): + def _in(value, container): + for item in container: + if value in item: + return True + return False + + base_classes = [str(base).replace("", "") for base in env.__class__.__bases__] + try: + base_classes += [str(base).replace("", "") for base in env.unwrapped.__class__.__bases__] + except: + pass + base_classes = sorted(list(set(base_classes))) if verbose: - logger.info("Environment wrapper: Isaac Gym (preview 3/4)") - return IsaacGymPreview3Wrapper(env) # preview 4 is the same as 3 - elif wrapper == "gym": + logger.info(f"Environment wrapper: 'auto' (class: {', '.join(base_classes)})") + + if _in("omni.isaac.lab.envs.manager_based_env.ManagerBasedEnv", base_classes) or _in("omni.isaac.lab.envs.direct_rl_env.DirectRLEnv", base_classes): + return "isaaclab" + elif _in("omni.isaac.gym.vec_env.vec_env_base.VecEnvBase", base_classes) or _in("omni.isaac.gym.vec_env.vec_env_mt.VecEnvMT", base_classes): + return "omniverse-isaacgym" + elif _in("rlgpu.tasks.base.vec_task.VecTask", base_classes): + return "isaacgym-preview2" + elif _in("robosuite.environments.", base_classes): + return "robosuite" + elif _in("dm_env._environment.Environment.", base_classes): + return "dm" + elif _in("pettingzoo.utils.env", base_classes) or _in("pettingzoo.utils.wrappers", base_classes): + return "pettingzoo" + elif _in("gymnasium.core.Env", base_classes) or _in("gymnasium.core.Wrapper", base_classes): + return "gymnasium" + elif _in("gym.core.Env", base_classes) or _in("gym.core.Wrapper", base_classes): + return "gym" + return base_classes + + if wrapper == "auto": + wrapper = _get_wrapper_name(env, verbose) + + if wrapper == "gym": if verbose: logger.info("Environment wrapper: Gym") return GymWrapper(env)