|
@@ -533,7 +533,7 @@ class Grid:
|
|
|
if obj != None:
|
|
|
obj.render(img)
|
|
|
|
|
|
- # TODO: overlay agent on top
|
|
|
+ # Overlay the agent on top
|
|
|
if agent_dir is not None:
|
|
|
tri_fn = point_in_triangle(
|
|
|
(0.12, 0.19),
|
|
@@ -541,12 +541,13 @@ class Grid:
|
|
|
(0.12, 0.81),
|
|
|
)
|
|
|
|
|
|
+ # Rotate the agent based on its direction
|
|
|
tri_fn = rotate_fn(tri_fn, cx=0.5, cy=0.5, theta=0.5*math.pi*agent_dir)
|
|
|
fill_coords(img, tri_fn, (255, 0, 0))
|
|
|
|
|
|
- # TODO: highlighting
|
|
|
+ # Highlight the cell if needed
|
|
|
if highlight:
|
|
|
- pass
|
|
|
+ img = highlight_img(img)
|
|
|
|
|
|
# Cache the rendered tile
|
|
|
cls.tile_cache[key] = img
|
|
@@ -557,7 +558,8 @@ class Grid:
|
|
|
self,
|
|
|
tile_size,
|
|
|
agent_pos=None,
|
|
|
- agent_dir=None
|
|
|
+ agent_dir=None,
|
|
|
+ highlight_mask=None
|
|
|
):
|
|
|
"""
|
|
|
Render this grid at a given scale
|
|
@@ -565,6 +567,9 @@ class Grid:
|
|
|
:param tile_size: tile size in pixels
|
|
|
"""
|
|
|
|
|
|
+ if highlight_mask is None:
|
|
|
+ highlight_mask = np.zeros(shape=(self.width, self.height), dtype=np.bool)
|
|
|
+
|
|
|
# Compute the total grid size
|
|
|
width_px = self.width * TILE_PIXELS
|
|
|
height_px = self.height * TILE_PIXELS
|
|
@@ -576,10 +581,11 @@ class Grid:
|
|
|
for i in range(0, self.width):
|
|
|
cell = self.get(i, j)
|
|
|
|
|
|
+ agent_here = np.array_equal(agent_pos, (i, j))
|
|
|
tile_img = Grid.render_tile(
|
|
|
cell,
|
|
|
- agent_dir=agent_dir if np.array_equal(agent_pos, (i, j)) else None,
|
|
|
- highlight=False,
|
|
|
+ agent_dir=agent_dir if agent_here else None,
|
|
|
+ highlight=highlight_mask[i, j],
|
|
|
tile_size=tile_size
|
|
|
)
|
|
|
|
|
@@ -1265,7 +1271,7 @@ class MiniGridEnv(gym.Env):
|
|
|
|
|
|
return obs
|
|
|
|
|
|
- def get_obs_render(self, obs, tile_size=TILE_PIXELS//2, mode='pixmap'):
|
|
|
+ def get_obs_render(self, obs, tile_size=TILE_PIXELS//2):
|
|
|
"""
|
|
|
Render an agent observation for visualization
|
|
|
"""
|
|
@@ -1275,6 +1281,8 @@ class MiniGridEnv(gym.Env):
|
|
|
# Render the whole grid
|
|
|
img = grid.render(r, tile_size)
|
|
|
|
|
|
+ assert False
|
|
|
+
|
|
|
"""
|
|
|
# Draw the agent
|
|
|
ratio = tile_size / TILE_PIXELS
|
|
@@ -1293,14 +1301,10 @@ class MiniGridEnv(gym.Env):
|
|
|
(-12, -10)
|
|
|
])
|
|
|
r.pop()
|
|
|
-
|
|
|
- if mode == 'rgb_array':
|
|
|
- return r.getArray()
|
|
|
- elif mode == 'pixmap':
|
|
|
- return r.getPixmap()
|
|
|
- return r
|
|
|
"""
|
|
|
|
|
|
+ return img
|
|
|
+
|
|
|
def render(self, mode='human', close=False, highlight=True, tile_size=TILE_PIXELS):
|
|
|
"""
|
|
|
Render the whole-grid human view
|
|
@@ -1313,49 +1317,42 @@ class MiniGridEnv(gym.Env):
|
|
|
return
|
|
|
"""
|
|
|
|
|
|
- # Render the whole grid
|
|
|
- img = self.grid.render(
|
|
|
- tile_size,
|
|
|
- self.agent_pos,
|
|
|
- self.agent_dir
|
|
|
- )
|
|
|
-
|
|
|
- """
|
|
|
# Compute which cells are visible to the agent
|
|
|
_, vis_mask = self.gen_obs_grid()
|
|
|
|
|
|
- # Compute the absolute coordinates of the bottom-left corner
|
|
|
+ # Compute the world coordinates of the bottom-left corner
|
|
|
# of the agent's view area
|
|
|
f_vec = self.dir_vec
|
|
|
r_vec = self.right_vec
|
|
|
top_left = self.agent_pos + f_vec * (self.agent_view_size-1) - r_vec * (self.agent_view_size // 2)
|
|
|
|
|
|
+ # Mask of which cells to highlight
|
|
|
+ highlight_mask = np.zeros(shape=(self.width, self.height), dtype=np.bool)
|
|
|
+
|
|
|
# For each cell in the visibility mask
|
|
|
- if highlight:
|
|
|
- for vis_j in range(0, self.agent_view_size):
|
|
|
- for vis_i in range(0, self.agent_view_size):
|
|
|
- # If this cell is not visible, don't highlight it
|
|
|
- if not vis_mask[vis_i, vis_j]:
|
|
|
- continue
|
|
|
+ for vis_j in range(0, self.agent_view_size):
|
|
|
+ for vis_i in range(0, self.agent_view_size):
|
|
|
+ # If this cell is not visible, don't highlight it
|
|
|
+ if not vis_mask[vis_i, vis_j]:
|
|
|
+ continue
|
|
|
|
|
|
- # Compute the world coordinates of this cell
|
|
|
- abs_i, abs_j = top_left - (f_vec * vis_j) + (r_vec * vis_i)
|
|
|
+ # Compute the world coordinates of this cell
|
|
|
+ abs_i, abs_j = top_left - (f_vec * vis_j) + (r_vec * vis_i)
|
|
|
|
|
|
- # Highlight the cell
|
|
|
- r.fillRect(
|
|
|
- abs_i * tile_size,
|
|
|
- abs_j * tile_size,
|
|
|
- tile_size,
|
|
|
- tile_size,
|
|
|
- 255, 255, 255, 75
|
|
|
- )
|
|
|
- """
|
|
|
+ if abs_i < 0 or abs_i >= self.width:
|
|
|
+ continue
|
|
|
+ if abs_j < 0 or abs_j >= self.height:
|
|
|
+ continue
|
|
|
|
|
|
- """
|
|
|
- if mode == 'rgb_array':
|
|
|
- return r.getArray()
|
|
|
- elif mode == 'pixmap':
|
|
|
- return r.getPixmap()
|
|
|
- """
|
|
|
+ # Mark this cell to be highlighted
|
|
|
+ highlight_mask[abs_i, abs_j] = True
|
|
|
+
|
|
|
+ # Render the whole grid
|
|
|
+ img = self.grid.render(
|
|
|
+ tile_size,
|
|
|
+ self.agent_pos,
|
|
|
+ self.agent_dir,
|
|
|
+ highlight_mask=highlight_mask if highlight else None
|
|
|
+ )
|
|
|
|
|
|
return img
|