Bladeren bron

Added colored floor tile the agent can walk over

Maxime Chevalier-Boisvert 6 jaren geleden
bovenliggende
commit
b6ffc9ba38
2 gewijzigde bestanden met toevoegingen van 44 en 17 verwijderingen
  1. 1 1
      README.md
  2. 43 16
      gym_minigrid/minigrid.py

+ 1 - 1
README.md

@@ -103,7 +103,7 @@ Structure of the world:
   - Cells that do not contain an object have the value `None`
 - Each object has an associated discrete color (string)
 - Each object has an associated type (string)
-  - Provided object types are: wall, door, locked_doors, key, ball, box and goal
+  - Provided object types are: wall, floor, door, locked_doors, key, ball, box and goal
 - The agent can pick up and carry exactly one object (eg: ball or key)
 
 Actions in the basic environment:

+ 43 - 16
gym_minigrid/minigrid.py

@@ -17,12 +17,12 @@ OBS_ARRAY_SIZE = (AGENT_VIEW_SIZE, AGENT_VIEW_SIZE, 3)
 
 # Map of color names to RGB values
 COLORS = {
-    'red'   : (255, 0, 0),
-    'green' : (0, 255, 0),
-    'blue'  : (0, 0, 255),
-    'purple': (112, 39, 195),
-    'yellow': (255, 255, 0),
-    'grey'  : (100, 100, 100)
+    'red'   : np.array([255, 0, 0]),
+    'green' : np.array([0, 255, 0]),
+    'blue'  : np.array([0, 0, 255]),
+    'purple': np.array([112, 39, 195]),
+    'yellow': np.array([255, 255, 0]),
+    'grey'  : np.array([100, 100, 100])
 }
 
 COLOR_NAMES = sorted(list(COLORS.keys()))
@@ -43,12 +43,13 @@ IDX_TO_COLOR = dict(zip(COLOR_TO_IDX.values(), COLOR_TO_IDX.keys()))
 OBJECT_TO_IDX = {
     'empty'         : 0,
     'wall'          : 1,
-    'door'          : 2,
-    'locked_door'   : 3,
-    'key'           : 4,
-    'ball'          : 5,
-    'box'           : 6,
-    'goal'          : 7
+    'floor'         : 2,
+    'door'          : 3,
+    'locked_door'   : 4,
+    'key'           : 5,
+    'ball'          : 6,
+    'box'           : 7,
+    'goal'          : 8
 }
 
 IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys()))
@@ -115,7 +116,7 @@ class WorldObj:
 
 class Goal(WorldObj):
     def __init__(self):
-        super(Goal, self).__init__('goal', 'green')
+        super().__init__('goal', 'green')
 
     def can_overlap(self):
         return True
@@ -129,9 +130,32 @@ class Goal(WorldObj):
             (0          ,           0)
         ])
 
+class Floor(WorldObj):
+    """
+    Colored floor tile the agent can walk over
+    """
+
+    def __init__(self, color='blue'):
+        super().__init__('floor', color)
+
+    def can_overlap(self):
+        return True
+
+    def render(self, r):
+        # Give the floor a pale color
+        c = COLORS[self.color]
+        r.setLineColor(100, 100, 100, 0)
+        r.setColor(*c/2)
+        r.drawPolygon([
+            (1          , CELL_PIXELS),
+            (CELL_PIXELS, CELL_PIXELS),
+            (CELL_PIXELS,           1),
+            (1          ,           1)
+        ])
+
 class Wall(WorldObj):
     def __init__(self, color='grey'):
-        super(Wall, self).__init__('wall', color)
+        super().__init__('wall', color)
 
     def see_behind(self):
         return False
@@ -147,7 +171,7 @@ class Wall(WorldObj):
 
 class Door(WorldObj):
     def __init__(self, color, is_open=False):
-        super(Door, self).__init__('door', color)
+        super().__init__('door', color)
         self.is_open = is_open
 
     def can_overlap(self):
@@ -532,6 +556,8 @@ class Grid:
 
                 if objType == 'wall':
                     v = Wall(color)
+                elif objType == 'floor':
+                    v = Floor(color)
                 elif objType == 'ball':
                     v = Ball(color)
                 elif objType == 'key':
@@ -675,7 +701,8 @@ class MiniGridEnv(gym.Env):
         assert self.start_dir is not None
 
         # Check that the agent doesn't overlap with an object
-        assert self.grid.get(*self.start_pos) is None
+        start_cell = self.grid.get(*self.start_pos)
+        assert start_cell is None or start_cell.can_overlap()
 
         # Place the agent in the starting position and direction
         self.agent_pos = self.start_pos