Przeglądaj źródła

Added new wrapper to train from pixels

Maxime Chevalier-Boisvert 5 lat temu
rodzic
commit
c22d75053d
1 zmienionych plików z 26 dodań i 0 usunięć
  1. 26 0
      gym_minigrid/wrappers.py

+ 26 - 0
gym_minigrid/wrappers.py

@@ -139,6 +139,32 @@ class RGBImgObsWrapper(gym.core.ObservationWrapper):
             tile_size=self.tile_size
             tile_size=self.tile_size
         )
         )
 
 
+class RGBImgPartialObsWrapper(gym.core.ObservationWrapper):
+    """
+    Wrapper to use partially observable RGB image as the only observation output
+    This can be used to have the agent to solve the gridworld in pixel space.
+    """
+
+    def __init__(self, env, tile_size=8):
+        super().__init__(env)
+
+        self.tile_size = tile_size
+
+        obs_shape = env.observation_space['image'].shape
+        self.observation_space = spaces.Box(
+            low=0,
+            high=255,
+            shape=(obs_shape[0] * tile_size, obs_shape[1] * tile_size, 3),
+            dtype='uint8'
+        )
+
+    def observation(self, obs):
+        env = self.unwrapped
+        return {
+            'mission': obs['mission'],
+            'image': env.get_obs_render(obs['image'], tile_size=self.tile_size, mode='rgb_array')
+        }
+
 class FullyObsWrapper(gym.core.ObservationWrapper):
 class FullyObsWrapper(gym.core.ObservationWrapper):
     """
     """
     Fully observable gridworld using a compact grid encoding
     Fully observable gridworld using a compact grid encoding