Browse Source

View size bug (#305)

Joseph Bloom 2 years ago
parent
commit
0a6449d41e
3 changed files with 22 additions and 4 deletions
  1. 1 1
      minigrid/wrappers.py
  2. 11 3
      tests/test_envs.py
  3. 10 0
      tests/test_wrappers.py

+ 1 - 1
minigrid/wrappers.py

@@ -635,7 +635,7 @@ class FlatObsWrapper(ObservationWrapper):
         return obs
         return obs
 
 
 
 
-class ViewSizeWrapper(Wrapper):
+class ViewSizeWrapper(ObservationWrapper):
     """
     """
     Wrapper to customize the agent field of view size.
     Wrapper to customize the agent field of view size.
     This cannot be used with fully observable wrappers.
     This cannot be used with fully observable wrappers.

+ 11 - 3
tests/test_envs.py

@@ -304,18 +304,26 @@ def test_mission_space():
     assert mission_space.contains("get the green key and the green key.")
     assert mission_space.contains("get the green key and the green key.")
     assert mission_space.contains("go fetch the red ball and the green key.")
     assert mission_space.contains("go fetch the red ball and the green key.")
 
 
+
 # not reasonable to test for all environments, test for a few of them.
 # not reasonable to test for all environments, test for a few of them.
-@pytest.mark.parametrize("env_id", ["MiniGrid-Empty-8x8-v0", "MiniGrid-DoorKey-16x16-v0","MiniGrid-ObstructedMaze-1Dl-v0"])
+@pytest.mark.parametrize(
+    "env_id",
+    [
+        "MiniGrid-Empty-8x8-v0",
+        "MiniGrid-DoorKey-16x16-v0",
+        "MiniGrid-ObstructedMaze-1Dl-v0",
+    ],
+)
 def test_env_sync_vectorization(env_id):
 def test_env_sync_vectorization(env_id):
-    
     def env_maker(env_id, **kwargs):
     def env_maker(env_id, **kwargs):
         def env_func():
         def env_func():
             env = gym.make(env_id, **kwargs)
             env = gym.make(env_id, **kwargs)
             return env
             return env
+
         return env_func
         return env_func
 
 
     num_envs = 4
     num_envs = 4
     env = gym.vector.SyncVectorEnv([env_maker(env_id) for _ in range(num_envs)])
     env = gym.vector.SyncVectorEnv([env_maker(env_id) for _ in range(num_envs)])
     env.reset()
     env.reset()
     env.step(env.action_space.sample())
     env.step(env.action_space.sample())
-    env.close()
+    env.close()

+ 10 - 0
tests/test_wrappers.py

@@ -250,3 +250,13 @@ def test_agent_sees_method(wrapper):
     assert (obs1["size"] == [5, 5]).all()
     assert (obs1["size"] == [5, 5]).all()
     for key in obs2:
     for key in obs2:
         assert np.array_equal(obs1[key], obs2[key])
         assert np.array_equal(obs1[key], obs2[key])
+
+
+@pytest.mark.parametrize("view_size", [5, 7, 9])
+def test_viewsize_wrapper(view_size):
+    env = gym.make("MiniGrid-Empty-5x5-v0")
+    env = ViewSizeWrapper(env, agent_view_size=view_size)
+    env.reset()
+    obs, _, _, _, _ = env.step(0)
+    assert obs["image"].shape == (view_size, view_size, 3)
+    env.close()