Browse Source

Fixed issues with run_tests.py, grid encode/decode

Maxime Chevalier-Boisvert 7 years ago
parent
commit
28df92e70d
2 changed files with 8 additions and 5 deletions
  1. 3 1
      gym_minigrid/minigrid.py
  2. 5 4
      run_tests.py

+ 3 - 1
gym_minigrid/minigrid.py

@@ -464,11 +464,13 @@ class Grid:
                 isOpen = True if openIdx == 1 else 0
 
                 if objType == 'wall':
-                    v = Wall()
+                    v = Wall(color)
                 elif objType == 'ball':
                     v = Ball(color)
                 elif objType == 'key':
                     v = Key(color)
+                elif objType == 'box':
+                    v = Box(color)
                 elif objType == 'door':
                     v = Door(color, isOpen)
                 elif objType == 'locked_door':

+ 5 - 4
run_tests.py

@@ -30,11 +30,12 @@ for envName in sorted(envSet):
         obs, reward, done, info = env.step(action)
 
         # Test observation encode/decode roundtrip
-        if type(obs) is np.ndarray:
-            grid = Grid.decode(obs)
-            obs2 = grid.encode()
-            assert np.array_equal(obs2, obs)
+        img = obs if type(obs) is np.ndarray else obs['image']
+        grid = Grid.decode(img)
+        img2 = grid.encode()
+        assert np.array_equal(img2, img)
 
+        # Check that the reward is within the specified range
         assert reward >= env.reward_range[0], reward
         assert reward <= env.reward_range[1], reward