Joseph Bloom 2 gadi atpakaļ
vecāks
revīzija
0a6449d41e
3 mainītis faili ar 22 papildinājumiem un 4 dzēšanām
  1. 1 1
      minigrid/wrappers.py
  2. 11 3
      tests/test_envs.py
  3. 10 0
      tests/test_wrappers.py

+ 1 - 1
minigrid/wrappers.py

@@ -635,7 +635,7 @@ class FlatObsWrapper(ObservationWrapper):
         return obs
         return obs
 
 
 
 
-class ViewSizeWrapper(Wrapper):
+class ViewSizeWrapper(ObservationWrapper):
     """
     """
     Wrapper to customize the agent field of view size.
     Wrapper to customize the agent field of view size.
     This cannot be used with fully observable wrappers.
     This cannot be used with fully observable wrappers.

+ 11 - 3
tests/test_envs.py

@@ -304,18 +304,26 @@ def test_mission_space():
     assert mission_space.contains("get the green key and the green key.")
     assert mission_space.contains("get the green key and the green key.")
     assert mission_space.contains("go fetch the red ball and the green key.")
     assert mission_space.contains("go fetch the red ball and the green key.")
 
 
+
 # not reasonable to test for all environments, test for a few of them.
 # not reasonable to test for all environments, test for a few of them.
-@pytest.mark.parametrize("env_id", ["MiniGrid-Empty-8x8-v0", "MiniGrid-DoorKey-16x16-v0","MiniGrid-ObstructedMaze-1Dl-v0"])
+@pytest.mark.parametrize(
+    "env_id",
+    [
+        "MiniGrid-Empty-8x8-v0",
+        "MiniGrid-DoorKey-16x16-v0",
+        "MiniGrid-ObstructedMaze-1Dl-v0",
+    ],
+)
 def test_env_sync_vectorization(env_id):
 def test_env_sync_vectorization(env_id):
-    
     def env_maker(env_id, **kwargs):
     def env_maker(env_id, **kwargs):
         def env_func():
         def env_func():
             env = gym.make(env_id, **kwargs)
             env = gym.make(env_id, **kwargs)
             return env
             return env
+
         return env_func
         return env_func
 
 
     num_envs = 4
     num_envs = 4
     env = gym.vector.SyncVectorEnv([env_maker(env_id) for _ in range(num_envs)])
     env = gym.vector.SyncVectorEnv([env_maker(env_id) for _ in range(num_envs)])
     env.reset()
     env.reset()
     env.step(env.action_space.sample())
     env.step(env.action_space.sample())
-    env.close()
+    env.close()

+ 10 - 0
tests/test_wrappers.py

@@ -250,3 +250,13 @@ def test_agent_sees_method(wrapper):
     assert (obs1["size"] == [5, 5]).all()
     assert (obs1["size"] == [5, 5]).all()
     for key in obs2:
     for key in obs2:
         assert np.array_equal(obs1[key], obs2[key])
         assert np.array_equal(obs1[key], obs2[key])
+
+
+@pytest.mark.parametrize("view_size", [5, 7, 9])
+def test_viewsize_wrapper(view_size):
+    env = gym.make("MiniGrid-Empty-5x5-v0")
+    env = ViewSizeWrapper(env, agent_view_size=view_size)
+    env.reset()
+    obs, _, _, _, _ = env.step(0)
+    assert obs["image"].shape == (view_size, view_size, 3)
+    env.close()