Skip to content

Commit

Permalink
setup atari experiments
Browse files Browse the repository at this point in the history
Summary: Setup a DQN experiment in an atari env in benchmark.

Reviewed By: rodrigodesalvobraz

Differential Revision: D66204020

fbshipit-source-id: 639a51529fb70cacf9f9471ec85c566a58c1f86f
  • Loading branch information
yiwan-rl authored and facebook-github-bot committed Dec 10, 2024
1 parent 61eaf77 commit f435791
Show file tree
Hide file tree
Showing 4 changed files with 278 additions and 2 deletions.
5 changes: 5 additions & 0 deletions pearl/user_envs/wrappers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

# pyre-strict

from .atari_wrappers import EpisodicLifeEnv, FireResetEnv, MaxAndSkipEnv, NoopResetEnv
from .dynamic_action_env import DynamicActionSpaceWrapper
from .gym_avg_torque_cost import GymAvgTorqueWrapper
from .partial_observability import (
Expand Down Expand Up @@ -39,4 +40,8 @@
"PartialObservableWrapper",
"GymAvgTorqueWrapper",
"DynamicActionSpaceWrapper",
"NoopResetEnv",
"FireResetEnv",
"EpisodicLifeEnv",
"MaxAndSkipEnv",
]
185 changes: 185 additions & 0 deletions pearl/user_envs/wrappers/atari_wrappers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
# pyre-ignore-all-errors

"""
The code from this file is copied from https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/common/atari_wrappers.py
"""

from typing import Any, Dict, SupportsFloat, Tuple

import gymnasium as gym
import numpy as np
from gymnasium import spaces


AtariResetReturn = Tuple[np.ndarray, Dict[str, Any]]
AtariStepReturn = Tuple[np.ndarray, SupportsFloat, bool, bool, Dict[str, Any]]

try:
import cv2

cv2.ocl.setUseOpenCL(False)
except ImportError:
cv2 = None # type: ignore[assignment]


class NoopResetEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]):
"""
Sample initial states by taking random number of no-ops on reset.
No-op is assumed to be action 0.
:param env: Environment to wrap
:param noop_max: Maximum value of no-ops to run
"""

def __init__(self, env: gym.Env, noop_max: int = 30) -> None:
super().__init__(env)
self.noop_max = noop_max
self.override_num_noops = None
self.noop_action = 0
assert env.unwrapped.get_action_meanings()[0] == "NOOP" # type: ignore[attr-defined]

def reset(self, **kwargs) -> AtariResetReturn:
self.env.reset(**kwargs)
if self.override_num_noops is not None:
noops = self.override_num_noops
else:
noops = self.unwrapped.np_random.integers(1, self.noop_max + 1)
assert noops > 0
obs = np.zeros(0)
info: Dict = {}
for _ in range(noops):
obs, _, terminated, truncated, info = self.env.step(self.noop_action)
if terminated or truncated:
obs, info = self.env.reset(**kwargs)
return obs, info


class FireResetEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]):
"""
Take action on reset for environments that are fixed until firing.
:param env: Environment to wrap
"""

def __init__(self, env: gym.Env) -> None:
super().__init__(env)
assert env.unwrapped.get_action_meanings()[1] == "FIRE" # type: ignore[attr-defined]
assert len(env.unwrapped.get_action_meanings()) >= 3 # type: ignore[attr-defined]

def reset(self, **kwargs) -> AtariResetReturn:
self.env.reset(**kwargs)
obs, _, terminated, truncated, _ = self.env.step(1)
if terminated or truncated:
self.env.reset(**kwargs)
obs, _, terminated, truncated, _ = self.env.step(2)
if terminated or truncated:
self.env.reset(**kwargs)
return obs, {}


class EpisodicLifeEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]):
"""
Make end-of-life == end-of-episode, but only reset on true game over.
Done by DeepMind for the DQN and co. since it helps value estimation.
:param env: Environment to wrap
"""

def __init__(self, env: gym.Env) -> None:
super().__init__(env)
self.lives = 0
self.was_real_done = True

def step(self, action: int) -> AtariStepReturn:
obs, reward, terminated, truncated, info = self.env.step(action)
self.was_real_done = terminated or truncated
# check current lives, make loss of life terminal,
# then update lives to handle bonus lives
lives = self.env.unwrapped.ale.lives() # type: ignore[attr-defined]
if 0 < lives < self.lives:
# for Qbert sometimes we stay in lives == 0 condition for a few frames
# so its important to keep lives > 0, so that we only reset once
# the environment advertises done.
terminated = True
self.lives = lives
return obs, reward, terminated, truncated, info

def reset(self, **kwargs) -> AtariResetReturn:
"""
Calls the Gym environment reset, only when lives are exhausted.
This way all states are still reachable even though lives are episodic,
and the learner need not know about any of this behind-the-scenes.
:param kwargs: Extra keywords passed to env.reset() call
:return: the first observation of the environment
"""
if self.was_real_done:
obs, info = self.env.reset(**kwargs)
else:
# no-op step to advance from terminal/lost life state
obs, _, terminated, truncated, info = self.env.step(0)

# The no-op step can lead to a game over, so we need to check it again
# to see if we should reset the environment and avoid the
# monitor.py `RuntimeError: Tried to step environment that needs reset`
if terminated or truncated:
obs, info = self.env.reset(**kwargs)
self.lives = self.env.unwrapped.ale.lives() # type: ignore[attr-defined]
return obs, info


class MaxAndSkipEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]):
"""
Return only every ``skip``-th frame (frameskipping)
and return the max between the two last frames.
:param env: Environment to wrap
:param skip: Number of ``skip``-th frame
The same action will be taken ``skip`` times.
"""

def __init__(self, env: gym.Env, skip: int = 4) -> None:
super().__init__(env)
# most recent raw observations (for max pooling across time steps)
assert (
env.observation_space.dtype is not None
), "No dtype specified for the observation space"
assert (
env.observation_space.shape is not None
), "No shape defined for the observation space"
self._obs_buffer = np.zeros(
(2, *env.observation_space.shape), dtype=env.observation_space.dtype
)
self._skip = skip

def step(self, action: int) -> AtariStepReturn:
"""
Step the environment with the given action
Repeat action, sum reward, and max over last observations.
:param action: the action
:return: observation, reward, terminated, truncated, information
"""
total_reward = 0.0
terminated = truncated = False
for i in range(self._skip):
obs, reward, terminated, truncated, info = self.env.step(action)
done = terminated or truncated
if i == self._skip - 2:
self._obs_buffer[0] = obs
if i == self._skip - 1:
self._obs_buffer[1] = obs
total_reward += float(reward)
if done:
break
# Note that the observation on the done=True frame
# doesn't matter
max_frame = self._obs_buffer.max(axis=0)

return max_frame, total_reward, terminated, truncated, info
30 changes: 28 additions & 2 deletions pearl/utils/scripts/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,16 @@
import warnings
from typing import List

import ale_py
import matplotlib.pyplot as plt
import numpy as np
import torch.multiprocessing as mp
from pearl.action_representation_modules.identity_action_representation_module import (
IdentityActionRepresentationModule,
)
from pearl.neural_networks.sequential_decision_making.q_value_networks import (
CNNQValueNetwork,
)
from pearl.pearl_agent import PearlAgent
from pearl.utils.functional_utils.experimentation.set_seed import set_seed

Expand All @@ -30,6 +37,7 @@
benchmark_acrobot_v1_part_1,
benchmark_acrobot_v1_part_2,
benchmark_ant_v4,
benchmark_atari,
benchmark_cartpole_v1_part_1,
benchmark_cartpole_v1_part_2,
benchmark_halfcheetah_v4,
Expand Down Expand Up @@ -166,6 +174,11 @@ def evaluate_single(
"action_representation_module"
](**method["action_representation_module_args"])

else:
policy_learner_args["action_representation_module"] = (
IdentityActionRepresentationModule()
)

if (
"history_summarization_module" in method
and "history_summarization_module_args" in method
Expand Down Expand Up @@ -206,6 +219,18 @@ def evaluate_single(
"history_summarization_module"
](**method["history_summarization_module_args"])

if "network_module" in method and method["network_module"] is CNNQValueNetwork:
policy_learner_args["network_instance"] = method["network_module"](
input_width=env.observation_space.shape[2],
input_height=env.observation_space.shape[1],
input_channels_count=env.observation_space.shape[0],
action_dim=policy_learner_args[
"action_representation_module"
].representation_dim,
output_dim=1,
**method["network_args"],
)

if method["name"] == "DuelingDQN": # only for Dueling DQN
assert "network_module" in method and "network_args" in method
policy_learner_args["network_instance"] = method["network_module"](
Expand Down Expand Up @@ -309,8 +334,9 @@ def generate_one_plot(experiment, attributes):


if __name__ == "__main__":
run(benchmark_pendulum_v1_lstm)
generate_plots(benchmark_pendulum_v1_lstm, ["return"])
run(benchmark_atari)
# run(benchmark_pendulum_v1_lstm)
# generate_plots(benchmark_pendulum_v1_lstm, ["return"])
# run(benchmark_pendulum_v1_lstm2)
# generate_plots(benchmark_pendulum_v1_lstm2, ["return"])
# run(benchmark_pendulum_v1_lstm3)
Expand Down
60 changes: 60 additions & 0 deletions pearl/utils/scripts/benchmark_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
VanillaContinuousActorNetwork,
)
from pearl.neural_networks.sequential_decision_making.q_value_networks import (
CNNQValueNetwork,
DuelingQValueNetwork,
EnsembleQValueNetwork,
VanillaQValueNetwork,
Expand Down Expand Up @@ -91,9 +92,13 @@
AcrobotPartialObservableWrapper,
AcrobotSparseRewardWrapper,
CartPolePartialObservableWrapper,
EpisodicLifeEnv,
FireResetEnv,
GymAvgTorqueWrapper,
MaxAndSkipEnv,
MountainCarPartialObservableWrapper,
MountainCarSparseRewardWrapper,
NoopResetEnv,
PendulumPartialObservableWrapper,
PendulumSparseRewardWrapper,
PuckWorldPartialObservableWrapper,
Expand Down Expand Up @@ -140,6 +145,29 @@
"num_layers": 1,
},
}
DQN_Atari_method = {
"name": "DQN",
"policy_learner": DeepQLearning,
"policy_learner_args": {
"training_rounds": 1,
"target_update_freq": 250,
"batch_size": 32,
},
"network_module": CNNQValueNetwork,
"network_args": {
"hidden_dims_fully_connected": [512],
"kernel_sizes": [8, 4, 3],
"output_channels_list": [32, 64, 64],
"strides": [4, 2, 1],
"paddings": [0, 0, 0],
},
"exploration_module": EGreedyExploration,
"exploration_module_args": {"epsilon": 0.1},
"replay_buffer": BasicReplayBuffer,
"replay_buffer_args": {"capacity": 50000},
"action_representation_module": OneHotActionTensorRepresentationModule,
"action_representation_module_args": {},
}
CDQN_method = {
"name": "Conservative DQN",
"policy_learner": DeepQLearning,
Expand Down Expand Up @@ -1309,6 +1337,24 @@
]
]

benchmark_atari = [
{
"exp_name": "benchmark_atari",
"env_name": env_name,
"num_runs": 1,
"num_steps": classic_control_steps,
"print_every_x_steps": print_every_x_steps,
"record_period": 10000,
"methods": [
DQN_Atari_method,
],
"device_id": 0,
}
for env_name in [
"PongNoFrameskip-v4",
]
]

benchmark_cartpole_v1_part_2 = [
{
"exp_name": "benchmark_cartpole_v1_part_2",
Expand Down Expand Up @@ -1557,5 +1603,19 @@ def get_env(env_name: str) -> GymEnvironment:
elif env_name[-7:] == "_w_cost":
env_name = env_name[:-7]
return GymEnvironment(GymAvgTorqueWrapper(gym.make(env_name)))
elif "ALE/" in env_name or "NoFrameskip" in env_name:
# Atari envs
env = gym.make(
env_name,
)
env = NoopResetEnv(env, noop_max=30)
env = MaxAndSkipEnv(env, skip=4)
env = EpisodicLifeEnv(env)
if "FIRE" in env.unwrapped.get_action_meanings():
env = FireResetEnv(env)
env = gym.wrappers.ResizeObservation(env, (84, 84))
env = gym.wrappers.GrayscaleObservation(env)
env = gym.wrappers.FrameStackObservation(env, 4)
return GymEnvironment(env)
else:
return GymEnvironment(env_name)

0 comments on commit f435791

Please sign in to comment.