|
@@ -364,7 +364,7 @@ class Grid:
|
|
|
return False
|
|
|
|
|
|
def __eq__(self, other):
|
|
|
- grid1 = self.encode()
|
|
|
+ grid1 = self.encode()
|
|
|
grid2 = other.encode()
|
|
|
return np.array_equal(grid2, grid1)
|
|
|
|
|
@@ -569,14 +569,17 @@ class Grid:
|
|
|
width, height, channels = array.shape
|
|
|
assert channels == 3
|
|
|
|
|
|
+ vis_mask = np.ones(shape=(width, height), dtype=np.bool)
|
|
|
+
|
|
|
grid = Grid(width, height)
|
|
|
for i in range(width):
|
|
|
for j in range(height):
|
|
|
type_idx, color_idx, state = array[i, j]
|
|
|
v = WorldObj.decode(type_idx, color_idx, state)
|
|
|
grid.set(i, j, v)
|
|
|
+ vis_mask[i, j] = (type_idx != OBJECT_TO_IDX['unseen'])
|
|
|
|
|
|
- return grid
|
|
|
+ return grid, vis_mask
|
|
|
|
|
|
def process_vis(grid, agent_pos):
|
|
|
mask = np.zeros(shape=(grid.width, grid.height), dtype=np.bool)
|
|
@@ -1078,7 +1081,7 @@ class MiniGridEnv(gym.Env):
|
|
|
vx, vy = coordinates
|
|
|
|
|
|
obs = self.gen_obs()
|
|
|
- obs_grid = Grid.decode(obs['image'])
|
|
|
+ obs_grid, _ = Grid.decode(obs['image'])
|
|
|
obs_cell = obs_grid.get(vx, vy)
|
|
|
world_cell = self.grid.get(x, y)
|
|
|
|
|
@@ -1211,13 +1214,14 @@ class MiniGridEnv(gym.Env):
|
|
|
Render an agent observation for visualization
|
|
|
"""
|
|
|
|
|
|
- grid = Grid.decode(obs)
|
|
|
+ grid, vis_mask = Grid.decode(obs)
|
|
|
|
|
|
# Render the whole grid
|
|
|
img = grid.render(
|
|
|
tile_size,
|
|
|
agent_pos=(self.agent_view_size // 2, self.agent_view_size - 1),
|
|
|
- agent_dir = 3
|
|
|
+ agent_dir=3,
|
|
|
+ highlight_mask=vis_mask
|
|
|
)
|
|
|
|
|
|
return img
|