Przeglądaj źródła

Implemented grid highlighting

Maxime Chevalier-Boisvert 5 lat temu
rodzic
commit
39f1c59e1c
3 zmienionych plików z 54 dodań i 58 usunięć
  1. 42 45
      gym_minigrid/minigrid.py
  2. 9 0
      gym_minigrid/rendering.py
  3. 3 13
      manual_control.py

+ 42 - 45
gym_minigrid/minigrid.py

@@ -533,7 +533,7 @@ class Grid:
         if obj != None:
         if obj != None:
             obj.render(img)
             obj.render(img)
 
 
-        # TODO: overlay agent on top
+        # Overlay the agent on top
         if agent_dir is not None:
         if agent_dir is not None:
             tri_fn = point_in_triangle(
             tri_fn = point_in_triangle(
                 (0.12, 0.19),
                 (0.12, 0.19),
@@ -541,12 +541,13 @@ class Grid:
                 (0.12, 0.81),
                 (0.12, 0.81),
             )
             )
 
 
+            # Rotate the agent based on its direction
             tri_fn = rotate_fn(tri_fn, cx=0.5, cy=0.5, theta=0.5*math.pi*agent_dir)
             tri_fn = rotate_fn(tri_fn, cx=0.5, cy=0.5, theta=0.5*math.pi*agent_dir)
             fill_coords(img, tri_fn, (255, 0, 0))
             fill_coords(img, tri_fn, (255, 0, 0))
 
 
-        # TODO: highlighting
+        # Highlight the cell if needed
         if highlight:
         if highlight:
-            pass
+            img = highlight_img(img)
 
 
         # Cache the rendered tile
         # Cache the rendered tile
         cls.tile_cache[key] = img
         cls.tile_cache[key] = img
@@ -557,7 +558,8 @@ class Grid:
         self,
         self,
         tile_size,
         tile_size,
         agent_pos=None,
         agent_pos=None,
-        agent_dir=None
+        agent_dir=None,
+        highlight_mask=None
     ):
     ):
         """
         """
         Render this grid at a given scale
         Render this grid at a given scale
@@ -565,6 +567,9 @@ class Grid:
         :param tile_size: tile size in pixels
         :param tile_size: tile size in pixels
         """
         """
 
 
+        if highlight_mask is None:
+            highlight_mask = np.zeros(shape=(self.width, self.height), dtype=np.bool)
+
         # Compute the total grid size
         # Compute the total grid size
         width_px = self.width * TILE_PIXELS
         width_px = self.width * TILE_PIXELS
         height_px = self.height * TILE_PIXELS
         height_px = self.height * TILE_PIXELS
@@ -576,10 +581,11 @@ class Grid:
             for i in range(0, self.width):
             for i in range(0, self.width):
                 cell = self.get(i, j)
                 cell = self.get(i, j)
 
 
+                agent_here = np.array_equal(agent_pos, (i, j))
                 tile_img = Grid.render_tile(
                 tile_img = Grid.render_tile(
                     cell,
                     cell,
-                    agent_dir=agent_dir if np.array_equal(agent_pos, (i, j)) else None,
-                    highlight=False,
+                    agent_dir=agent_dir if agent_here else None,
+                    highlight=highlight_mask[i, j],
                     tile_size=tile_size
                     tile_size=tile_size
                 )
                 )
 
 
@@ -1265,7 +1271,7 @@ class MiniGridEnv(gym.Env):
 
 
         return obs
         return obs
 
 
-    def get_obs_render(self, obs, tile_size=TILE_PIXELS//2, mode='pixmap'):
+    def get_obs_render(self, obs, tile_size=TILE_PIXELS//2):
         """
         """
         Render an agent observation for visualization
         Render an agent observation for visualization
         """
         """
@@ -1275,6 +1281,8 @@ class MiniGridEnv(gym.Env):
         # Render the whole grid
         # Render the whole grid
         img = grid.render(r, tile_size)
         img = grid.render(r, tile_size)
 
 
+        assert False
+
         """
         """
         # Draw the agent
         # Draw the agent
         ratio = tile_size / TILE_PIXELS
         ratio = tile_size / TILE_PIXELS
@@ -1293,14 +1301,10 @@ class MiniGridEnv(gym.Env):
             (-12, -10)
             (-12, -10)
         ])
         ])
         r.pop()
         r.pop()
-
-        if mode == 'rgb_array':
-            return r.getArray()
-        elif mode == 'pixmap':
-            return r.getPixmap()
-        return r
         """
         """
 
 
+        return img
+
     def render(self, mode='human', close=False, highlight=True, tile_size=TILE_PIXELS):
     def render(self, mode='human', close=False, highlight=True, tile_size=TILE_PIXELS):
         """
         """
         Render the whole-grid human view
         Render the whole-grid human view
@@ -1313,49 +1317,42 @@ class MiniGridEnv(gym.Env):
             return
             return
         """
         """
 
 
-        # Render the whole grid
-        img = self.grid.render(
-            tile_size,
-            self.agent_pos,
-            self.agent_dir
-        )
-
-        """
         # Compute which cells are visible to the agent
         # Compute which cells are visible to the agent
         _, vis_mask = self.gen_obs_grid()
         _, vis_mask = self.gen_obs_grid()
 
 
-        # Compute the absolute coordinates of the bottom-left corner
+        # Compute the world coordinates of the bottom-left corner
         # of the agent's view area
         # of the agent's view area
         f_vec = self.dir_vec
         f_vec = self.dir_vec
         r_vec = self.right_vec
         r_vec = self.right_vec
         top_left = self.agent_pos + f_vec * (self.agent_view_size-1) - r_vec * (self.agent_view_size // 2)
         top_left = self.agent_pos + f_vec * (self.agent_view_size-1) - r_vec * (self.agent_view_size // 2)
 
 
+        # Mask of which cells to highlight
+        highlight_mask = np.zeros(shape=(self.width, self.height), dtype=np.bool)
+
         # For each cell in the visibility mask
         # For each cell in the visibility mask
-        if highlight:
-            for vis_j in range(0, self.agent_view_size):
-                for vis_i in range(0, self.agent_view_size):
-                    # If this cell is not visible, don't highlight it
-                    if not vis_mask[vis_i, vis_j]:
-                        continue
+        for vis_j in range(0, self.agent_view_size):
+            for vis_i in range(0, self.agent_view_size):
+                # If this cell is not visible, don't highlight it
+                if not vis_mask[vis_i, vis_j]:
+                    continue
 
 
-                    # Compute the world coordinates of this cell
-                    abs_i, abs_j = top_left - (f_vec * vis_j) + (r_vec * vis_i)
+                # Compute the world coordinates of this cell
+                abs_i, abs_j = top_left - (f_vec * vis_j) + (r_vec * vis_i)
 
 
-                    # Highlight the cell
-                    r.fillRect(
-                        abs_i * tile_size,
-                        abs_j * tile_size,
-                        tile_size,
-                        tile_size,
-                        255, 255, 255, 75
-                    )
-        """
+                if abs_i < 0 or abs_i >= self.width:
+                    continue
+                if abs_j < 0 or abs_j >= self.height:
+                    continue
 
 
-        """
-        if mode == 'rgb_array':
-            return r.getArray()
-        elif mode == 'pixmap':
-            return r.getPixmap()
-        """
+                # Mark this cell to be highlighted
+                highlight_mask[abs_i, abs_j] = True
+
+        # Render the whole grid
+        img = self.grid.render(
+            tile_size,
+            self.agent_pos,
+            self.agent_dir,
+            highlight_mask=highlight_mask if highlight else None
+        )
 
 
         return img
         return img

+ 9 - 0
gym_minigrid/rendering.py

@@ -64,3 +64,12 @@ def point_in_triangle(a, b, c):
         return (u >= 0) and (v >= 0) and (u + v) < 1
         return (u >= 0) and (v >= 0) and (u + v) < 1
 
 
     return fn
     return fn
+
+def highlight_img(img, color=(255, 255, 255), alpha=0.30):
+    """
+    Add highlighting to an image
+    """
+
+    img = img + alpha * (np.array(color) - img)
+    img = img.clip(0, 255)
+    return img

+ 3 - 13
manual_control.py

@@ -79,17 +79,6 @@ args = parser.parse_args()
 
 
 env = gym.make(args.env_name)
 env = gym.make(args.env_name)
 
 
-"""
-t0 = time.time()
-
-for i in range(1000):
-    img = env.render('rgb_array')
-
-t1 = time.time()
-dt = int(1000 * (t1-t0))
-print(dt)
-"""
-
 fig, ax = plt.subplots()
 fig, ax = plt.subplots()
 
 
 # Keyboard handler
 # Keyboard handler
@@ -102,8 +91,9 @@ fig.canvas.set_window_title('gym_minigrid - ' + args.env_name)
 ax.set_xticks([], [])
 ax.set_xticks([], [])
 ax.set_yticks([], [])
 ax.set_yticks([], [])
 
 
-#plt.figure(num='gym-minigrid')
-imshow_obj = ax.imshow(img)
+# Show the first image of the environment
+img = env.render('rgb_array')
+imshow_obj = ax.imshow(img, interpolation='bilinear')
 
 
 reset()
 reset()