From 8c92d0100aa764070734edf3d466999e44538daa Mon Sep 17 00:00:00 2001 From: pseudo-rnd-thoughts Date: Thu, 15 Feb 2024 14:46:16 +0000 Subject: [PATCH 1/3] Move gym (and gymnasium) env to root and rename `gym.py` for registration to `gym_registration.py` --- pyproject.toml | 5 +- src/python/env/__init__.py | 0 src/python/{env/gym.py => gym_env.py} | 8 +- src/python/gym_registration.py | 187 ++++++++++++++++++ src/python/{gymnasium.py => gymnasium_env.py} | 0 .../{gym.py => gymnasium_registration.py} | 6 +- 6 files changed, 198 insertions(+), 8 deletions(-) delete mode 100644 src/python/env/__init__.py rename src/python/{env/gym.py => gym_env.py} (99%) create mode 100644 src/python/gym_registration.py rename src/python/{gymnasium.py => gymnasium_env.py} (100%) rename src/python/{gym.py => gymnasium_registration.py} (98%) diff --git a/pyproject.toml b/pyproject.toml index aa2443541..c33286668 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,14 +54,13 @@ changelog = "https://github.com/mgbellemare/Arcade-Learning-Environment/blob/mas ale-import-roms = "ale_py.scripts.import_roms:main" [project.entry-points."gym.envs"] -ALE = "ale_py.gym:register_gym_envs" -__internal__ = "ale_py.gym:register_legacy_gym_envs" +ALE = "ale_py.gym_registration:register_gym_envs" +__internal__ = "ale_py.gym_registration:register_legacy_gym_envs" [tool.setuptools] packages = [ "ale_py", "ale_py.roms", - "ale_py.env", "ale_py.scripts" ] package-dir = {ale_py = "src/python", gym = "src/gym"} diff --git a/src/python/env/__init__.py b/src/python/env/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/python/env/gym.py b/src/python/gym_env.py similarity index 99% rename from src/python/env/gym.py rename to src/python/gym_env.py index 44353418c..33d616c12 100644 --- a/src/python/env/gym.py +++ b/src/python/gym_env.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import sys from typing import Any, Dict, List, Optional, Sequence, Tuple, Union @@ -6,9 +8,9 @@ import ale_py.roms.utils as rom_utils import numpy as np -import gym -import gym.logger as logger -from gym import error, spaces, utils +import gym_registration +import gym_registration.logger as logger +from gym_registration import error, spaces, utils if sys.version_info < (3, 11): from typing_extensions import NotRequired, TypedDict diff --git a/src/python/gym_registration.py b/src/python/gym_registration.py new file mode 100644 index 000000000..1e5022f35 --- /dev/null +++ b/src/python/gym_registration.py @@ -0,0 +1,187 @@ +from __future__ import annotations + +from collections import defaultdict +from typing import Any, Callable, Mapping, NamedTuple, Sequence, Text, Union + +import ale_py.roms as roms +from ale_py.roms import utils as rom_utils + +from gym.envs.registration import register + + +class GymFlavour(NamedTuple): + suffix: str + kwargs: Union[Mapping[Text, Any], Callable[[str], Mapping[Text, Any]]] + + +class GymConfig(NamedTuple): + version: str + kwargs: Mapping[Text, Any] + flavours: Sequence[GymFlavour] + + +def _register_gym_configs( + roms: Sequence[str], + obs_types: Sequence[str], + configs: Sequence[GymConfig], + prefix: str = "", +) -> None: + if len(prefix) > 0 and prefix[-1] != "/": + prefix += "/" + + for rom in roms: + for obs_type in obs_types: + for config in configs: + for flavour in config.flavours: + name = rom_utils.rom_id_to_name(rom) + name = f"{name}-ram" if obs_type == "ram" else name + + # Parse config kwargs + config_kwargs = ( + config.kwargs(rom) if callable(config.kwargs) else config.kwargs + ) + # Parse flavour kwargs + flavour_kwargs = ( + flavour.kwargs(rom) + if callable(flavour.kwargs) + else flavour.kwargs + ) + + # Register the environment + register( + id=f"{prefix}{name}{flavour.suffix}-{config.version}", + entry_point="ale_py.gym_env:AtariEnv", + kwargs=dict( + game=rom, + obs_type=obs_type, + **config_kwargs, + **flavour_kwargs, + ), + ) + + +def register_legacy_gym_envs() -> None: + legacy_games = [ + "adventure", + "air_raid", + "alien", + "amidar", + "assault", + "asterix", + "asteroids", + "atlantis", + "bank_heist", + "battle_zone", + "beam_rider", + "berzerk", + "bowling", + "boxing", + "breakout", + "carnival", + "centipede", + "chopper_command", + "crazy_climber", + "defender", + "demon_attack", + "double_dunk", + "elevator_action", + "enduro", + "fishing_derby", + "freeway", + "frostbite", + "gopher", + "gravitar", + "hero", + "ice_hockey", + "jamesbond", + "journey_escape", + "kangaroo", + "krull", + "kung_fu_master", + "montezuma_revenge", + "ms_pacman", + "name_this_game", + "phoenix", + "pitfall", + "pong", + "pooyan", + "private_eye", + "qbert", + "riverraid", + "road_runner", + "robotank", + "seaquest", + "skiing", + "solaris", + "space_invaders", + "star_gunner", + "tennis", + "time_pilot", + "tutankham", + "up_n_down", + "venture", + "video_pinball", + "wizard_of_wor", + "yars_revenge", + "zaxxon", + ] + obs_types = ["rgb", "ram"] + frameskip = defaultdict(lambda: 4, [("space_invaders", 3)]) + + versions = [ + GymConfig( + version="v0", + kwargs={ + "repeat_action_probability": 0.25, + "full_action_space": False, + "max_num_frames_per_episode": 108_000, + }, + flavours=[ + # Default for v0 has 10k steps, no idea why... + GymFlavour("", {"frameskip": (2, 5)}), + # Deterministic has 100k steps, close to the standard of 108k (30 mins gameplay) + GymFlavour("Deterministic", lambda rom: {"frameskip": frameskip[rom]}), + # NoFrameSkip imposes a max episode steps of frameskip * 100k, weird... + GymFlavour("NoFrameskip", {"frameskip": 1}), + ], + ), + GymConfig( + version="v4", + kwargs={ + "repeat_action_probability": 0.0, + "full_action_space": False, + "max_num_frames_per_episode": 108_000, + }, + flavours=[ + # Unlike v0, v4 has 100k max episode steps + GymFlavour("", {"frameskip": (2, 5)}), + GymFlavour("Deterministic", lambda rom: {"frameskip": frameskip[rom]}), + # Same weird frameskip * 100k max steps for v4? + GymFlavour("NoFrameskip", {"frameskip": 1}), + ], + ), + ] + + _register_gym_configs(legacy_games, obs_types, versions) + + +def register_gym_envs(): + all_games = list(map(rom_utils.rom_name_to_id, dir(roms))) + obs_types = ["rgb", "ram"] + + # max_episode_steps is 108k frames which is 30 mins of gameplay. + # This corresponds to 108k / 4 = 27,000 steps + versions = [ + GymConfig( + version="v5", + kwargs={ + "repeat_action_probability": 0.25, + "full_action_space": False, + "frameskip": 4, + "max_num_frames_per_episode": 108_000, + }, + flavours=[GymFlavour("", {})], + ) + ] + + _register_gym_configs(all_games, obs_types, versions) diff --git a/src/python/gymnasium.py b/src/python/gymnasium_env.py similarity index 100% rename from src/python/gymnasium.py rename to src/python/gymnasium_env.py diff --git a/src/python/gym.py b/src/python/gymnasium_registration.py similarity index 98% rename from src/python/gym.py rename to src/python/gymnasium_registration.py index 62aec69cf..5021d4289 100644 --- a/src/python/gym.py +++ b/src/python/gymnasium_registration.py @@ -1,10 +1,12 @@ +from __future__ import annotations + from collections import defaultdict from typing import Any, Callable, Mapping, NamedTuple, Sequence, Text, Union import ale_py.roms as roms from ale_py.roms import utils as rom_utils -from gym.envs.registration import register +import gymnasium class GymFlavour(NamedTuple): @@ -46,7 +48,7 @@ def _register_gym_configs( ) # Register the environment - register( + gymnasium.register( id=f"{prefix}{name}{flavour.suffix}-{config.version}", entry_point="ale_py.env.gym:AtariEnv", kwargs=dict( From 7ab6c47c087a3cc4640f4c6e6c783263799c40f0 Mon Sep 17 00:00:00 2001 From: pseudo-rnd-thoughts Date: Thu, 15 Feb 2024 14:58:41 +0000 Subject: [PATCH 2/3] Fix tests --- src/python/gym_env.py | 5 ++--- tests/fixtures.py | 2 +- tests/python/gym/test_gym_interface.py | 6 +++--- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/src/python/gym_env.py b/src/python/gym_env.py index 33d616c12..efc7875f8 100644 --- a/src/python/gym_env.py +++ b/src/python/gym_env.py @@ -8,9 +8,8 @@ import ale_py.roms.utils as rom_utils import numpy as np -import gym_registration -import gym_registration.logger as logger -from gym_registration import error, spaces, utils +import gym +from gym import error, spaces, utils, logger if sys.version_info < (3, 11): from typing_extensions import NotRequired, TypedDict diff --git a/tests/fixtures.py b/tests/fixtures.py index d868365ca..56f57e8b7 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -33,7 +33,7 @@ def tetris_gym(request, test_rom_path): ): register( id="TetrisTest-v0", - entry_point="ale_py.env.gym:AtariEnv", + entry_point="ale_py.gym_env:AtariEnv", kwargs={"game": "tetris_test"}, ) diff --git a/tests/python/gym/test_gym_interface.py b/tests/python/gym/test_gym_interface.py index cbc370834..10c4e081a 100644 --- a/tests/python/gym/test_gym_interface.py +++ b/tests/python/gym/test_gym_interface.py @@ -7,8 +7,8 @@ from unittest.mock import patch import numpy as np -from ale_py.env.gym import AtariEnv -from ale_py.gym import ( +from ale_py.gym_env import AtariEnv +from ale_py.gym_registration import ( _register_gym_configs, register_gym_envs, register_legacy_gym_envs, @@ -30,7 +30,7 @@ def test_register_legacy_env_id(): def _mocked_register_gym_configs(*args, **kwargs): return _original_register_gym_configs(*args, **kwargs, prefix=prefix) - with patch("ale_py.gym._register_gym_configs", new=_mocked_register_gym_configs): + with patch("ale_py.gym_registration._register_gym_configs", new=_mocked_register_gym_configs): # Register internal IDs register_legacy_gym_envs() From c334a51fc7df15f223f25c59792153ca3ff41afa Mon Sep 17 00:00:00 2001 From: pseudo-rnd-thoughts Date: Thu, 15 Feb 2024 15:00:11 +0000 Subject: [PATCH 3/3] pre-commit --- tests/python/gym/test_gym_interface.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/python/gym/test_gym_interface.py b/tests/python/gym/test_gym_interface.py index 10c4e081a..f25a58f81 100644 --- a/tests/python/gym/test_gym_interface.py +++ b/tests/python/gym/test_gym_interface.py @@ -30,7 +30,10 @@ def test_register_legacy_env_id(): def _mocked_register_gym_configs(*args, **kwargs): return _original_register_gym_configs(*args, **kwargs, prefix=prefix) - with patch("ale_py.gym_registration._register_gym_configs", new=_mocked_register_gym_configs): + with patch( + "ale_py.gym_registration._register_gym_configs", + new=_mocked_register_gym_configs, + ): # Register internal IDs register_legacy_gym_envs()