Parcourir la source

Fix ViewSizeWrapper that changed the observation space of the unwrapped env !

saleml il y a 2 ans
Parent
commit
e90cdf1be9
2 fichiers modifiés avec 32 ajouts et 19 suppressions
  1. 19 13
      gym_minigrid/minigrid.py
  2. 13 6
      gym_minigrid/wrappers.py

+ 19 - 13
gym_minigrid/minigrid.py

@@ -1050,33 +1050,36 @@ class MiniGridEnv(gym.Env):
 
         return vx, vy
 
-    def get_view_exts(self):
+    def get_view_exts(self, agent_view_size=None):
         """
         Get the extents of the square set of tiles visible to the agent
         Note: the bottom extent indices are not included in the set
+        if agent_view_size is None, use self.agent_view_size
         """
 
+        agent_view_size = agent_view_size or self.agent_view_size
+
         # Facing right
         if self.agent_dir == 0:
             topX = self.agent_pos[0]
-            topY = self.agent_pos[1] - self.agent_view_size // 2
+            topY = self.agent_pos[1] - agent_view_size // 2
         # Facing down
         elif self.agent_dir == 1:
-            topX = self.agent_pos[0] - self.agent_view_size // 2
+            topX = self.agent_pos[0] - agent_view_size // 2
             topY = self.agent_pos[1]
         # Facing left
         elif self.agent_dir == 2:
-            topX = self.agent_pos[0] - self.agent_view_size + 1
-            topY = self.agent_pos[1] - self.agent_view_size // 2
+            topX = self.agent_pos[0] - agent_view_size + 1
+            topY = self.agent_pos[1] - agent_view_size // 2
         # Facing up
         elif self.agent_dir == 3:
-            topX = self.agent_pos[0] - self.agent_view_size // 2
-            topY = self.agent_pos[1] - self.agent_view_size + 1
+            topX = self.agent_pos[0] - agent_view_size // 2
+            topY = self.agent_pos[1] - agent_view_size + 1
         else:
             assert False, "invalid agent direction"
 
-        botX = topX + self.agent_view_size
-        botY = topY + self.agent_view_size
+        botX = topX + agent_view_size
+        botY = topY + agent_view_size
 
         return (topX, topY, botX, botY)
 
@@ -1182,16 +1185,19 @@ class MiniGridEnv(gym.Env):
 
         return obs, reward, done, {}
 
-    def gen_obs_grid(self):
+    def gen_obs_grid(self, agent_view_size=None):
         """
         Generate the sub-grid observed by the agent.
         This method also outputs a visibility mask telling us which grid
         cells the agent can actually see.
+        if agent_view_size is None, self.agent_view_size is used
         """
 
-        topX, topY, botX, botY = self.get_view_exts()
+        topX, topY, botX, botY = self.get_view_exts(agent_view_size)
+        
+        agent_view_size = agent_view_size or self.agent_view_size
 
-        grid = self.grid.slice(topX, topY, self.agent_view_size, self.agent_view_size)
+        grid = self.grid.slice(topX, topY, agent_view_size, agent_view_size)
 
         for i in range(self.agent_dir + 1):
             grid = grid.rotate_left()
@@ -1199,7 +1205,7 @@ class MiniGridEnv(gym.Env):
         # Process occluders and visibility
         # Note that this incurs some performance cost
         if not self.see_through_walls:
-            vis_mask = grid.process_vis(agent_pos=(self.agent_view_size // 2 , self.agent_view_size - 1))
+            vis_mask = grid.process_vis(agent_pos=(agent_view_size // 2 , agent_view_size - 1))
         else:
             vis_mask = np.ones(shape=(grid.width, grid.height), dtype=bool)
 

+ 13 - 6
gym_minigrid/wrappers.py

@@ -387,8 +387,7 @@ class ViewSizeWrapper(gym.core.Wrapper):
         assert agent_view_size % 2 == 1
         assert agent_view_size >= 3
 
-        # Override default view size
-        env.unwrapped.agent_view_size = agent_view_size
+        self.agent_view_size = agent_view_size
 
         # Compute observation space with specified view size
         new_image_space = gym.spaces.Box(
@@ -401,11 +400,19 @@ class ViewSizeWrapper(gym.core.Wrapper):
         # Override the environment's observation spaceexit
         self.observation_space = spaces.Dict({**self.observation_space, 'image':new_image_space})
 
-    def reset(self, **kwargs):
-        return self.env.reset(**kwargs)
+    def observation(self, obs):
+        env = self.unwrapped
 
-    def step(self, action):
-        return self.env.step(action)
+        grid, vis_mask = env.gen_obs_grid(self.agent_view_size)
+
+        # Encode the partially observable view into a numpy array
+        image = grid.encode(vis_mask)
+
+
+        return {
+            **obs,
+            'image': image
+        }
 
 class DirectionObsWrapper(gym.core.ObservationWrapper):
     """