ソースを参照

Added gen_obs_grid method which outputs visibility mask

Maxime Chevalier-Boisvert 7 年 前
コミット
6b513362e6
1 ファイル変更16 行追加3 行削除
  1. 16 3
      gym_minigrid/minigrid.py

+ 16 - 3
gym_minigrid/minigrid.py

@@ -1046,9 +1046,11 @@ class MiniGridEnv(gym.Env):
 
         return obs, reward, done, {}
 
-    def gen_obs(self):
+    def gen_obs_grid(self):
         """
-        Generate the agent's view (partially observable, low-resolution encoding)
+        Generate the sub-grid observed by the agent.
+        This method also outputs a visibility mask telling us which grid
+        cells the agent can actually see.
         """
 
         topX, topY, botX, botY = self.get_view_exts()
@@ -1061,7 +1063,9 @@ class MiniGridEnv(gym.Env):
         # Process occluders and visibility
         # Note that this incurs some performance cost
         if not self.see_through_walls:
-            grid.process_vis(agent_pos=(3, 6))
+            vis_mask = grid.process_vis(agent_pos=(3, 6))
+        else:
+            vis_mask = np.ones(shape=(grid.width, grid.height), dtype=np.bool)
 
         # Make it so the agent sees what it's carrying
         # We do this by placing the carried object at the agent's position
@@ -1072,6 +1076,15 @@ class MiniGridEnv(gym.Env):
         else:
             grid.set(*agent_pos, None)
 
+        return grid, vis_mask
+
+    def gen_obs(self):
+        """
+        Generate the agent's view (partially observable, low-resolution encoding)
+        """
+
+        grid, vis_mask = self.gen_obs_grid()
+
         # Encode the partially observable view into a numpy array
         image = grid.encode()