Skip to content

Commit

Permalink
fix(nyz): fix env check bugs (#852)
Browse files Browse the repository at this point in the history
  • Loading branch information
PaParaZz1 committed Jan 13, 2025
1 parent f60b377 commit f5157c7
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 18 deletions.
40 changes: 23 additions & 17 deletions ding/envs/env/env_implementation_check.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from tabnanny import check
from typing import Any, Callable, List, Tuple
from typing import Any, Callable, List, Tuple, Union, Dict, TYPE_CHECKING
import numpy as np
from collections.abc import Sequence
from easydict import EasyDict

from ding.envs.env import BaseEnv, BaseEnvTimestep
from gym.spaces import Space, Box, Discrete, MultiDiscrete, MultiBinary
from ding.envs.env.tests import DemoEnv
# from dizoo.atari.envs import AtariEnv

if TYPE_CHECKING:
from ding.envs.env import BaseEnv


def check_space_dtype(env: BaseEnv) -> None:
def check_space_dtype(env: 'BaseEnv') -> None:
print("== 0. Test obs/act/rew space's dtype")
env.reset()
for name, space in zip(['obs', 'act', 'rew'], [env.observation_space, env.action_space, env.reward_space]):
Expand All @@ -24,7 +24,7 @@ def check_space_dtype(env: BaseEnv) -> None:


# Util function
def check_array_space(ndarray, space, name) -> bool:
def check_array_space(ndarray: Union[np.ndarray, Sequence, Dict], space: Union['Space', Dict], name: str) -> None:
if isinstance(ndarray, np.ndarray):
# print("{}'s type should be np.ndarray".format(name))
assert ndarray.dtype == space.dtype, "{}'s dtype is {}, but requires {}".format(
Expand All @@ -33,14 +33,18 @@ def check_array_space(ndarray, space, name) -> bool:
assert ndarray.shape == space.shape, "{}'s shape is {}, but requires {}".format(
name, ndarray.shape, space.shape
)
assert (space.low <= ndarray).all() and (ndarray <= space.high).all(
), "{}'s value is {}, but requires in range ({},{})".format(name, ndarray, space.low, space.high)
if isinstance(space, Box):
assert (space.low <= ndarray).all() and (ndarray <= space.high).all(
), "{}'s value is {}, but requires in range ({},{})".format(name, ndarray, space.low, space.high)
elif isinstance(space, (Discrete, MultiDiscrete, MultiBinary)):
print(space.start, space.n)
assert (ndarray >= space.start) and (ndarray <= space.n)
elif isinstance(ndarray, Sequence):
for i in range(len(ndarray)):
try:
check_array_space(ndarray[i], space[i], name)
except AssertionError as e:
print("The following error happens at {}-th index".format(i))
print("The following error happens at {}-th index".format(i))
raise e
elif isinstance(ndarray, dict):
for k in ndarray.keys():
Expand All @@ -55,13 +59,13 @@ def check_array_space(ndarray, space, name) -> bool:
)


def check_reset(env: BaseEnv) -> None:
def check_reset(env: 'BaseEnv') -> None:
print('== 1. Test reset method')
obs = env.reset()
check_array_space(obs, env.observation_space, 'obs')


def check_step(env: BaseEnv) -> None:
def check_step(env: 'BaseEnv') -> None:
done_times = 0
print('== 2. Test step method')
_ = env.reset()
Expand All @@ -82,7 +86,9 @@ def check_step(env: BaseEnv) -> None:


# Util function
def check_different_memory(array1, array2, step_times) -> None:
def check_different_memory(
array1: Union[np.ndarray, Sequence, Dict], array2: Union[np.ndarray, Sequence, Dict], step_times: int
) -> None:
assert type(array1) == type(
array2
), "In step times {}, obs_last_frame({}) and obs_this_frame({}) are not of the same type".format(
Expand Down Expand Up @@ -121,7 +127,7 @@ def check_different_memory(array1, array2, step_times) -> None:
)


def check_obs_deepcopy(env: BaseEnv) -> None:
def check_obs_deepcopy(env: 'BaseEnv') -> None:

step_times = 0
print('== 3. Test observation deepcopy')
Expand All @@ -139,14 +145,14 @@ def check_obs_deepcopy(env: BaseEnv) -> None:
break


def check_all(env: BaseEnv) -> None:
def check_all(env: 'BaseEnv') -> None:
check_space_dtype(env)
check_reset(env)
check_step(env)
check_obs_deepcopy(env)


def demonstrate_correct_procedure(env_fn: Callable) -> None:
def demonstrate_correct_procedure(env_fn: Callable[[Dict], 'BaseEnv']) -> None:
print('== 4. Demonstrate the correct procudures')
done_times = 0
# Init the env.
Expand Down Expand Up @@ -174,7 +180,7 @@ def demonstrate_correct_procedure(env_fn: Callable) -> None:

if __name__ == "__main__":
'''
# Moethods `check_*` are for user to check whether their implemented env obeys DI-engine's rules.
# Methods `check_*` are for user to check whether their implemented env obeys DI-engine's rules.
# You can replace `AtariEnv` with your own env.
atari_env = AtariEnv(EasyDict(env_id='PongNoFrameskip-v4', frame_stack=4, is_train=False))
check_reset(atari_env)
Expand Down
7 changes: 6 additions & 1 deletion ding/envs/env/tests/test_env_implementation_check.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import pytest
from easydict import EasyDict
import numpy as np
import gym
from copy import deepcopy
Expand All @@ -17,6 +16,12 @@ def test_an_implemented_env():

@pytest.mark.unittest
def test_check_array_space():
discrete_space = gym.spaces.Discrete(10)
discrete_array = np.array(2, dtype=np.int64)
check_array_space(discrete_array, discrete_space, 'test_discrete')
discrete_array = np.array(11, dtype=np.int64)
with pytest.raises(AssertionError):
check_array_space(discrete_array, discrete_space, 'test_discrete')
seq_array = (np.array([1, 2, 3], dtype=np.int64), np.array([4., 5., 6.], dtype=np.float32))
seq_space = [gym.spaces.Box(low=0, high=10, shape=(3, ), dtype=np.int64) for _ in range(2)]
with pytest.raises(AssertionError):
Expand Down

0 comments on commit f5157c7

Please sign in to comment.