Browse Source

View size bug (#305)

Joseph Bloom 2 years ago
parent
commit
0a6449d41e
3 changed files with 22 additions and 4 deletions
  1. 1 1
      minigrid/wrappers.py
  2. 11 3
      tests/test_envs.py
  3. 10 0
      tests/test_wrappers.py

+ 1 - 1
minigrid/wrappers.py

@@ -635,7 +635,7 @@ class FlatObsWrapper(ObservationWrapper):
         return obs
 
 
-class ViewSizeWrapper(Wrapper):
+class ViewSizeWrapper(ObservationWrapper):
     """
     Wrapper to customize the agent field of view size.
     This cannot be used with fully observable wrappers.

+ 11 - 3
tests/test_envs.py

@@ -304,18 +304,26 @@ def test_mission_space():
     assert mission_space.contains("get the green key and the green key.")
     assert mission_space.contains("go fetch the red ball and the green key.")
 
+
 # not reasonable to test for all environments, test for a few of them.
-@pytest.mark.parametrize("env_id", ["MiniGrid-Empty-8x8-v0", "MiniGrid-DoorKey-16x16-v0","MiniGrid-ObstructedMaze-1Dl-v0"])
+@pytest.mark.parametrize(
+    "env_id",
+    [
+        "MiniGrid-Empty-8x8-v0",
+        "MiniGrid-DoorKey-16x16-v0",
+        "MiniGrid-ObstructedMaze-1Dl-v0",
+    ],
+)
 def test_env_sync_vectorization(env_id):
-    
     def env_maker(env_id, **kwargs):
         def env_func():
             env = gym.make(env_id, **kwargs)
             return env
+
         return env_func
 
     num_envs = 4
     env = gym.vector.SyncVectorEnv([env_maker(env_id) for _ in range(num_envs)])
     env.reset()
     env.step(env.action_space.sample())
-    env.close()
+    env.close()

+ 10 - 0
tests/test_wrappers.py

@@ -250,3 +250,13 @@ def test_agent_sees_method(wrapper):
     assert (obs1["size"] == [5, 5]).all()
     for key in obs2:
         assert np.array_equal(obs1[key], obs2[key])
+
+
+@pytest.mark.parametrize("view_size", [5, 7, 9])
+def test_viewsize_wrapper(view_size):
+    env = gym.make("MiniGrid-Empty-5x5-v0")
+    env = ViewSizeWrapper(env, agent_view_size=view_size)
+    env.reset()
+    obs, _, _, _, _ = env.step(0)
+    assert obs["image"].shape == (view_size, view_size, 3)
+    env.close()