Skip to content

Commit

Permalink
Enhance wrappers doc and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ffelten committed Aug 7, 2024
1 parent f72773a commit f442ea4
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 12 deletions.
13 changes: 12 additions & 1 deletion mo_gymnasium/wrappers/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,17 @@ class MONormalizeReward(gym.Wrapper, gym.utils.RecordConstructorArgs):
"""Wrapper to normalize the reward component at index idx. Does not touch other reward components.
This code is heavily inspired on Gymnasium's except that it extracts the reward component at given idx, normalizes it, and reinjects it.
(!) This smoothes the moving average of the reward, which can be useful for training stability. But it does not "normalize" the reward in the sense of making it have a mean of 0 and a standard deviation of 1.
Example:
>>> import mo_gymnasium as mo_gym
>>> from mo_gymnasium.wrappers import MONormalizeReward
>>> env = mo_gym.make("deep-sea-treasure-v0")
>>> norm_treasure_env = MONormalizeReward(env, idx=0)
>>> both_norm_env = MONormalizeReward(norm_treasure_env, idx=1)
>>> both_norm_env.reset() # This one normalizes both rewards
"""

def __init__(self, env: gym.Env, idx: int, gamma: float = 0.99, epsilon: float = 1e-8):
Expand Down Expand Up @@ -193,12 +204,12 @@ def step(self, action):
info, dict
), f"`info` dtype is {type(info)} while supported dtype is `dict`. This may be due to usage of other wrappers in the wrong order."
self.episode_returns += rewards
self.episode_lengths += 1

# CHANGE: The discounted returns are also computed here
self.disc_episode_returns += rewards * np.repeat(self.gamma**self.episode_lengths, self.reward_dim).reshape(
self.episode_returns.shape
)
self.episode_lengths += 1

if terminated or truncated:
assert self._stats_key not in info
Expand Down
14 changes: 10 additions & 4 deletions tests/test_vector_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,22 +46,28 @@ def test_mo_sync_autoreset():


def test_mo_record_ep_statistic_vector_env():
num_envs = 3
num_envs = 2
envs = MOSyncVectorEnv([lambda: mo_gym.make("deep-sea-treasure-v0") for _ in range(num_envs)])
envs = MORecordEpisodeStatistics(envs)
envs = MORecordEpisodeStatistics(envs, gamma=0.97)

envs.reset()
terminateds = np.array([False] * num_envs)
info = {}
while not np.any(terminateds):
obs, rewards, terminateds, _, info = envs.step(envs.action_space.sample())
obs, rewards, terminateds, _, info = envs.step([0, 3])
obs, rewards, terminateds, _, info = envs.step([0, 1])
obs, rewards, terminateds, _, info = envs.step([0, 1])

assert isinstance(info["episode"]["r"], np.ndarray)
assert isinstance(info["episode"]["dr"], np.ndarray)
# Episode records are vectorized because multiple environments
assert info["episode"]["r"].shape == (num_envs, 2)
np.testing.assert_almost_equal(info["episode"]["r"][0], np.array([0.0, 0.0], dtype=np.float32), decimal=2)
np.testing.assert_almost_equal(info["episode"]["r"][1], np.array([8.2, -3.0], dtype=np.float32), decimal=2)
assert info["episode"]["dr"].shape == (num_envs, 2)
np.testing.assert_almost_equal(info["episode"]["dr"][0], np.array([0.0, 0.0], dtype=np.float32), decimal=2)
np.testing.assert_almost_equal(info["episode"]["dr"][1], np.array([7.72, -2.91], dtype=np.float32), decimal=2)
assert isinstance(info["episode"]["l"], np.ndarray)
np.testing.assert_almost_equal(info["episode"]["l"], np.array([0, 3], dtype=np.float32), decimal=2)
assert isinstance(info["episode"]["t"], np.ndarray)
envs.close()

Expand Down
14 changes: 7 additions & 7 deletions tests/test_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,35 +20,35 @@ def go_to_8_3(env):


def test_normalization_wrapper():
# Watch out that the wrapper does not normalize the rewards to have a mean of 0 and std of 1
# instead it smoothens the moving average of the rewards
env = mo_gym.make("deep-sea-treasure-v0")
norm_treasure_env = MONormalizeReward(env, idx=0)
both_norm_env = MONormalizeReward(norm_treasure_env, idx=1)

# No normalization
env.reset(seed=0)
_, rewards, _, _, _ = env.step(1)
np.testing.assert_allclose(rewards, [0.7, -1.0], rtol=0, atol=1e-2)
np.testing.assert_almost_equal(rewards, [0.7, -1.0], decimal=2)

# Tests for both rewards normalized
for i in range(30):
go_to_8_3(both_norm_env)
both_norm_env.reset(seed=0)
_, rewards, _, _, _ = both_norm_env.step(1) # down
np.testing.assert_allclose(
rewards, [0.49, -1.24], rtol=0, atol=1e-2
) # TODO PR check why we had to change those values @Mark?
np.testing.assert_almost_equal(rewards, [0.5, -1.24], decimal=2)
rewards, _ = go_to_8_3(both_norm_env)
np.testing.assert_allclose(rewards, [4.73, -1.24], rtol=0, atol=1e-2)
np.testing.assert_almost_equal(rewards, [4.73, -1.24], decimal=2)

# Tests for only treasure normalized
for i in range(30):
go_to_8_3(norm_treasure_env)
norm_treasure_env.reset(seed=0)
_, rewards, _, _, _ = norm_treasure_env.step(1) # down
# Time rewards are not normalized (-1)
np.testing.assert_allclose(rewards, [0.51, -1.0], rtol=0, atol=1e-2)
np.testing.assert_almost_equal(rewards, [0.51, -1.0], decimal=2)
rewards, _ = go_to_8_3(norm_treasure_env)
np.testing.assert_allclose(rewards, [5.33, -1.0], rtol=0, atol=1e-2)
np.testing.assert_almost_equal(rewards, [5.33, -1.0], decimal=2)


def test_clip_wrapper():
Expand Down

0 comments on commit f442ea4

Please sign in to comment.