Przeglądaj źródła

Refactor observation wrappers for consistency

Maxime Chevalier-Boisvert 5 lat temu
rodzic
commit
0403f4dc00
2 zmienionych plików z 32 dodań i 14 usunięć
  1. 30 12
      gym_minigrid/wrappers.py
  2. 2 2
      run_tests.py

+ 30 - 12
gym_minigrid/wrappers.py

@@ -106,7 +106,6 @@ class ImgObsWrapper(gym.core.ObservationWrapper):
 
     def __init__(self, env):
         super().__init__(env)
-
         self.observation_space = env.observation_space.spaces['image']
 
     def observation(self, obs):
@@ -125,14 +124,15 @@ class OneHotPartialObsWrapper(gym.core.ObservationWrapper):
 
         obs_shape = env.observation_space['image'].shape
 
+        # Number of bits per cell
         num_bits = len(OBJECT_TO_IDX) + len(COLOR_TO_IDX) + len(STATE_TO_IDX)
 
         self.observation_space.spaces["image"] = spaces.Box(
-                low=0,
-                high=255,
-                shape=(obs_shape[0], obs_shape[1], num_bits),
-                dtype='uint8'
-            )
+            low=0,
+            high=255,
+            shape=(obs_shape[0], obs_shape[1], num_bits),
+            dtype='uint8'
+        )
 
     def observation(self, obs):
         img = obs['image']
@@ -174,12 +174,19 @@ class RGBImgObsWrapper(gym.core.ObservationWrapper):
 
     def observation(self, obs):
         env = self.unwrapped
-        return env.render(
+
+        rgb_img = env.render(
             mode='rgb_array',
             highlight=False,
             tile_size=self.tile_size
         )
 
+        return {
+            'mission': obs['mission'],
+            'image': rgb_img
+        }
+
+
 class RGBImgPartialObsWrapper(gym.core.ObservationWrapper):
     """
     Wrapper to use partially observable RGB image as the only observation output
@@ -201,9 +208,16 @@ class RGBImgPartialObsWrapper(gym.core.ObservationWrapper):
 
     def observation(self, obs):
         env = self.unwrapped
+
+        rgb_img_partial = env.get_obs_render(
+            obs['image'],
+            tile_size=self.tile_size,
+            mode='rgb_array'
+        )
+
         return {
             'mission': obs['mission'],
-            'image': env.get_obs_render(obs['image'], tile_size=self.tile_size, mode='rgb_array')
+            'image': rgb_img_partial
         }
 
 class FullyObsWrapper(gym.core.ObservationWrapper):
@@ -214,7 +228,7 @@ class FullyObsWrapper(gym.core.ObservationWrapper):
     def __init__(self, env):
         super().__init__(env)
 
-        self.observation_space = spaces.Box(
+        self.observation_space.spaces["image"] = spaces.Box(
             low=0,
             high=255,
             shape=(self.env.width, self.env.height, 3),  # number of cells
@@ -230,7 +244,10 @@ class FullyObsWrapper(gym.core.ObservationWrapper):
             env.agent_dir
         ])
 
-        return full_grid
+        return {
+            'mission': obs['mission'],
+            'image': full_grid
+        }
 
 class FlatObsWrapper(gym.core.ObservationWrapper):
     """
@@ -283,13 +300,14 @@ class FlatObsWrapper(gym.core.ObservationWrapper):
 
         return obs
 
-class AgentViewWrapper(gym.core.Wrapper):
+class ViewSizeWrapper(gym.core.Wrapper):
     """
     Wrapper to customize the agent field of view size.
+    This cannot be used with fully observable wrappers.
     """
 
     def __init__(self, env, agent_view_size=7):
-        super(AgentViewWrapper, self).__init__(env)
+        super().__init__(env)
 
         # Override default view size
         env.unwrapped.agent_view_size = agent_view_size

+ 2 - 2
run_tests.py

@@ -88,7 +88,7 @@ for env_name in env_list:
     env = FullyObsWrapper(env)
     env.reset()
     obs, _, _, _ = env.step(0)
-    assert obs.shape == env.observation_space.shape
+    assert obs['image'].shape == env.observation_space.spaces['image'].shape
     env.close()
 
     env = gym.make(env_name)
@@ -98,7 +98,7 @@ for env_name in env_list:
     env.close()
 
     env = gym.make(env_name)
-    env = AgentViewWrapper(env, 5)
+    env = ViewSizeWrapper(env, 5)
     env.reset()
     env.step(0)
     env.close()