浏览代码

Added test for FullyObsWrapper

Maxime Chevalier-Boisvert 6 年之前
父节点
当前提交
544efb237a
共有 2 个文件被更改,包括 8 次插入5 次删除
  1. 2 4
      gym_minigrid/wrappers.py
  2. 6 1
      run_tests.py

+ 2 - 4
gym_minigrid/wrappers.py

@@ -74,7 +74,6 @@ class StateBonus(gym.core.Wrapper):
 
         return obs, reward, done, info
 
-
 class ImgObsWrapper(gym.core.ObservationWrapper):
     """
     Use rgb image as the only observation output
@@ -82,13 +81,13 @@ class ImgObsWrapper(gym.core.ObservationWrapper):
 
     def __init__(self, env):
         super().__init__(env)
-        self.__dict__.update(vars(env))  # hack to pass values to super wrapper
+        # Hack to pass values to super wrapper
+        self.__dict__.update(vars(env))
         self.observation_space = env.observation_space.spaces['image']
 
     def observation(self, obs):
         return obs['image']
 
-
 class FullyObsWrapper(gym.core.ObservationWrapper):
     """
     Fully observable gridworld using a compact grid encoding
@@ -109,7 +108,6 @@ class FullyObsWrapper(gym.core.ObservationWrapper):
         full_grid[self.env.agent_pos[0]][self.env.agent_pos[1]] = np.array([255, self.env.agent_dir, 0])
         return full_grid
 
-
 class FlatObsWrapper(gym.core.ObservationWrapper):
     """
     Encode mission strings using a one-hot scheme,

+ 6 - 1
run_tests.py

@@ -64,6 +64,12 @@ for envName in env_list:
 
         env.render('rgb_array')
 
+    # Test the fully observable wrapper
+    env = FullyObsWrapper(env)
+    env.reset()
+    obs, _, _, _ = env.step(0)
+    assert obs.shape == env.observation_space.shape
+
     env.close()
 
 ##############################################################################
@@ -88,4 +94,3 @@ for i in range(0, 500):
         env.reset()
 
 #############################################################################
-