|
@@ -1046,9 +1046,11 @@ class MiniGridEnv(gym.Env):
|
|
|
|
|
|
return obs, reward, done, {}
|
|
|
|
|
|
- def gen_obs(self):
|
|
|
+ def gen_obs_grid(self):
|
|
|
"""
|
|
|
- Generate the agent's view (partially observable, low-resolution encoding)
|
|
|
+ Generate the sub-grid observed by the agent.
|
|
|
+ This method also outputs a visibility mask telling us which grid
|
|
|
+ cells the agent can actually see.
|
|
|
"""
|
|
|
|
|
|
topX, topY, botX, botY = self.get_view_exts()
|
|
@@ -1061,7 +1063,9 @@ class MiniGridEnv(gym.Env):
|
|
|
# Process occluders and visibility
|
|
|
# Note that this incurs some performance cost
|
|
|
if not self.see_through_walls:
|
|
|
- grid.process_vis(agent_pos=(3, 6))
|
|
|
+ vis_mask = grid.process_vis(agent_pos=(3, 6))
|
|
|
+ else:
|
|
|
+ vis_mask = np.ones(shape=(grid.width, grid.height), dtype=np.bool)
|
|
|
|
|
|
# Make it so the agent sees what it's carrying
|
|
|
# We do this by placing the carried object at the agent's position
|
|
@@ -1072,6 +1076,15 @@ class MiniGridEnv(gym.Env):
|
|
|
else:
|
|
|
grid.set(*agent_pos, None)
|
|
|
|
|
|
+ return grid, vis_mask
|
|
|
+
|
|
|
+ def gen_obs(self):
|
|
|
+ """
|
|
|
+ Generate the agent's view (partially observable, low-resolution encoding)
|
|
|
+ """
|
|
|
+
|
|
|
+ grid, vis_mask = self.gen_obs_grid()
|
|
|
+
|
|
|
# Encode the partially observable view into a numpy array
|
|
|
image = grid.encode()
|
|
|
|