diff --git a/ding/envs/env/env_implementation_check.py b/ding/envs/env/env_implementation_check.py index 18b24edc81..e7f0dfe03b 100644 --- a/ding/envs/env/env_implementation_check.py +++ b/ding/envs/env/env_implementation_check.py @@ -1,15 +1,15 @@ -from tabnanny import check -from typing import Any, Callable, List, Tuple +from typing import Any, Callable, List, Tuple, Union, Dict, TYPE_CHECKING import numpy as np from collections.abc import Sequence -from easydict import EasyDict - -from ding.envs.env import BaseEnv, BaseEnvTimestep +from gym.spaces import Space, Box, Discrete, MultiDiscrete, MultiBinary from ding.envs.env.tests import DemoEnv # from dizoo.atari.envs import AtariEnv +if TYPE_CHECKING: + from ding.envs.env import BaseEnv + -def check_space_dtype(env: BaseEnv) -> None: +def check_space_dtype(env: 'BaseEnv') -> None: print("== 0. Test obs/act/rew space's dtype") env.reset() for name, space in zip(['obs', 'act', 'rew'], [env.observation_space, env.action_space, env.reward_space]): @@ -24,7 +24,7 @@ def check_space_dtype(env: BaseEnv) -> None: # Util function -def check_array_space(ndarray, space, name) -> bool: +def check_array_space(ndarray: Union[np.ndarray, Sequence, Dict], space: Union['Space', Dict], name: str) -> None: if isinstance(ndarray, np.ndarray): # print("{}'s type should be np.ndarray".format(name)) assert ndarray.dtype == space.dtype, "{}'s dtype is {}, but requires {}".format( @@ -33,14 +33,18 @@ def check_array_space(ndarray, space, name) -> bool: assert ndarray.shape == space.shape, "{}'s shape is {}, but requires {}".format( name, ndarray.shape, space.shape ) - assert (space.low <= ndarray).all() and (ndarray <= space.high).all( - ), "{}'s value is {}, but requires in range ({},{})".format(name, ndarray, space.low, space.high) + if isinstance(space, Box): + assert (space.low <= ndarray).all() and (ndarray <= space.high).all( + ), "{}'s value is {}, but requires in range ({},{})".format(name, ndarray, space.low, space.high) + elif isinstance(space, (Discrete, MultiDiscrete, MultiBinary)): + print(space.start, space.n) + assert (ndarray >= space.start) and (ndarray <= space.n) elif isinstance(ndarray, Sequence): for i in range(len(ndarray)): try: check_array_space(ndarray[i], space[i], name) except AssertionError as e: - print("The following error happens at {}-th index".format(i)) + print("The following error happens at {}-th index".format(i)) raise e elif isinstance(ndarray, dict): for k in ndarray.keys(): @@ -55,13 +59,13 @@ def check_array_space(ndarray, space, name) -> bool: ) -def check_reset(env: BaseEnv) -> None: +def check_reset(env: 'BaseEnv') -> None: print('== 1. Test reset method') obs = env.reset() check_array_space(obs, env.observation_space, 'obs') -def check_step(env: BaseEnv) -> None: +def check_step(env: 'BaseEnv') -> None: done_times = 0 print('== 2. Test step method') _ = env.reset() @@ -82,7 +86,9 @@ def check_step(env: BaseEnv) -> None: # Util function -def check_different_memory(array1, array2, step_times) -> None: +def check_different_memory( + array1: Union[np.ndarray, Sequence, Dict], array2: Union[np.ndarray, Sequence, Dict], step_times: int +) -> None: assert type(array1) == type( array2 ), "In step times {}, obs_last_frame({}) and obs_this_frame({}) are not of the same type".format( @@ -121,7 +127,7 @@ def check_different_memory(array1, array2, step_times) -> None: ) -def check_obs_deepcopy(env: BaseEnv) -> None: +def check_obs_deepcopy(env: 'BaseEnv') -> None: step_times = 0 print('== 3. Test observation deepcopy') @@ -139,14 +145,14 @@ def check_obs_deepcopy(env: BaseEnv) -> None: break -def check_all(env: BaseEnv) -> None: +def check_all(env: 'BaseEnv') -> None: check_space_dtype(env) check_reset(env) check_step(env) check_obs_deepcopy(env) -def demonstrate_correct_procedure(env_fn: Callable) -> None: +def demonstrate_correct_procedure(env_fn: Callable[[Dict], 'BaseEnv']) -> None: print('== 4. Demonstrate the correct procudures') done_times = 0 # Init the env. @@ -174,7 +180,7 @@ def demonstrate_correct_procedure(env_fn: Callable) -> None: if __name__ == "__main__": ''' - # Moethods `check_*` are for user to check whether their implemented env obeys DI-engine's rules. + # Methods `check_*` are for user to check whether their implemented env obeys DI-engine's rules. # You can replace `AtariEnv` with your own env. atari_env = AtariEnv(EasyDict(env_id='PongNoFrameskip-v4', frame_stack=4, is_train=False)) check_reset(atari_env) diff --git a/ding/envs/env/tests/test_env_implementation_check.py b/ding/envs/env/tests/test_env_implementation_check.py index fb413304ce..02fa79e064 100644 --- a/ding/envs/env/tests/test_env_implementation_check.py +++ b/ding/envs/env/tests/test_env_implementation_check.py @@ -1,5 +1,4 @@ import pytest -from easydict import EasyDict import numpy as np import gym from copy import deepcopy @@ -17,6 +16,12 @@ def test_an_implemented_env(): @pytest.mark.unittest def test_check_array_space(): + discrete_space = gym.spaces.Discrete(10) + discrete_array = np.array(2, dtype=np.int64) + check_array_space(discrete_array, discrete_space, 'test_discrete') + discrete_array = np.array(11, dtype=np.int64) + with pytest.raises(AssertionError): + check_array_space(discrete_array, discrete_space, 'test_discrete') seq_array = (np.array([1, 2, 3], dtype=np.int64), np.array([4., 5., 6.], dtype=np.float32)) seq_space = [gym.spaces.Box(low=0, high=10, shape=(3, ), dtype=np.int64) for _ in range(2)] with pytest.raises(AssertionError):