瀏覽代碼

Added new wrapper to train from pixels

Maxime Chevalier-Boisvert 5 年之前
父節點
當前提交
c22d75053d
共有 1 個文件被更改,包括 26 次插入0 次删除
  1. 26 0
      gym_minigrid/wrappers.py

+ 26 - 0
gym_minigrid/wrappers.py

@@ -139,6 +139,32 @@ class RGBImgObsWrapper(gym.core.ObservationWrapper):
             tile_size=self.tile_size
         )
 
+class RGBImgPartialObsWrapper(gym.core.ObservationWrapper):
+    """
+    Wrapper to use partially observable RGB image as the only observation output
+    This can be used to have the agent to solve the gridworld in pixel space.
+    """
+
+    def __init__(self, env, tile_size=8):
+        super().__init__(env)
+
+        self.tile_size = tile_size
+
+        obs_shape = env.observation_space['image'].shape
+        self.observation_space = spaces.Box(
+            low=0,
+            high=255,
+            shape=(obs_shape[0] * tile_size, obs_shape[1] * tile_size, 3),
+            dtype='uint8'
+        )
+
+    def observation(self, obs):
+        env = self.unwrapped
+        return {
+            'mission': obs['mission'],
+            'image': env.get_obs_render(obs['image'], tile_size=self.tile_size, mode='rgb_array')
+        }
+
 class FullyObsWrapper(gym.core.ObservationWrapper):
     """
     Fully observable gridworld using a compact grid encoding