浏览代码

Agent Field View Customization (#54)

* Add agent view size and obs size array as Minigrid members

* Update AgentView wrapper

* Fix typo in wrapper name

* Fix observation space in Agentview wrapper

* change agent view wrapper name

* Pass agent_view_size as an arg to MiniGridEnv, remove unused obs_array_size

* Added some comments

* Revert back to self.observation_space
Abdelrahman Ahmed 6 年之前
父节点
当前提交
92caf4ad12
共有 2 个文件被更改,包括 52 次插入28 次删除
  1. 26 28
      gym_minigrid/minigrid.py
  2. 26 0
      gym_minigrid/wrappers.py

+ 26 - 28
gym_minigrid/minigrid.py

@@ -8,12 +8,6 @@ from gym.utils import seeding
 # Size in pixels of a cell in the full-scale human view
 # Size in pixels of a cell in the full-scale human view
 CELL_PIXELS = 32
 CELL_PIXELS = 32
 
 
-# Number of cells (width and height) in the agent view
-AGENT_VIEW_SIZE = 7
-
-# Size of the array given as an observation to the agent
-OBS_ARRAY_SIZE = (AGENT_VIEW_SIZE, AGENT_VIEW_SIZE, 3)
-
 # Map of color names to RGB values
 # Map of color names to RGB values
 COLORS = {
 COLORS = {
     'red'   : np.array([255, 0, 0]),
     'red'   : np.array([255, 0, 0]),
@@ -669,7 +663,8 @@ class MiniGridEnv(gym.Env):
         height=None,
         height=None,
         max_steps=100,
         max_steps=100,
         see_through_walls=False,
         see_through_walls=False,
-        seed=1337
+        seed=1337,
+        agent_view_size=7
     ):
     ):
         # Can't set both grid_size and width/height
         # Can't set both grid_size and width/height
         if grid_size:
         if grid_size:
@@ -683,12 +678,15 @@ class MiniGridEnv(gym.Env):
         # Actions are discrete integer values
         # Actions are discrete integer values
         self.action_space = spaces.Discrete(len(self.actions))
         self.action_space = spaces.Discrete(len(self.actions))
 
 
+        # Number of cells (width and height) in the agent view
+        self.agent_view_size = agent_view_size
+
         # Observations are dictionaries containing an
         # Observations are dictionaries containing an
         # encoding of the grid and a textual 'mission' string
         # encoding of the grid and a textual 'mission' string
         self.observation_space = spaces.Box(
         self.observation_space = spaces.Box(
             low=0,
             low=0,
             high=255,
             high=255,
-            shape=OBS_ARRAY_SIZE,
+            shape=(self.agent_view_size, self.agent_view_size, 3),
             dtype='uint8'
             dtype='uint8'
         )
         )
         self.observation_space = spaces.Dict({
         self.observation_space = spaces.Dict({
@@ -1009,8 +1007,8 @@ class MiniGridEnv(gym.Env):
         rx, ry = self.right_vec
         rx, ry = self.right_vec
 
 
         # Compute the absolute coordinates of the top-left view corner
         # Compute the absolute coordinates of the top-left view corner
-        sz = AGENT_VIEW_SIZE
-        hs = AGENT_VIEW_SIZE // 2
+        sz = self.agent_view_size
+        hs = self.agent_view_size // 2
         tx = ax + (dx * (sz-1)) - (rx * hs)
         tx = ax + (dx * (sz-1)) - (rx * hs)
         ty = ay + (dy * (sz-1)) - (ry * hs)
         ty = ay + (dy * (sz-1)) - (ry * hs)
 
 
@@ -1033,24 +1031,24 @@ class MiniGridEnv(gym.Env):
         # Facing right
         # Facing right
         if self.agent_dir == 0:
         if self.agent_dir == 0:
             topX = self.agent_pos[0]
             topX = self.agent_pos[0]
-            topY = self.agent_pos[1] - AGENT_VIEW_SIZE // 2
+            topY = self.agent_pos[1] - self.agent_view_size // 2
         # Facing down
         # Facing down
         elif self.agent_dir == 1:
         elif self.agent_dir == 1:
-            topX = self.agent_pos[0] - AGENT_VIEW_SIZE // 2
+            topX = self.agent_pos[0] - self.agent_view_size // 2
             topY = self.agent_pos[1]
             topY = self.agent_pos[1]
         # Facing left
         # Facing left
         elif self.agent_dir == 2:
         elif self.agent_dir == 2:
-            topX = self.agent_pos[0] - AGENT_VIEW_SIZE + 1
-            topY = self.agent_pos[1] - AGENT_VIEW_SIZE // 2
+            topX = self.agent_pos[0] - self.agent_view_size + 1
+            topY = self.agent_pos[1] - self.agent_view_size // 2
         # Facing up
         # Facing up
         elif self.agent_dir == 3:
         elif self.agent_dir == 3:
-            topX = self.agent_pos[0] - AGENT_VIEW_SIZE // 2
-            topY = self.agent_pos[1] - AGENT_VIEW_SIZE + 1
+            topX = self.agent_pos[0] - self.agent_view_size // 2
+            topY = self.agent_pos[1] - self.agent_view_size + 1
         else:
         else:
             assert False, "invalid agent direction"
             assert False, "invalid agent direction"
 
 
-        botX = topX + AGENT_VIEW_SIZE
-        botY = topY + AGENT_VIEW_SIZE
+        botX = topX + self.agent_view_size
+        botY = topY + self.agent_view_size
 
 
         return (topX, topY, botX, botY)
         return (topX, topY, botX, botY)
 
 
@@ -1061,7 +1059,7 @@ class MiniGridEnv(gym.Env):
 
 
         vx, vy = self.get_view_coords(x, y)
         vx, vy = self.get_view_coords(x, y)
 
 
-        if vx < 0 or vy < 0 or vx >= AGENT_VIEW_SIZE or vy >= AGENT_VIEW_SIZE:
+        if vx < 0 or vy < 0 or vx >= self.agent_view_size or vy >= self.agent_view_size:
             return None
             return None
 
 
         return vx, vy
         return vx, vy
@@ -1165,7 +1163,7 @@ class MiniGridEnv(gym.Env):
 
 
         topX, topY, botX, botY = self.get_view_exts()
         topX, topY, botX, botY = self.get_view_exts()
 
 
-        grid = self.grid.slice(topX, topY, AGENT_VIEW_SIZE, AGENT_VIEW_SIZE)
+        grid = self.grid.slice(topX, topY, self.agent_view_size, self.agent_view_size)
 
 
         for i in range(self.agent_dir + 1):
         for i in range(self.agent_dir + 1):
             grid = grid.rotate_left()
             grid = grid.rotate_left()
@@ -1173,7 +1171,7 @@ class MiniGridEnv(gym.Env):
         # Process occluders and visibility
         # Process occluders and visibility
         # Note that this incurs some performance cost
         # Note that this incurs some performance cost
         if not self.see_through_walls:
         if not self.see_through_walls:
-            vis_mask = grid.process_vis(agent_pos=(AGENT_VIEW_SIZE // 2 , AGENT_VIEW_SIZE - 1))
+            vis_mask = grid.process_vis(agent_pos=(self.agent_view_size // 2 , self.agent_view_size - 1))
         else:
         else:
             vis_mask = np.ones(shape=(grid.width, grid.height), dtype=np.bool)
             vis_mask = np.ones(shape=(grid.width, grid.height), dtype=np.bool)
 
 
@@ -1220,8 +1218,8 @@ class MiniGridEnv(gym.Env):
         if self.obs_render == None:
         if self.obs_render == None:
             from gym_minigrid.rendering import Renderer
             from gym_minigrid.rendering import Renderer
             self.obs_render = Renderer(
             self.obs_render = Renderer(
-                AGENT_VIEW_SIZE * tile_pixels,
-                AGENT_VIEW_SIZE * tile_pixels
+                self.agent_view_size * tile_pixels,
+                self.agent_view_size * tile_pixels
             )
             )
 
 
         r = self.obs_render
         r = self.obs_render
@@ -1238,8 +1236,8 @@ class MiniGridEnv(gym.Env):
         r.push()
         r.push()
         r.scale(ratio, ratio)
         r.scale(ratio, ratio)
         r.translate(
         r.translate(
-            CELL_PIXELS * (0.5 + AGENT_VIEW_SIZE // 2),
-            CELL_PIXELS * (AGENT_VIEW_SIZE - 0.5)
+            CELL_PIXELS * (0.5 + self.agent_view_size // 2),
+            CELL_PIXELS * (self.agent_view_size - 0.5)
         )
         )
         r.rotate(3 * 90)
         r.rotate(3 * 90)
         r.setLineColor(255, 0, 0)
         r.setLineColor(255, 0, 0)
@@ -1306,11 +1304,11 @@ class MiniGridEnv(gym.Env):
         # of the agent's view area
         # of the agent's view area
         f_vec = self.dir_vec
         f_vec = self.dir_vec
         r_vec = self.right_vec
         r_vec = self.right_vec
-        top_left = self.agent_pos + f_vec * (AGENT_VIEW_SIZE-1) - r_vec * (AGENT_VIEW_SIZE // 2)
+        top_left = self.agent_pos + f_vec * (self.agent_view_size-1) - r_vec * (self.agent_view_size // 2)
 
 
         # For each cell in the visibility mask
         # For each cell in the visibility mask
-        for vis_j in range(0, AGENT_VIEW_SIZE):
-            for vis_i in range(0, AGENT_VIEW_SIZE):
+        for vis_j in range(0, self.agent_view_size):
+            for vis_i in range(0, self.agent_view_size):
                 # If this cell is not visible, don't highlight it
                 # If this cell is not visible, don't highlight it
                 if not vis_mask[vis_i, vis_j]:
                 if not vis_mask[vis_i, vis_j]:
                     continue
                     continue

+ 26 - 0
gym_minigrid/wrappers.py

@@ -158,3 +158,29 @@ class FlatObsWrapper(gym.core.ObservationWrapper):
         obs = np.concatenate((image.flatten(), self.cachedArray.flatten()))
         obs = np.concatenate((image.flatten(), self.cachedArray.flatten()))
 
 
         return obs
         return obs
+
+
+class AgentViewWrapper(gym.core.Wrapper):
+    """
+    Wrapper to customize the agent's field of view.
+    """
+
+    def __init__(self, env, agent_view_size=7):
+        super(AgentViewWrapper, self).__init__(env)
+        self.__dict__.update(vars(env))  # Hack to pass values to super wrapper
+
+        # Override default view size
+        env.agent_view_size = agent_view_size
+
+        # Compute observation space with specified view size
+        observation_space = gym.spaces.Box(
+            low=0,
+            high=255,
+            shape=(agent_view_size, agent_view_size, 3),
+            dtype='uint8'
+        )
+
+        # Override the environment's observation space
+        self.observation_space = spaces.Dict({
+            'image': observation_space
+        })