瀏覽代碼

Highlight visible cells in agent view

Maxime Chevalier-Boisvert 5 年之前
父節點
當前提交
9d1bc63bff
共有 2 個文件被更改,包括 15 次插入10 次删除
  1. 9 5
      gym_minigrid/minigrid.py
  2. 6 5
      run_tests.py

+ 9 - 5
gym_minigrid/minigrid.py

@@ -364,7 +364,7 @@ class Grid:
         return False
 
     def __eq__(self, other):
-        grid1 = self.encode()
+        grid1  = self.encode()
         grid2 = other.encode()
         return np.array_equal(grid2, grid1)
 
@@ -569,14 +569,17 @@ class Grid:
         width, height, channels = array.shape
         assert channels == 3
 
+        vis_mask = np.ones(shape=(width, height), dtype=np.bool)
+
         grid = Grid(width, height)
         for i in range(width):
             for j in range(height):
                 type_idx, color_idx, state = array[i, j]
                 v = WorldObj.decode(type_idx, color_idx, state)
                 grid.set(i, j, v)
+                vis_mask[i, j] = (type_idx != OBJECT_TO_IDX['unseen'])
 
-        return grid
+        return grid, vis_mask
 
     def process_vis(grid, agent_pos):
         mask = np.zeros(shape=(grid.width, grid.height), dtype=np.bool)
@@ -1078,7 +1081,7 @@ class MiniGridEnv(gym.Env):
         vx, vy = coordinates
 
         obs = self.gen_obs()
-        obs_grid = Grid.decode(obs['image'])
+        obs_grid, _ = Grid.decode(obs['image'])
         obs_cell = obs_grid.get(vx, vy)
         world_cell = self.grid.get(x, y)
 
@@ -1211,13 +1214,14 @@ class MiniGridEnv(gym.Env):
         Render an agent observation for visualization
         """
 
-        grid = Grid.decode(obs)
+        grid, vis_mask = Grid.decode(obs)
 
         # Render the whole grid
         img = grid.render(
             tile_size,
             agent_pos=(self.agent_view_size // 2, self.agent_view_size - 1),
-            agent_dir = 3
+            agent_dir=3,
+            highlight_mask=vis_mask
         )
 
         return img

+ 6 - 5
run_tests.py

@@ -50,8 +50,8 @@ for env_name in env_list:
 
         # Test observation encode/decode roundtrip
         img = obs['image']
-        vis_mask = img[:, :, 0] != OBJECT_TO_IDX['unseen']  # hackish
-        img2 = Grid.decode(img).encode(vis_mask=vis_mask)
+        grid, vis_mask = Grid.decode(img)
+        img2 = grid.encode(vis_mask=vis_mask)
         assert np.array_equal(img, img2)
 
         # Test the env to string function
@@ -135,10 +135,11 @@ env.reset()
 for i in range(0, 500):
     action = random.randint(0, env.action_space.n - 1)
     obs, reward, done, info = env.step(action)
-    goal_visible = ('green', 'goal') in Grid.decode(obs['image'])
+
+    grid, _ = Grid.decode(obs['image'])
+    goal_visible = ('green', 'goal') in grid
+
     agent_sees_goal = env.agent_sees(*goal_pos)
     assert agent_sees_goal == goal_visible
     if done:
         env.reset()
-
-#############################################################################