소스 검색

`prettyencode`, `prettystring` and `__str__` added to Grid.

Lucas Willems 7 년 전
부모
커밋
063b2dda6c
1개의 변경된 파일98개의 추가작업 그리고 0개의 파일을 삭제
  1. 98 0
      gym_minigrid/minigrid.py

+ 98 - 0
gym_minigrid/minigrid.py

@@ -329,6 +329,9 @@ class Grid:
 
     def __ne__(self, other):
         return not self == other
+    
+    def __str__(self):
+        return self.prettystring()
 
     def copy(self):
         from copy import deepcopy
@@ -485,6 +488,101 @@ class Grid:
                     array[i, j, 2] = 1
 
         return array
+    
+    def prettyencode(self, env=None):
+        """
+        Produce a compact 2d-array encoding of the grid with pretty pixels
+        """
+
+        from copy import deepcopy
+
+        def rotate_left(array):
+            new_array = deepcopy(array)
+            for i in range(len(array)):
+                for j in range(len(array[0])):
+                    new_array[j][len(array[0])-1-i] = array[i][j]
+            return new_array
+
+        def vertically_symmetrize(array):
+            new_array = deepcopy(array)
+            for i in range(len(array)):
+                for j in range(len(array[0])):
+                    new_array[i][len(array[0])-1-j] = array[i][j]
+            return new_array
+
+        # Map of object id to short string
+        OBJECT_IDX_TO_IDS = {
+            0: ' ',
+            1: 'W',
+            2: 'D',
+            3: 'L',
+            4: 'K',
+            5: 'B',
+            6: 'X',
+            7: 'G'
+        }
+
+        # Short string for opened door
+        OPENDED_DOOR_IDS = '_'
+
+        # Map of color id to short string
+        COLOR_IDX_TO_IDS = {
+            0: 'R',
+            1: 'G',
+            2: 'B',
+            3: 'P',
+            4: 'Y',
+            5: 'E'
+        }
+
+        # Map agent's direction to short string
+        AGENT_DIR_TO_IDS = {
+            0: '⏩',
+            1: '⏬',
+            2: '⏪',
+            3: '⏫'
+        }
+        
+        array = self.encode()
+
+        array = rotate_left(array)
+        array = vertically_symmetrize(array)
+
+        new_array = []
+
+        for line in array:
+            new_line = []
+
+            for pixel in line:
+                # If the door is opened
+                if pixel[0] in [2, 3] and pixel[2] == 1:
+                    object_ids = OPENDED_DOOR_IDS
+                else:
+                    object_ids = OBJECT_IDX_TO_IDS[pixel[0]]
+
+                # If no object
+                if pixel[0] == 0:
+                    color_ids = ' '
+                else:
+                    color_ids = COLOR_IDX_TO_IDS[pixel[1]]
+
+                new_line.append(object_ids + color_ids)
+
+            new_array.append(new_line)
+        
+        if env != None:
+            # Add the agent
+            new_array[env.agentPos[1]][env.agentPos[0]] = AGENT_DIR_TO_IDS[env.agentDir]
+
+        return new_array
+    
+    def prettystring(self, env=None):
+        """
+        Produce a pretty string of the grid
+        """
+        
+        array = self.prettyencode(env)
+        return "\n".join([" ".join(line) for line in array])
 
     def decode(array):
         """