ソースを参照

Swap width and height in RGBImgObsWrapper, add test (#445)

Co-authored-by: Michael Catanzaro <michael.catanzaro@geomdata.com>
Michael Catanzaro 7 ヶ月 前
コミット
715394f7b6
2 ファイル変更11 行追加1 行削除
  1. 1 1
      minigrid/wrappers.py
  2. 10 0
      tests/test_wrappers.py

+ 1 - 1
minigrid/wrappers.py

@@ -313,8 +313,8 @@ class RGBImgObsWrapper(ObservationWrapper):
             low=0,
             high=255,
             shape=(
-                self.unwrapped.width * tile_size,
                 self.unwrapped.height * tile_size,
+                self.unwrapped.width * tile_size,
                 3,
             ),
             dtype="uint8",

+ 10 - 0
tests/test_wrappers.py

@@ -389,3 +389,13 @@ def test_no_death_wrapper():
     assert reward_wrap == reward + death_cost
     env.close()
     env_wrap.close()
+
+
+def test_non_square_RGBIMgObsWrapper():
+    """
+    Add test for non-square dimensions with RGBImgObsWrapper
+    (https://github.com/Farama-Foundation/Minigrid/issues/444).
+    """
+    env = RGBImgObsWrapper(gym.make("MiniGrid-BlockedUnlockPickup-v0"))
+    obs, info = env.reset()
+    assert env.observation_space["image"].shape == obs["image"].shape