Selaa lähdekoodia

make observation image have different values for empty and unseen cells (#35) (#36)

* add ('unseen', 0) to `OBJECT_TO_IDX` dictionary;  all other values are
 shifted up by 1 to make space for the new value.

 * update `Grid.encode` to take an optional `vis_mask` argument, and
 return an image which now distinguishes between unseen and empty cells
 (values respectively 0 and 1).

 * update `Grid.decode` to handle the above changes;  values 0 and 1 in
 input array are both mapped to None in output grid.

 * make Grid.decode() a static method;  it was already used as a static
 method but was missing the appropriate decorator.

 * update "observation encode/decode roundtrip" test in run_tests.py;
 the visualization mask necessary to perform the test is found by
 manually checking the first channel of the image for value 0.
A. Baisero 6 vuotta sitten
vanhempi
commit
a6678d060d
2 muutettua tiedostoa jossa 41 lisäystä ja 44 poistoa
  1. 38 41
      gym_minigrid/minigrid.py
  2. 3 3
      run_tests.py

+ 38 - 41
gym_minigrid/minigrid.py

@@ -40,16 +40,17 @@ IDX_TO_COLOR = dict(zip(COLOR_TO_IDX.values(), COLOR_TO_IDX.keys()))
 
 # Map of object type to integers
 OBJECT_TO_IDX = {
-    'empty'         : 0,
-    'wall'          : 1,
-    'floor'         : 2,
-    'door'          : 3,
-    'locked_door'   : 4,
-    'key'           : 5,
-    'ball'          : 6,
-    'box'           : 7,
-    'goal'          : 8,
-    'lava'          : 9
+    'unseen'        : 0,
+    'empty'         : 1,
+    'wall'          : 2,
+    'floor'         : 3,
+    'door'          : 4,
+    'locked_door'   : 5,
+    'key'           : 6,
+    'ball'          : 7,
+    'box'           : 8,
+    'goal'          : 9,
+    'lava'          : 10
 }
 
 IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys()))
@@ -551,55 +552,51 @@ class Grid:
 
         r.pop()
 
-    def encode(self):
+    def encode(self, vis_mask=None):
         """
         Produce a compact numpy encoding of the grid
         """
+        if vis_mask is None:
+            vis_mask = np.ones((self.width, self.height), dtype=bool)
 
-        codeSize = self.width * self.height * 3
-
-        array = np.zeros(shape=(self.width, self.height, 3), dtype='uint8')
-
-        for j in range(0, self.height):
-            for i in range(0, self.width):
-
-                v = self.get(i, j)
-
-                if v == None:
-                    continue
-
-                array[i, j, 0] = OBJECT_TO_IDX[v.type]
-                array[i, j, 1] = COLOR_TO_IDX[v.color]
-
-                if hasattr(v, 'is_open') and v.is_open:
-                    array[i, j, 2] = 1
+        array = np.zeros((self.width, self.height, 3), dtype='uint8')
+        for i in range(self.width):
+            for j in range(self.height):
+                if vis_mask[i, j]:
+                    v = self.get(i, j)
+
+                    if v is None:
+                        array[i, j, 0] = OBJECT_TO_IDX['empty']
+                        array[i, j, 1] = 0
+                        array[i, j, 2] = 0
+                    else:
+                        array[i, j, 0] = OBJECT_TO_IDX[v.type]
+                        array[i, j, 1] = COLOR_TO_IDX[v.color]
+                        array[i, j, 2] = hasattr(v, 'is_open') and v.is_open
 
         return array
 
+    @staticmethod
     def decode(array):
         """
         Decode an array grid encoding back into a grid
         """
 
-        width = array.shape[0]
-        height = array.shape[1]
-        assert array.shape[2] == 3
+        width, height, channels = array.shape
+        assert channels == 3
 
         grid = Grid(width, height)
+        for i in range(width):
+            for j in range(height):
+                typeIdx, colorIdx, openIdx = array[i, j]
 
-        for j in range(0, height):
-            for i in range(0, width):
-
-                typeIdx  = array[i, j, 0]
-                colorIdx = array[i, j, 1]
-                openIdx  = array[i, j, 2]
-
-                if typeIdx == 0:
+                if typeIdx == OBJECT_TO_IDX['unseen'] or \
+                        typeIdx == OBJECT_TO_IDX['empty']:
                     continue
 
                 objType = IDX_TO_OBJECT[typeIdx]
                 color = IDX_TO_COLOR[colorIdx]
-                is_open = True if openIdx == 1 else 0
+                is_open = openIdx == 1
 
                 if objType == 'wall':
                     v = Wall(color)
@@ -1244,7 +1241,7 @@ class MiniGridEnv(gym.Env):
         grid, vis_mask = self.gen_obs_grid()
 
         # Encode the partially observable view into a numpy array
-        image = grid.encode()
+        image = grid.encode(vis_mask)
 
         assert hasattr(self, 'mission'), "environments must define a textual mission string"
 

+ 3 - 3
run_tests.py

@@ -4,7 +4,7 @@ import random
 import numpy as np
 import gym
 from gym_minigrid.register import env_list
-from gym_minigrid.minigrid import Grid
+from gym_minigrid.minigrid import Grid, OBJECT_TO_IDX
 
 # Test specifically importing a specific environment
 from gym_minigrid.envs import DoorKeyEnv
@@ -50,8 +50,8 @@ for envName in env_list:
 
         # Test observation encode/decode roundtrip
         img = obs['image']
-        grid = Grid.decode(img)
-        img2 = grid.encode()
+        vis_mask = img[:, :, 0] != OBJECT_TO_IDX['unseen']  # hackish
+        img2 = Grid.decode(img).encode(vis_mask=vis_mask)
         assert np.array_equal(img, img2)
 
         # Check that the reward is within the specified range