diff --git a/minigrid/wrappers.py b/minigrid/wrappers.py index 43aa70feb..c87879b65 100644 --- a/minigrid/wrappers.py +++ b/minigrid/wrappers.py @@ -313,8 +313,8 @@ def __init__(self, env, tile_size=8): low=0, high=255, shape=( - self.unwrapped.width * tile_size, self.unwrapped.height * tile_size, + self.unwrapped.width * tile_size, 3, ), dtype="uint8", diff --git a/tests/test_wrappers.py b/tests/test_wrappers.py index 1fb0777da..9afd544f1 100644 --- a/tests/test_wrappers.py +++ b/tests/test_wrappers.py @@ -389,3 +389,13 @@ def test_no_death_wrapper(): assert reward_wrap == reward + death_cost env.close() env_wrap.close() + + +def test_non_square_RGBIMgObsWrapper(): + """ + Add test for non-square dimensions with RGBImgObsWrapper + (https://github.com/Farama-Foundation/Minigrid/issues/444). + """ + env = RGBImgObsWrapper(gym.make("MiniGrid-BlockedUnlockPickup-v0")) + obs, info = env.reset() + assert env.observation_space["image"].shape == obs["image"].shape