浏览代码

Adding symbolic observation wrapper (#165)

* Fix numpy DeprecationWarning`np.bool` to `bool`

https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations

* add symbolic obs wrapper

* eof

* keep obs fields

* moveaxis typo to transpose

* moveaxis typo to transpose
Eduardo Pignatelli 3 年之前
父节点
当前提交
6116191b15
共有 2 个文件被更改,包括 34 次插入7 次删除
  1. 5 5
      gym_minigrid/minigrid.py
  2. 29 2
      gym_minigrid/wrappers.py

+ 5 - 5
gym_minigrid/minigrid.py

@@ -501,7 +501,7 @@ class Grid:
         """
 
         if highlight_mask is None:
-            highlight_mask = np.zeros(shape=(self.width, self.height), dtype=np.bool)
+            highlight_mask = np.zeros(shape=(self.width, self.height), dtype=bool)
 
         # Compute the total grid size
         width_px = self.width * tile_size
@@ -564,7 +564,7 @@ class Grid:
         width, height, channels = array.shape
         assert channels == 3
 
-        vis_mask = np.ones(shape=(width, height), dtype=np.bool)
+        vis_mask = np.ones(shape=(width, height), dtype=bool)
 
         grid = Grid(width, height)
         for i in range(width):
@@ -577,7 +577,7 @@ class Grid:
         return grid, vis_mask
 
     def process_vis(grid, agent_pos):
-        mask = np.zeros(shape=(grid.width, grid.height), dtype=np.bool)
+        mask = np.zeros(shape=(grid.width, grid.height), dtype=bool)
 
         mask[agent_pos[0], agent_pos[1]] = True
 
@@ -1181,7 +1181,7 @@ class MiniGridEnv(gym.Env):
         if not self.see_through_walls:
             vis_mask = grid.process_vis(agent_pos=(self.agent_view_size // 2 , self.agent_view_size - 1))
         else:
-            vis_mask = np.ones(shape=(grid.width, grid.height), dtype=np.bool)
+            vis_mask = np.ones(shape=(grid.width, grid.height), dtype=bool)
 
         # Make it so the agent sees what it's carrying
         # We do this by placing the carried object at the agent's position
@@ -1260,7 +1260,7 @@ class MiniGridEnv(gym.Env):
         top_left = self.agent_pos + f_vec * (self.agent_view_size-1) - r_vec * (self.agent_view_size // 2)
 
         # Mask of which cells to highlight
-        highlight_mask = np.zeros(shape=(self.width, self.height), dtype=np.bool)
+        highlight_mask = np.zeros(shape=(self.width, self.height), dtype=bool)
 
         # For each cell in the visibility mask
         for vis_j in range(0, self.agent_view_size):

+ 29 - 2
gym_minigrid/wrappers.py

@@ -5,7 +5,7 @@ from functools import reduce
 import numpy as np
 import gym
 from gym import error, spaces, utils
-from .minigrid import OBJECT_TO_IDX, COLOR_TO_IDX, STATE_TO_IDX
+from .minigrid import OBJECT_TO_IDX, COLOR_TO_IDX, STATE_TO_IDX, Goal
 
 class ReseedWrapper(gym.core.Wrapper):
     """
@@ -331,7 +331,6 @@ class ViewSizeWrapper(gym.core.Wrapper):
     def step(self, action):
         return self.env.step(action)
 
-from .minigrid import Goal
 class DirectionObsWrapper(gym.core.ObservationWrapper):
     """
     Provides the slope/angular direction to the goal with the observations as modeled by (y2 - y2 )/( x2 - x1)
@@ -354,3 +353,31 @@ class DirectionObsWrapper(gym.core.ObservationWrapper):
         slope = np.divide( self.goal_position[1] - self.agent_pos[1] ,  self.goal_position[0] - self.agent_pos[0])
         obs['goal_direction'] = np.arctan( slope ) if self.type == 'angle' else slope
         return obs
+
+class SymbolicObsWrapper(gym.core.ObservationWrapper):
+    """
+    Fully observable grid with a symbolic state representation.
+    The symbol is a triple of (X, Y, IDX), where X and Y are
+    the coordinates on the grid, and IDX is the id of the object.
+    """
+
+    def __init__(self, env):
+        super().__init__(env)
+
+        self.observation_space.spaces["image"] = spaces.Box(
+            low=0,
+            high=max(OBJECT_TO_IDX.values()),
+            shape=(self.env.width, self.env.height, 3),  # number of cells
+            dtype="uint8",
+        )
+
+    def observation(self, obs):
+        objects = np.array(
+            [OBJECT_TO_IDX[o.type] if o is not None else -1 for o in self.grid.grid]
+        )
+        w, h = self.width, self.height
+        grid = np.mgrid[:w, :h]
+        grid = np.concatenate([grid, objects.reshape(1, w, h)])
+        grid = np.transpose(grid, (1, 2, 0))
+        obs['image'] = grid
+        return obs