Sfoglia il codice sorgente

Added basic tests for wrappers. Fixed issues.

Maxime Chevalier-Boisvert 6 anni fa
parent
commit
3e9171e182
3 ha cambiato i file con 31 aggiunte e 15 eliminazioni
  1. 4 8
      gym_minigrid/envs/lockedroom.py
  2. 3 4
      gym_minigrid/wrappers.py
  3. 24 3
      run_tests.py

+ 4 - 8
gym_minigrid/envs/lockedroom.py

@@ -29,15 +29,11 @@ class LockedRoom(MiniGridEnv):
     """
 
     def __init__(
-        self
+        self,
+        size=19
     ):
-        size = 19
         super().__init__(grid_size=size, max_steps=10*size)
 
-        self.observation_space = spaces.Dict({
-            'image': self.observation_space
-        })
-
     def _gen_grid(self, width, height):
         # Create the grid
         self.grid = Grid(width, height)
@@ -114,8 +110,8 @@ class LockedRoom(MiniGridEnv):
         # Generate the mission string
         self.mission = (
             'get the %s key from the %s room, '
-            'then use it to unlock the %s door '
-            'so you can get to the goal'
+            'unlock the %s door and '
+            'go to the goal'
         ) % (lockedRoom.color, keyRoom.color, lockedRoom.color)
 
     def step(self, action):

+ 3 - 4
gym_minigrid/wrappers.py

@@ -120,7 +120,7 @@ class FlatObsWrapper(gym.core.ObservationWrapper):
     and combine these with observed images into one flat array
     """
 
-    def __init__(self, env, maxStrLen=64):
+    def __init__(self, env, maxStrLen=96):
         self.__dict__.update(vars(env))  # Pass values to super wrapper
         super().__init__(env)
 
@@ -146,7 +146,7 @@ class FlatObsWrapper(gym.core.ObservationWrapper):
 
         # Cache the last-encoded mission string
         if mission != self.cachedStr:
-            assert len(mission) <= self.maxStrLen, "mission string too long"
+            assert len(mission) <= self.maxStrLen, 'mission string too long ({} chars)'.format(len(mission))
             mission = mission.lower()
 
             strArray = np.zeros(shape=(self.maxStrLen, self.numCharCodes), dtype='float32')
@@ -166,7 +166,6 @@ class FlatObsWrapper(gym.core.ObservationWrapper):
 
         return obs
 
-
 class AgentViewWrapper(gym.core.Wrapper):
     """
     Wrapper to customize the agent's field of view.
@@ -196,4 +195,4 @@ class AgentViewWrapper(gym.core.Wrapper):
         return self.env.reset(**kwargs)
 
     def step(self, action):
-        return self.env.step(action)
+        return self.env.step(action)

+ 24 - 3
run_tests.py

@@ -16,11 +16,11 @@ from gym_minigrid.wrappers import *
 
 print('%d environments registered' % len(env_list))
 
-for envName in env_list:
-    print('testing "%s"' % envName)
+for env_name in env_list:
+    print('testing "%s"' % env_name)
 
     # Load the gym environment
-    env = gym.make(envName)
+    env = gym.make(env_name)
     env.max_steps = min(env.max_steps, 200)
     env.reset()
     env.render('rgb_array')
@@ -67,12 +67,33 @@ for envName in env_list:
 
         env.render('rgb_array')
 
+    # Test the close method
+    env.close()
+
+    env = gym.make(env_name)
+    env = ImgObsWrapper(env)
+    env.reset()
+    env.step(0)
+    env.close()
+
     # Test the fully observable wrapper
+    env = gym.make(env_name)
     env = FullyObsWrapper(env)
     env.reset()
     obs, _, _, _ = env.step(0)
     assert obs.shape == env.observation_space.shape
+    env.close()
 
+    env = gym.make(env_name)
+    env = FlatObsWrapper(env)
+    env.reset()
+    env.step(0)
+    env.close()
+
+    env = gym.make(env_name)
+    env = AgentViewWrapper(env, 5)
+    env.reset()
+    env.step(0)
     env.close()
 
 ##############################################################################