Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move gym (and gymnasium) env to root and rename gym.py for registration to gym_registration.py #509

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,13 @@ changelog = "https://github.com/mgbellemare/Arcade-Learning-Environment/blob/mas
ale-import-roms = "ale_py.scripts.import_roms:main"

[project.entry-points."gym.envs"]
ALE = "ale_py.gym:register_gym_envs"
__internal__ = "ale_py.gym:register_legacy_gym_envs"
ALE = "ale_py.gym_registration:register_gym_envs"
__internal__ = "ale_py.gym_registration:register_legacy_gym_envs"

[tool.setuptools]
packages = [
"ale_py",
"ale_py.roms",
"ale_py.env",
"ale_py.scripts"
]
package-dir = {ale_py = "src/python", gym = "src/gym"}
Expand Down
Empty file removed src/python/env/__init__.py
Empty file.
5 changes: 3 additions & 2 deletions src/python/env/gym.py → src/python/gym_env.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import sys
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

Expand All @@ -7,8 +9,7 @@
import numpy as np

import gym
import gym.logger as logger
from gym import error, spaces, utils
from gym import error, spaces, utils, logger

if sys.version_info < (3, 11):
from typing_extensions import NotRequired, TypedDict
Expand Down
187 changes: 187 additions & 0 deletions src/python/gym_registration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
from __future__ import annotations

from collections import defaultdict
from typing import Any, Callable, Mapping, NamedTuple, Sequence, Text, Union

import ale_py.roms as roms
from ale_py.roms import utils as rom_utils

from gym.envs.registration import register


class GymFlavour(NamedTuple):
suffix: str
kwargs: Union[Mapping[Text, Any], Callable[[str], Mapping[Text, Any]]]


class GymConfig(NamedTuple):
version: str
kwargs: Mapping[Text, Any]
flavours: Sequence[GymFlavour]


def _register_gym_configs(
roms: Sequence[str],
obs_types: Sequence[str],
configs: Sequence[GymConfig],
prefix: str = "",
) -> None:
if len(prefix) > 0 and prefix[-1] != "/":
prefix += "/"

for rom in roms:
for obs_type in obs_types:
for config in configs:
for flavour in config.flavours:
name = rom_utils.rom_id_to_name(rom)
name = f"{name}-ram" if obs_type == "ram" else name

# Parse config kwargs
config_kwargs = (
config.kwargs(rom) if callable(config.kwargs) else config.kwargs
)
# Parse flavour kwargs
flavour_kwargs = (
flavour.kwargs(rom)
if callable(flavour.kwargs)
else flavour.kwargs
)

# Register the environment
register(
id=f"{prefix}{name}{flavour.suffix}-{config.version}",
entry_point="ale_py.gym_env:AtariEnv",
kwargs=dict(
game=rom,
obs_type=obs_type,
**config_kwargs,
**flavour_kwargs,
),
)


def register_legacy_gym_envs() -> None:
legacy_games = [
"adventure",
"air_raid",
"alien",
"amidar",
"assault",
"asterix",
"asteroids",
"atlantis",
"bank_heist",
"battle_zone",
"beam_rider",
"berzerk",
"bowling",
"boxing",
"breakout",
"carnival",
"centipede",
"chopper_command",
"crazy_climber",
"defender",
"demon_attack",
"double_dunk",
"elevator_action",
"enduro",
"fishing_derby",
"freeway",
"frostbite",
"gopher",
"gravitar",
"hero",
"ice_hockey",
"jamesbond",
"journey_escape",
"kangaroo",
"krull",
"kung_fu_master",
"montezuma_revenge",
"ms_pacman",
"name_this_game",
"phoenix",
"pitfall",
"pong",
"pooyan",
"private_eye",
"qbert",
"riverraid",
"road_runner",
"robotank",
"seaquest",
"skiing",
"solaris",
"space_invaders",
"star_gunner",
"tennis",
"time_pilot",
"tutankham",
"up_n_down",
"venture",
"video_pinball",
"wizard_of_wor",
"yars_revenge",
"zaxxon",
]
obs_types = ["rgb", "ram"]
frameskip = defaultdict(lambda: 4, [("space_invaders", 3)])

versions = [
GymConfig(
version="v0",
kwargs={
"repeat_action_probability": 0.25,
"full_action_space": False,
"max_num_frames_per_episode": 108_000,
},
flavours=[
# Default for v0 has 10k steps, no idea why...
GymFlavour("", {"frameskip": (2, 5)}),
# Deterministic has 100k steps, close to the standard of 108k (30 mins gameplay)
GymFlavour("Deterministic", lambda rom: {"frameskip": frameskip[rom]}),
# NoFrameSkip imposes a max episode steps of frameskip * 100k, weird...
GymFlavour("NoFrameskip", {"frameskip": 1}),
],
),
GymConfig(
version="v4",
kwargs={
"repeat_action_probability": 0.0,
"full_action_space": False,
"max_num_frames_per_episode": 108_000,
},
flavours=[
# Unlike v0, v4 has 100k max episode steps
GymFlavour("", {"frameskip": (2, 5)}),
GymFlavour("Deterministic", lambda rom: {"frameskip": frameskip[rom]}),
# Same weird frameskip * 100k max steps for v4?
GymFlavour("NoFrameskip", {"frameskip": 1}),
],
),
]

_register_gym_configs(legacy_games, obs_types, versions)


def register_gym_envs():
all_games = list(map(rom_utils.rom_name_to_id, dir(roms)))
obs_types = ["rgb", "ram"]

# max_episode_steps is 108k frames which is 30 mins of gameplay.
# This corresponds to 108k / 4 = 27,000 steps
versions = [
GymConfig(
version="v5",
kwargs={
"repeat_action_probability": 0.25,
"full_action_space": False,
"frameskip": 4,
"max_num_frames_per_episode": 108_000,
},
flavours=[GymFlavour("", {})],
)
]

_register_gym_configs(all_games, obs_types, versions)
File renamed without changes.
6 changes: 4 additions & 2 deletions src/python/gym.py → src/python/gymnasium_registration.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations

from collections import defaultdict
from typing import Any, Callable, Mapping, NamedTuple, Sequence, Text, Union

import ale_py.roms as roms
from ale_py.roms import utils as rom_utils

from gym.envs.registration import register
import gymnasium


class GymFlavour(NamedTuple):
Expand Down Expand Up @@ -46,7 +48,7 @@ def _register_gym_configs(
)

# Register the environment
register(
gymnasium.register(
id=f"{prefix}{name}{flavour.suffix}-{config.version}",
entry_point="ale_py.env.gym:AtariEnv",
kwargs=dict(
Expand Down
2 changes: 1 addition & 1 deletion tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def tetris_gym(request, test_rom_path):
):
register(
id="TetrisTest-v0",
entry_point="ale_py.env.gym:AtariEnv",
entry_point="ale_py.gym_env:AtariEnv",
kwargs={"game": "tetris_test"},
)

Expand Down
9 changes: 6 additions & 3 deletions tests/python/gym/test_gym_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from unittest.mock import patch

import numpy as np
from ale_py.env.gym import AtariEnv
from ale_py.gym import (
from ale_py.gym_env import AtariEnv
from ale_py.gym_registration import (
_register_gym_configs,
register_gym_envs,
register_legacy_gym_envs,
Expand All @@ -30,7 +30,10 @@ def test_register_legacy_env_id():
def _mocked_register_gym_configs(*args, **kwargs):
return _original_register_gym_configs(*args, **kwargs, prefix=prefix)

with patch("ale_py.gym._register_gym_configs", new=_mocked_register_gym_configs):
with patch(
"ale_py.gym_registration._register_gym_configs",
new=_mocked_register_gym_configs,
):
# Register internal IDs
register_legacy_gym_envs()

Expand Down
Loading