Skip to content

Commit

Permalink
Tests/Added tests for env wrappers (#296)
Browse files Browse the repository at this point in the history
* Tests/Added tests for env wrappers

* Tests/Deleted unnecessary file

* Tests/Minor fixes to previous commit

---------

Co-authored-by: Locatelli Alex Giannino <[email protected]>
Co-authored-by: Michele Milesi <[email protected]>
  • Loading branch information
3 people authored May 31, 2024
1 parent 62b3da0 commit dee8c80
Show file tree
Hide file tree
Showing 4 changed files with 331 additions and 83 deletions.
44 changes: 31 additions & 13 deletions sheeprl/envs/dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,18 @@ def __init__(
image_size: Tuple[int, int, int] = (3, 64, 64),
n_steps: int = 128,
vector_shape: Tuple[int] = (10,),
dict_obs_space: bool = True,
):
self.observation_space = gym.spaces.Dict(
{
"rgb": gym.spaces.Box(0, 256, shape=image_size, dtype=np.uint8),
"state": gym.spaces.Box(-20, 20, shape=vector_shape, dtype=np.float32),
}
)
self._dict_obs_space = dict_obs_space
if self._dict_obs_space:
self.observation_space = gym.spaces.Dict(
{
"rgb": gym.spaces.Box(0, 256, shape=image_size, dtype=np.uint8),
"state": gym.spaces.Box(-20, 20, shape=vector_shape, dtype=np.float32),
}
)
else:
self.observation_space = gym.spaces.Box(-20, 20, shape=vector_shape, dtype=np.float32)
self.reward_range = (-np.inf, np.inf)
self._current_step = 0
self._n_steps = n_steps
Expand All @@ -35,10 +40,14 @@ def step(self, action):
)

def get_obs(self) -> Dict[str, np.ndarray]:
return {
"rgb": np.zeros(self.observation_space["rgb"].shape, dtype=np.uint8),
"state": np.zeros(self.observation_space["state"].shape, dtype=np.float32),
}
if self._dict_obs_space:
return {
# da sostituire con np.random.rand
"rgb": np.full(self.observation_space["rgb"].shape, self._current_step % 256, dtype=np.uint8),
"state": np.full(self.observation_space["state"].shape, self._current_step, dtype=np.uint8),
}
else:
return np.full(self.observation_space.shape, self._current_step, dtype=np.uint8)

def reset(self, seed=None, options=None):
self._current_step = 0
Expand All @@ -61,9 +70,12 @@ def __init__(
n_steps: int = 128,
vector_shape: Tuple[int] = (10,),
action_dim: int = 2,
dict_obs_space: bool = True,
):
self.action_space = gym.spaces.Box(-np.inf, np.inf, shape=(action_dim,))
super().__init__(image_size=image_size, n_steps=n_steps, vector_shape=vector_shape)
super().__init__(
image_size=image_size, n_steps=n_steps, vector_shape=vector_shape, dict_obs_space=dict_obs_space
)


class DiscreteDummyEnv(BaseDummyEnv):
Expand All @@ -73,9 +85,12 @@ def __init__(
n_steps: int = 4,
vector_shape: Tuple[int] = (10,),
action_dim: int = 2,
dict_obs_space: bool = True,
):
self.action_space = gym.spaces.Discrete(action_dim)
super().__init__(image_size=image_size, n_steps=n_steps, vector_shape=vector_shape)
super().__init__(
image_size=image_size, n_steps=n_steps, vector_shape=vector_shape, dict_obs_space=dict_obs_space
)


class MultiDiscreteDummyEnv(BaseDummyEnv):
Expand All @@ -85,6 +100,9 @@ def __init__(
n_steps: int = 128,
vector_shape: Tuple[int] = (10,),
action_dims: List[int] = [2, 2],
dict_obs_space: bool = True,
):
self.action_space = gym.spaces.MultiDiscrete(action_dims)
super().__init__(image_size=image_size, n_steps=n_steps, vector_shape=vector_shape)
super().__init__(
image_size=image_size, n_steps=n_steps, vector_shape=vector_shape, dict_obs_space=dict_obs_space
)
126 changes: 126 additions & 0 deletions tests/test_envs/test_actions_as_observations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
from collections import deque

import gymnasium as gym
import numpy as np
import pytest

from sheeprl.envs.dummy import ContinuousDummyEnv, DiscreteDummyEnv, MultiDiscreteDummyEnv
from sheeprl.envs.wrappers import ActionsAsObservationWrapper

ENVIRONMENTS = {
"discrete_dummy": DiscreteDummyEnv,
"multidiscrete_dummy": MultiDiscreteDummyEnv,
"continuous_dummy": ContinuousDummyEnv,
}


@pytest.mark.parametrize("num_stack", [1, 4, 8])
@pytest.mark.parametrize("dilation", [1, 2, 4])
@pytest.mark.parametrize("env_id", ENVIRONMENTS.keys())
def test_actions_as_observation_wrapper(env_id: str, num_stack, dilation):
env = ENVIRONMENTS[env_id]()
if isinstance(env.action_space, gym.spaces.MultiDiscrete):
noop = [0, 0]
else:
noop = 0
env = ActionsAsObservationWrapper(env, num_stack=num_stack, noop=noop, dilation=dilation)

o = env.reset()[0]
assert len(o["action_stack"].shape) == len(env.observation_space["action_stack"].shape)
for d1, d2 in zip(o["action_stack"].shape, env.observation_space["action_stack"].shape):
assert d1 == d2

actions = []
for _ in range(8):
action = env.action_space.sample()
actions.append(action)
o = env.step(action)[0]

# Ensure the shapes match
assert len(o["action_stack"].shape) == len(env.observation_space["action_stack"].shape)
for d1, d2 in zip(o["action_stack"].shape, env.observation_space["action_stack"].shape):
assert d1 == d2

expected_actions = deque(maxlen=num_stack * dilation)
if len(actions) < num_stack * dilation:
for _ in range(num_stack * dilation - len(actions)):
expected_actions.append(env.noop)
for past_action in actions[-(num_stack * dilation) :]:
if isinstance(env.action_space, gym.spaces.Box):
expected_actions.append(past_action)
elif isinstance(env.action_space, gym.spaces.MultiDiscrete):
one_hot_actions = []
for act, n in zip(past_action, env.action_space.nvec):
one_hot_actions.append(np.zeros((n,), dtype=np.float32))
one_hot_actions[-1][act] = 1.0
expected_actions.append(np.concatenate(one_hot_actions, axis=-1))
else:
one_hot_action = np.zeros((env.action_space.n,), dtype=np.float32)
one_hot_action[past_action] = 1.0
expected_actions.append(one_hot_action)

expected_actions_stack = list(expected_actions)[dilation - 1 :: dilation]
expected_actions_stack = np.concatenate(expected_actions_stack, axis=-1).astype(np.float32)

np.testing.assert_array_equal(o["action_stack"], expected_actions_stack)


@pytest.mark.parametrize("num_stack", [-1, 0])
@pytest.mark.parametrize("env_id", ENVIRONMENTS.keys())
def test_actions_as_observation_wrapper_invalid_num_stack(env_id, num_stack):
env = ENVIRONMENTS[env_id]()
if isinstance(env.action_space, gym.spaces.MultiDiscrete):
noop = [0, 0]
else:
noop = 0
with pytest.raises(ValueError, match="The number of actions to the"):
env = ActionsAsObservationWrapper(env, num_stack=num_stack, noop=noop, dilation=3)


@pytest.mark.parametrize("dilation", [-1, 0])
@pytest.mark.parametrize("env_id", ENVIRONMENTS.keys())
def test_actions_as_observation_wrapper_invalid_dilation(env_id, dilation):
env = ENVIRONMENTS[env_id]()
if isinstance(env.action_space, gym.spaces.MultiDiscrete):
noop = [0, 0]
else:
noop = 0
with pytest.raises(ValueError, match="The actions stack dilation argument must be greater than zero"):
env = ActionsAsObservationWrapper(env, num_stack=3, noop=noop, dilation=dilation)


@pytest.mark.parametrize("noop", [set([0, 0, 0]), "this is an invalid type", np.array([0, 0, 0])])
@pytest.mark.parametrize("env_id", ENVIRONMENTS.keys())
def test_actions_as_observation_wrapper_invalid_noop_type(env_id, noop):
env = ENVIRONMENTS[env_id]()
with pytest.raises(ValueError, match="The noop action must be an integer or float or list"):
env = ActionsAsObservationWrapper(env, num_stack=3, noop=noop, dilation=2)


def test_actions_as_observation_wrapper_invalid_noop_continuous_type():
env = ContinuousDummyEnv()
with pytest.raises(ValueError, match="The noop actions must be a float for continuous action spaces"):
env = ActionsAsObservationWrapper(env, num_stack=3, noop=[0, 0, 0], dilation=2)


@pytest.mark.parametrize("noop", [[0, 0, 0], 0.0])
def test_actions_as_observation_wrapper_invalid_noop_discrete_type(noop):
env = DiscreteDummyEnv()
with pytest.raises(ValueError, match="The noop actions must be an integer for discrete action spaces"):
env = ActionsAsObservationWrapper(env, num_stack=3, noop=noop, dilation=2)


@pytest.mark.parametrize("noop", [0, 0.0])
def test_actions_as_observation_wrapper_invalid_noop_multidiscrete_type(noop):
env = MultiDiscreteDummyEnv()
with pytest.raises(ValueError, match="The noop actions must be a list for multi-discrete action spaces"):
env = ActionsAsObservationWrapper(env, num_stack=3, noop=noop, dilation=2)


@pytest.mark.parametrize("noop", [[0], [0, 0, 0]])
def test_actions_as_observation_wrapper_invalid_noop_multidiscrete_n_actions(noop):
env = MultiDiscreteDummyEnv()
with pytest.raises(
RuntimeError, match="The number of noop actions must be equal to the number of actions of the environment"
):
env = ActionsAsObservationWrapper(env, num_stack=3, noop=noop, dilation=2)
102 changes: 102 additions & 0 deletions tests/test_envs/test_frame_stack.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import numpy as np
import pytest

from sheeprl.envs.dummy import ContinuousDummyEnv, DiscreteDummyEnv, MultiDiscreteDummyEnv
from sheeprl.envs.wrappers import FrameStack

ENVIRONMENTS = {
"discrete_dummy": DiscreteDummyEnv,
"multidiscrete_dummy": MultiDiscreteDummyEnv,
"continuous_dummy": ContinuousDummyEnv,
}


@pytest.mark.parametrize("dilation", [1, 2, 4])
@pytest.mark.parametrize("num_stack", [1, 2, 3])
@pytest.mark.parametrize("env_id", ENVIRONMENTS.keys())
def test_valid_initialization(env_id, num_stack, dilation):
env = ENVIRONMENTS[env_id]()

env = FrameStack(env, num_stack=num_stack, cnn_keys=["rgb"], dilation=dilation)
assert env._num_stack == num_stack
assert env._dilation == dilation
assert "rgb" in env._cnn_keys
assert "rgb" in env._frames


@pytest.mark.parametrize("num_stack", [-2.4, -1, 0])
@pytest.mark.parametrize("env_id", ENVIRONMENTS.keys())
def test_invalid_num_stack(env_id, num_stack):
env = ENVIRONMENTS[env_id]()

with pytest.raises(ValueError, match="Invalid value for num_stack, expected a value greater"):
FrameStack(env, num_stack=num_stack, cnn_keys=["rgb"], dilation=2)


@pytest.mark.parametrize("num_stack", [1, 3, 7])
@pytest.mark.parametrize("env_id", ENVIRONMENTS.keys())
def test_invalid_observation_space(env_id, num_stack):
env = ENVIRONMENTS[env_id](dict_obs_space=False)

with pytest.raises(RuntimeError, match="Expected an observation space of type gym.spaces.Dict"):
FrameStack(env, num_stack=num_stack, cnn_keys=["rgb"], dilation=2)


@pytest.mark.parametrize("cnn_keys", [[], None])
@pytest.mark.parametrize("num_stack", [1, 3, 7])
@pytest.mark.parametrize("env_id", ENVIRONMENTS.keys())
def test_invalid_cnn_keys(env_id, num_stack, cnn_keys):
env = ENVIRONMENTS[env_id]()

with pytest.raises(RuntimeError, match="Specify at least one valid cnn key"):
FrameStack(env, num_stack=num_stack, cnn_keys=cnn_keys, dilation=2)


@pytest.mark.parametrize("env_id", ENVIRONMENTS.keys())
@pytest.mark.parametrize("num_stack", [1, 3, 7])
def test_reset_method(env_id, num_stack):
env = ENVIRONMENTS[env_id]()

wrapper = FrameStack(env, num_stack=num_stack, cnn_keys=["rgb"])
obs, _ = wrapper.reset()

assert "rgb" in obs
assert obs["rgb"].shape == (num_stack, *env.observation_space["rgb"].shape)


@pytest.mark.parametrize("num_stack", [1, 2, 5])
@pytest.mark.parametrize("dilation", [1, 2, 3])
def test_framestack(num_stack, dilation):
env = DiscreteDummyEnv()
env = FrameStack(env, num_stack, cnn_keys=["rgb"], dilation=dilation)

# Reset the environment to initialize the frame stack
obs, _ = env.reset()

for step in range(1, 64):
obs = env.step(None)[0]

expected_frame = np.stack(
[
np.full(
env.env.observation_space["rgb"].shape,
max(0, (step - dilation * (num_stack - i - 1))) % 256,
dtype=np.uint8,
)
for i in range(num_stack)
],
axis=0,
)
np.testing.assert_array_equal(obs["rgb"], expected_frame)


@pytest.mark.parametrize("env_id", ENVIRONMENTS.keys())
@pytest.mark.parametrize("num_stack", [1, 3, 7])
def test_step_method(env_id, num_stack):
env = ENVIRONMENTS[env_id]()
wrapper = FrameStack(env, num_stack=num_stack, cnn_keys=["rgb"])
wrapper.reset()
action = wrapper.action_space.sample()
obs = wrapper.step(action)[0]
assert "rgb" in obs
assert obs["rgb"].shape == (num_stack, *env.observation_space["rgb"].shape)
Loading

0 comments on commit dee8c80

Please sign in to comment.