Browse Source

Testing one-hot encoding

Maxime Chevalier-Boisvert 6 years ago
parent
commit
eb04bf87b5
1 changed files with 11 additions and 7 deletions
  1. 11 7
      gym_minigrid/minigrid.py

+ 11 - 7
gym_minigrid/minigrid.py

@@ -12,7 +12,8 @@ CELL_PIXELS = 32
 AGENT_VIEW_SIZE = 7
 
 # Size of the array given as an observation to the agent
-OBS_ARRAY_SIZE = (AGENT_VIEW_SIZE, AGENT_VIEW_SIZE, 3)
+#OBS_ARRAY_SIZE = (AGENT_VIEW_SIZE, AGENT_VIEW_SIZE, 3)
+OBS_ARRAY_SIZE = (AGENT_VIEW_SIZE, AGENT_VIEW_SIZE, 17)
 
 # Map of color names to RGB values
 COLORS = {
@@ -556,9 +557,7 @@ class Grid:
         Produce a compact numpy encoding of the grid
         """
 
-        codeSize = self.width * self.height * 3
-
-        array = np.zeros(shape=(self.width, self.height, 3), dtype='uint8')
+        array = np.zeros(shape=(self.width, self.height, 17), dtype='uint8')
 
         for j in range(0, self.height):
             for i in range(0, self.width):
@@ -568,14 +567,19 @@ class Grid:
                 if v == None:
                     continue
 
-                array[i, j, 0] = OBJECT_TO_IDX[v.type]
-                array[i, j, 1] = COLOR_TO_IDX[v.color]
+                type_idx = OBJECT_TO_IDX[v.type]
+                color_idx = COLOR_TO_IDX[v.color]
+
+                array[i, j, type_idx] = 1
+                array[i, j, 10 + color_idx] = 1
 
                 if hasattr(v, 'is_open') and v.is_open:
-                    array[i, j, 2] = 1
+                    array[i, j, 16] = 1
 
         return array
 
+
+
     def decode(array):
         """
         Decode an array grid encoding back into a grid