| 
					
				 | 
			
			
				@@ -533,7 +533,7 @@ class Grid: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         if obj != None: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             obj.render(img) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        # TODO: overlay agent on top 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # Overlay the agent on top 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         if agent_dir is not None: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             tri_fn = point_in_triangle( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 (0.12, 0.19), 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -541,12 +541,13 @@ class Grid: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 (0.12, 0.81), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            # Rotate the agent based on its direction 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             tri_fn = rotate_fn(tri_fn, cx=0.5, cy=0.5, theta=0.5*math.pi*agent_dir) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             fill_coords(img, tri_fn, (255, 0, 0)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        # TODO: highlighting 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # Highlight the cell if needed 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         if highlight: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            pass 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            img = highlight_img(img) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         # Cache the rendered tile 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         cls.tile_cache[key] = img 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -557,7 +558,8 @@ class Grid: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         self, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         tile_size, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         agent_pos=None, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        agent_dir=None 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        agent_dir=None, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        highlight_mask=None 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     ): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         Render this grid at a given scale 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -565,6 +567,9 @@ class Grid: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         :param tile_size: tile size in pixels 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if highlight_mask is None: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            highlight_mask = np.zeros(shape=(self.width, self.height), dtype=np.bool) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         # Compute the total grid size 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         width_px = self.width * TILE_PIXELS 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         height_px = self.height * TILE_PIXELS 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -576,10 +581,11 @@ class Grid: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             for i in range(0, self.width): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 cell = self.get(i, j) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                agent_here = np.array_equal(agent_pos, (i, j)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 tile_img = Grid.render_tile( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     cell, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    agent_dir=agent_dir if np.array_equal(agent_pos, (i, j)) else None, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    highlight=False, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    agent_dir=agent_dir if agent_here else None, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    highlight=highlight_mask[i, j], 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     tile_size=tile_size 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -1265,7 +1271,7 @@ class MiniGridEnv(gym.Env): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         return obs 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    def get_obs_render(self, obs, tile_size=TILE_PIXELS//2, mode='pixmap'): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def get_obs_render(self, obs, tile_size=TILE_PIXELS//2): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         Render an agent observation for visualization 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         """ 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -1275,6 +1281,8 @@ class MiniGridEnv(gym.Env): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         # Render the whole grid 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         img = grid.render(r, tile_size) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        assert False 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         # Draw the agent 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         ratio = tile_size / TILE_PIXELS 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -1293,14 +1301,10 @@ class MiniGridEnv(gym.Env): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             (-12, -10) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         ]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         r.pop() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        if mode == 'rgb_array': 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            return r.getArray() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        elif mode == 'pixmap': 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            return r.getPixmap() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        return r 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        return img 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def render(self, mode='human', close=False, highlight=True, tile_size=TILE_PIXELS): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         Render the whole-grid human view 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -1313,49 +1317,42 @@ class MiniGridEnv(gym.Env): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             return 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        # Render the whole grid 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        img = self.grid.render( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            tile_size, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            self.agent_pos, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            self.agent_dir 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         # Compute which cells are visible to the agent 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         _, vis_mask = self.gen_obs_grid() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        # Compute the absolute coordinates of the bottom-left corner 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # Compute the world coordinates of the bottom-left corner 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         # of the agent's view area 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         f_vec = self.dir_vec 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         r_vec = self.right_vec 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         top_left = self.agent_pos + f_vec * (self.agent_view_size-1) - r_vec * (self.agent_view_size // 2) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # Mask of which cells to highlight 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        highlight_mask = np.zeros(shape=(self.width, self.height), dtype=np.bool) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         # For each cell in the visibility mask 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        if highlight: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            for vis_j in range(0, self.agent_view_size): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                for vis_i in range(0, self.agent_view_size): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    # If this cell is not visible, don't highlight it 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    if not vis_mask[vis_i, vis_j]: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                        continue 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        for vis_j in range(0, self.agent_view_size): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            for vis_i in range(0, self.agent_view_size): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                # If this cell is not visible, don't highlight it 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                if not vis_mask[vis_i, vis_j]: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    continue 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    # Compute the world coordinates of this cell 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    abs_i, abs_j = top_left - (f_vec * vis_j) + (r_vec * vis_i) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                # Compute the world coordinates of this cell 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                abs_i, abs_j = top_left - (f_vec * vis_j) + (r_vec * vis_i) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    # Highlight the cell 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    r.fillRect( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                        abs_i * tile_size, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                        abs_j * tile_size, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                        tile_size, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                        tile_size, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                        255, 255, 255, 75 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                if abs_i < 0 or abs_i >= self.width: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    continue 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                if abs_j < 0 or abs_j >= self.height: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    continue 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        if mode == 'rgb_array': 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            return r.getArray() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        elif mode == 'pixmap': 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            return r.getPixmap() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                # Mark this cell to be highlighted 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                highlight_mask[abs_i, abs_j] = True 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # Render the whole grid 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        img = self.grid.render( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            tile_size, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            self.agent_pos, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            self.agent_dir, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            highlight_mask=highlight_mask if highlight else None 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         return img 
			 |