diff --git a/rllte/env/craftax/__init__.py b/rllte/env/craftax/__init__.py index 5a95468b..6604cd78 100644 --- a/rllte/env/craftax/__init__.py +++ b/rllte/env/craftax/__init__.py @@ -1,6 +1,6 @@ -from craftax.envs.craftax_pixels_env import CraftaxPixelsEnv -from craftax_classic.envs.craftax_pixels_env import CraftaxClassicPixelsEnv -from environment_base.wrappers import ( +from craftax.craftax.envs.craftax_pixels_env import CraftaxPixelsEnv +from craftax.craftax_classic.envs.craftax_pixels_env import CraftaxClassicPixelsEnv +from craftax.environment_base.wrappers import ( LogWrapper, BatchEnvWrapper, OptimisticResetVecEnvWrapper, @@ -8,20 +8,20 @@ from rllte.env.craftax.wrappers import TorchWrapper, ResizeTorchWrapper, RecordEpisodeStatistics4Craftax + def make_craftax_env( env_id: str = "Craftax-Classic", num_envs: int = 32, reset_ratio: int = 16, device: str = "cpu", - ): - +): if env_id == "Craftax-Classic": env = CraftaxClassicPixelsEnv() elif env_id == "Craftax": env = CraftaxPixelsEnv() else: raise ValueError(f"Unknown environment: {env_id}") - + env = LogWrapper(env) env = OptimisticResetVecEnvWrapper(env, num_envs=num_envs, reset_ratio=reset_ratio) env = TorchWrapper(env, device=device)