Explorar o código

Rewrote broken env __str__ function. Added test.

Maxime Chevalier-Boisvert %!s(int64=6) %!d(string=hai) anos
pai
achega
232318e798
Modificáronse 2 ficheiros con 38 adicións e 66 borrados
  1. 35 66
      gym_minigrid/minigrid.py
  2. 3 0
      run_tests.py

+ 35 - 66
gym_minigrid/minigrid.py

@@ -779,90 +779,59 @@ class MiniGridEnv(gym.Env):
     def __str__(self):
         """
         Produce a pretty string of the environment's grid along with the agent.
-        The agent is represented by `⏩`. A grid pixel is represented by 2-character
-        string, the first one for the object and the second one for the color.
+        A grid cell is represented by 2-character string, the first one for
+        the object and the second one for the color.
         """
 
-        from copy import deepcopy
-
-        def rotate_left(array):
-            new_array = deepcopy(array)
-            for i in range(len(array)):
-                for j in range(len(array[0])):
-                    new_array[j][len(array[0])-1-i] = array[i][j]
-            return new_array
-
-        def vertically_symmetrize(array):
-            new_array = deepcopy(array)
-            for i in range(len(array)):
-                for j in range(len(array[0])):
-                    new_array[i][len(array[0])-1-j] = array[i][j]
-            return new_array
-
-        # Map of object id to short string
-        OBJECT_IDX_TO_IDS = {
-            0: ' ',
-            1: 'W',
-            2: 'D',
-            3: 'L',
-            4: 'K',
-            5: 'B',
-            6: 'X',
-            7: 'G'
+        # Map of object types to short string
+        OBJECT_TO_STR = {
+            'wall'          : 'W',
+            'floor'         : 'F',
+            'door'          : 'D',
+            'locked_door'   : 'L',
+            'key'           : 'K',
+            'ball'          : 'A',
+            'box'           : 'B',
+            'goal'          : 'G',
+            'lava'          : 'V',
         }
 
         # Short string for opened door
         OPENDED_DOOR_IDS = '_'
 
-        # Map of color id to short string
-        COLOR_IDX_TO_IDS = {
-            0: 'R',
-            1: 'G',
-            2: 'B',
-            3: 'P',
-            4: 'Y',
-            5: 'E'
-        }
-
         # Map agent's direction to short string
-        AGENT_DIR_TO_IDS = {
-            0: '',
-            1: '',
-            2: '',
-            3: ''
+        AGENT_DIR_TO_STR = {
+            0: '>',
+            1: 'V',
+            2: '<',
+            3: '^'
         }
 
-        array = self.grid.encode()
-
-        array = rotate_left(array)
-        array = vertically_symmetrize(array)
+        str = ''
 
-        new_array = []
+        for j in range(self.grid.height):
 
-        for line in array:
-            new_line = []
+            for i in range(self.grid.width):
+                if i == self.agent_pos[0] and j == self.agent_pos[1]:
+                    str += 2 * AGENT_DIR_TO_STR[self.agent_dir]
+                    continue
 
-            for pixel in line:
-                # If the door is opened
-                if pixel[0] in [2, 3] and pixel[2] == 1:
-                    object_ids = OPENDED_DOOR_IDS
-                else:
-                    object_ids = OBJECT_IDX_TO_IDS[pixel[0]]
+                c = self.grid.get(i, j)
 
-                # If no object
-                if pixel[0] == 0:
-                    color_ids = ' '
-                else:
-                    color_ids = COLOR_IDX_TO_IDS[pixel[1]]
+                if c == None:
+                    str += '  '
+                    continue
 
-                new_line.append(object_ids + color_ids)
+                if c.type.startswith('door') and c.is_open:
+                    str += '__'
+                    continue
 
-            new_array.append(new_line)
+                str += OBJECT_TO_STR[c.type] + c.color[0].upper()
 
-        # Add the agent
-        new_array[self.agent_pos[1]][self.agent_pos[0]] = AGENT_DIR_TO_IDS[self.agent_dir]
+            if j < self.grid.height - 1:
+                str += '\n'
 
-        return "\n".join([" ".join(line) for line in new_array])
+        return str
 
     def _gen_grid(self, width, height):
         assert False, "_gen_grid needs to be implemented by each environment"

+ 3 - 0
run_tests.py

@@ -54,6 +54,9 @@ for envName in env_list:
         img2 = Grid.decode(img).encode(vis_mask=vis_mask)
         assert np.array_equal(img, img2)
 
+        # Test the env to string function
+        str(env)
+
         # Check that the reward is within the specified range
         assert reward >= env.reward_range[0], reward
         assert reward <= env.reward_range[1], reward