Rodrigo de Lazcano 2 năm trước cách đây
mục cha
commit
301102a05c
1 tập tin đã thay đổi với 2 bổ sung0 xóa
  1. 2 0
      minigrid/wrappers.py

+ 2 - 0
minigrid/wrappers.py

@@ -519,9 +519,11 @@ class SymbolicObsWrapper(ObservationWrapper):
         objects = np.array(
             [OBJECT_TO_IDX[o.type] if o is not None else -1 for o in self.grid.grid]
         )
+        agent_pos = self.env.agent_pos
         w, h = self.width, self.height
         grid = np.mgrid[:w, :h]
         grid = np.concatenate([grid, objects.reshape(1, w, h)])
         grid = np.transpose(grid, (1, 2, 0))
+        grid[agent_pos[0], agent_pos[1], 2] = OBJECT_TO_IDX["agent"]
         obs["image"] = grid
         return obs