ソースを参照

Merge branch 'master' of github.com:maximecb/gym-minigrid

Maxime Chevalier-Boisvert 6 年 前
コミット
dba71a2c63
3 ファイル変更49 行追加46 行削除
  1. 38 38
      gym_minigrid/minigrid.py
  2. 2 4
      gym_minigrid/wrappers.py
  3. 9 4
      run_tests.py

+ 38 - 38
gym_minigrid/minigrid.py

@@ -40,16 +40,17 @@ IDX_TO_COLOR = dict(zip(COLOR_TO_IDX.values(), COLOR_TO_IDX.keys()))
 
 # Map of object type to integers
 OBJECT_TO_IDX = {
-    'empty'         : 0,
-    'wall'          : 1,
-    'floor'         : 2,
-    'door'          : 3,
-    'locked_door'   : 4,
-    'key'           : 5,
-    'ball'          : 6,
-    'box'           : 7,
-    'goal'          : 8,
-    'lava'          : 9
+    'unseen'        : 0,
+    'empty'         : 1,
+    'wall'          : 2,
+    'floor'         : 3,
+    'door'          : 4,
+    'locked_door'   : 5,
+    'key'           : 6,
+    'ball'          : 7,
+    'box'           : 8,
+    'goal'          : 9,
+    'lava'          : 10
 }
 
 IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys()))
@@ -551,53 +552,52 @@ class Grid:
 
         r.pop()
 
-    def encode(self):
+    def encode(self, vis_mask=None):
         """
         Produce a compact numpy encoding of the grid
         """
 
-        array = np.zeros(shape=(self.width, self.height, 3), dtype='uint8')
+        if vis_mask is None:
+            vis_mask = np.ones((self.width, self.height), dtype=bool)
 
-        for j in range(0, self.height):
-            for i in range(0, self.width):
-
-                v = self.get(i, j)
-
-                if v == None:
-                    continue
-
-                array[i, j, 0] = OBJECT_TO_IDX[v.type]
-                array[i, j, 1] = COLOR_TO_IDX[v.color]
-
-                if hasattr(v, 'is_open') and v.is_open:
-                    array[i, j, 2] = 1
+        array = np.zeros((self.width, self.height, 3), dtype='uint8')
+        for i in range(self.width):
+            for j in range(self.height):
+                if vis_mask[i, j]:
+                    v = self.get(i, j)
+
+                    if v is None:
+                        array[i, j, 0] = OBJECT_TO_IDX['empty']
+                        array[i, j, 1] = 0
+                        array[i, j, 2] = 0
+                    else:
+                        array[i, j, 0] = OBJECT_TO_IDX[v.type]
+                        array[i, j, 1] = COLOR_TO_IDX[v.color]
+                        array[i, j, 2] = hasattr(v, 'is_open') and v.is_open
 
         return array
 
+    @staticmethod
     def decode(array):
         """
         Decode an array grid encoding back into a grid
         """
 
-        width = array.shape[0]
-        height = array.shape[1]
-        assert array.shape[2] == 3
+        width, height, channels = array.shape
+        assert channels == 3
 
         grid = Grid(width, height)
+        for i in range(width):
+            for j in range(height):
+                typeIdx, colorIdx, openIdx = array[i, j]
 
-        for j in range(0, height):
-            for i in range(0, width):
-
-                typeIdx  = array[i, j, 0]
-                colorIdx = array[i, j, 1]
-                openIdx  = array[i, j, 2]
-
-                if typeIdx == 0:
+                if typeIdx == OBJECT_TO_IDX['unseen'] or \
+                        typeIdx == OBJECT_TO_IDX['empty']:
                     continue
 
                 objType = IDX_TO_OBJECT[typeIdx]
                 color = IDX_TO_COLOR[colorIdx]
-                is_open = True if openIdx == 1 else 0
+                is_open = openIdx == 1
 
                 if objType == 'wall':
                     v = Wall(color)
@@ -1242,7 +1242,7 @@ class MiniGridEnv(gym.Env):
         grid, vis_mask = self.gen_obs_grid()
 
         # Encode the partially observable view into a numpy array
-        image = grid.encode()
+        image = grid.encode(vis_mask)
 
         assert hasattr(self, 'mission'), "environments must define a textual mission string"
 

+ 2 - 4
gym_minigrid/wrappers.py

@@ -74,7 +74,6 @@ class StateBonus(gym.core.Wrapper):
 
         return obs, reward, done, info
 
-
 class ImgObsWrapper(gym.core.ObservationWrapper):
     """
     Use rgb image as the only observation output
@@ -82,13 +81,13 @@ class ImgObsWrapper(gym.core.ObservationWrapper):
 
     def __init__(self, env):
         super().__init__(env)
-        self.__dict__.update(vars(env))  # hack to pass values to super wrapper
+        # Hack to pass values to super wrapper
+        self.__dict__.update(vars(env))
         self.observation_space = env.observation_space.spaces['image']
 
     def observation(self, obs):
         return obs['image']
 
-
 class FullyObsWrapper(gym.core.ObservationWrapper):
     """
     Fully observable gridworld using a compact grid encoding
@@ -109,7 +108,6 @@ class FullyObsWrapper(gym.core.ObservationWrapper):
         full_grid[self.env.agent_pos[0]][self.env.agent_pos[1]] = np.array([255, self.env.agent_dir, 0])
         return full_grid
 
-
 class FlatObsWrapper(gym.core.ObservationWrapper):
     """
     Encode mission strings using a one-hot scheme,

+ 9 - 4
run_tests.py

@@ -4,7 +4,7 @@ import random
 import numpy as np
 import gym
 from gym_minigrid.register import env_list
-from gym_minigrid.minigrid import Grid
+from gym_minigrid.minigrid import Grid, OBJECT_TO_IDX
 
 # Test specifically importing a specific environment
 from gym_minigrid.envs import DoorKeyEnv
@@ -50,8 +50,8 @@ for envName in env_list:
 
         # Test observation encode/decode roundtrip
         img = obs['image']
-        grid = Grid.decode(img)
-        img2 = grid.encode()
+        vis_mask = img[:, :, 0] != OBJECT_TO_IDX['unseen']  # hackish
+        img2 = Grid.decode(img).encode(vis_mask=vis_mask)
         assert np.array_equal(img, img2)
 
         # Check that the reward is within the specified range
@@ -64,6 +64,12 @@ for envName in env_list:
 
         env.render('rgb_array')
 
+    # Test the fully observable wrapper
+    env = FullyObsWrapper(env)
+    env.reset()
+    obs, _, _, _ = env.step(0)
+    assert obs.shape == env.observation_space.shape
+
     env.close()
 
 ##############################################################################
@@ -88,4 +94,3 @@ for i in range(0, 500):
         env.reset()
 
 #############################################################################
-