瀏覽代碼

RGB Image Observation Wrapper added. (#60)

* Added RGBImgObsWrapper

Use the full RGB image as the only observation output, no language/mission.

* Added Highlight option, render bug fix. 

render bug fix: The render does not have a window if rgb_array mode is called before human mode.

* rgbImageObservation wrapper class comment updated

* Update minigrid.py

* Update wrappers.py
idigitopia 6 年之前
父節點
當前提交
ea721d955b
共有 2 個文件被更改,包括 45 次插入19 次删除
  1. 20 19
      gym_minigrid/minigrid.py
  2. 25 0
      gym_minigrid/wrappers.py

+ 20 - 19
gym_minigrid/minigrid.py

@@ -1254,7 +1254,7 @@ class MiniGridEnv(gym.Env):
 
         return r.getPixmap()
 
-    def render(self, mode='human', close=False):
+    def render(self, mode='human', close=False, highlight=True):
         """
         Render the whole-grid human view
         """
@@ -1264,7 +1264,7 @@ class MiniGridEnv(gym.Env):
                 self.grid_render.close()
             return
 
-        if self.grid_render is None:
+        if self.grid_render is None or self.grid_render.window is None:
             from gym_minigrid.rendering import Renderer
             self.grid_render = Renderer(
                 self.width * CELL_PIXELS,
@@ -1308,23 +1308,24 @@ class MiniGridEnv(gym.Env):
         top_left = self.agent_pos + f_vec * (self.agent_view_size-1) - r_vec * (self.agent_view_size // 2)
 
         # For each cell in the visibility mask
-        for vis_j in range(0, self.agent_view_size):
-            for vis_i in range(0, self.agent_view_size):
-                # If this cell is not visible, don't highlight it
-                if not vis_mask[vis_i, vis_j]:
-                    continue
-
-                # Compute the world coordinates of this cell
-                abs_i, abs_j = top_left - (f_vec * vis_j) + (r_vec * vis_i)
-
-                # Highlight the cell
-                r.fillRect(
-                    abs_i * CELL_PIXELS,
-                    abs_j * CELL_PIXELS,
-                    CELL_PIXELS,
-                    CELL_PIXELS,
-                    255, 255, 255, 75
-                )
+        if highlight:
+            for vis_j in range(0, self.agent_view_size):
+                for vis_i in range(0, self.agent_view_size):
+                    # If this cell is not visible, don't highlight it
+                    if not vis_mask[vis_i, vis_j]:
+                        continue
+
+                    # Compute the world coordinates of this cell
+                    abs_i, abs_j = top_left - (f_vec * vis_j) + (r_vec * vis_i)
+
+                    # Highlight the cell
+                    r.fillRect(
+                        abs_i * CELL_PIXELS,
+                        abs_j * CELL_PIXELS,
+                        CELL_PIXELS,
+                        CELL_PIXELS,
+                        255, 255, 255, 75
+                    )
 
         r.endFrame()
 

+ 25 - 0
gym_minigrid/wrappers.py

@@ -6,6 +6,7 @@ import numpy as np
 import gym
 from gym import error, spaces, utils
 from .minigrid import OBJECT_TO_IDX, COLOR_TO_IDX
+from .minigrid import CELL_PIXELS
 
 class ReseedWrapper(gym.core.Wrapper):
     """
@@ -114,6 +115,30 @@ class ImgObsWrapper(gym.core.ObservationWrapper):
     def observation(self, obs):
         return obs['image']
 
+class RGBImgObsWrapper(gym.core.ObservationWrapper):
+    """
+    Wrapper to use fully observable RGB image as the only observation output,
+    no language/mission. This can be used to have the agent to solve the
+    gridworld in pixel space.
+    """
+
+    def __init__(self, env):
+        self.__dict__.update(vars(env))  # Pass values to super wrapper
+        super().__init__(env)
+
+
+        self.observation_space = spaces.Box(
+            low=0,
+            high=255,
+            shape=(self.env.width*CELL_PIXELS, self.env.height*CELL_PIXELS, 3),
+            dtype='uint8'
+        )
+
+    def observation(self, obs):
+        env = self.unwrapped
+        return env.render(mode = 'rgb_array', highlight = False)
+
+
 class FullyObsWrapper(gym.core.ObservationWrapper):
     """
     Fully observable gridworld using a compact grid encoding