Browse Source

Added tile_size argument to RGBImgObsWrapper

Maxime Chevalier-Boisvert 5 years ago
parent
commit
0f3abab084
2 changed files with 26 additions and 18 deletions
  1. 17 15
      gym_minigrid/minigrid.py
  2. 9 3
      gym_minigrid/wrappers.py

+ 17 - 15
gym_minigrid/minigrid.py

@@ -1213,7 +1213,7 @@ class MiniGridEnv(gym.Env):
 
         return obs
 
-    def get_obs_render(self, obs, tile_pixels=CELL_PIXELS//2, mode='pixmap'):
+    def get_obs_render(self, obs, tile_size=CELL_PIXELS//2, mode='pixmap'):
         """
         Render an agent observation for visualization
         """
@@ -1221,8 +1221,8 @@ class MiniGridEnv(gym.Env):
         if self.obs_render == None:
             from gym_minigrid.rendering import Renderer
             self.obs_render = Renderer(
-                self.agent_view_size * tile_pixels,
-                self.agent_view_size * tile_pixels
+                self.agent_view_size * tile_size,
+                self.agent_view_size * tile_size
             )
 
         r = self.obs_render
@@ -1232,10 +1232,10 @@ class MiniGridEnv(gym.Env):
         grid = Grid.decode(obs)
 
         # Render the whole grid
-        grid.render(r, tile_pixels)
+        grid.render(r, tile_size)
 
         # Draw the agent
-        ratio = tile_pixels / CELL_PIXELS
+        ratio = tile_size / CELL_PIXELS
         r.push()
         r.scale(ratio, ratio)
         r.translate(
@@ -1253,14 +1253,14 @@ class MiniGridEnv(gym.Env):
         r.pop()
 
         r.endFrame()
-    
+
         if mode == 'rgb_array':
             return r.getArray()
         elif mode == 'pixmap':
             return r.getPixmap()
         return r
 
-    def render(self, mode='human', close=False, highlight=True):
+    def render(self, mode='human', close=False, highlight=True, tile_size=CELL_PIXELS):
         """
         Render the whole-grid human view
         """
@@ -1270,11 +1270,11 @@ class MiniGridEnv(gym.Env):
                 self.grid_render.close()
             return
 
-        if self.grid_render is None or self.grid_render.window is None:
+        if self.grid_render is None or self.grid_render.window is None or (self.grid_render.width != self.width * tile_size):
             from gym_minigrid.rendering import Renderer
             self.grid_render = Renderer(
-                self.width * CELL_PIXELS,
-                self.height * CELL_PIXELS,
+                self.width * tile_size,
+                self.height * tile_size,
                 True if mode == 'human' else False
             )
 
@@ -1286,10 +1286,12 @@ class MiniGridEnv(gym.Env):
         r.beginFrame()
 
         # Render the whole grid
-        self.grid.render(r, CELL_PIXELS)
+        self.grid.render(r, tile_size)
 
         # Draw the agent
+        ratio = tile_size / CELL_PIXELS
         r.push()
+        r.scale(ratio, ratio)
         r.translate(
             CELL_PIXELS * (self.agent_pos[0] + 0.5),
             CELL_PIXELS * (self.agent_pos[1] + 0.5)
@@ -1326,10 +1328,10 @@ class MiniGridEnv(gym.Env):
 
                     # Highlight the cell
                     r.fillRect(
-                        abs_i * CELL_PIXELS,
-                        abs_j * CELL_PIXELS,
-                        CELL_PIXELS,
-                        CELL_PIXELS,
+                        abs_i * tile_size,
+                        abs_j * tile_size,
+                        tile_size,
+                        tile_size,
                         255, 255, 255, 75
                     )
 

+ 9 - 3
gym_minigrid/wrappers.py

@@ -119,19 +119,25 @@ class RGBImgObsWrapper(gym.core.ObservationWrapper):
     gridworld in pixel space.
     """
 
-    def __init__(self, env):
+    def __init__(self, env, tile_size=8):
         super().__init__(env)
 
+        self.tile_size = tile_size
+
         self.observation_space = spaces.Box(
             low=0,
             high=255,
-            shape=(self.env.width*CELL_PIXELS, self.env.height*CELL_PIXELS, 3),
+            shape=(self.env.width*tile_size, self.env.height*tile_size, 3),
             dtype='uint8'
         )
 
     def observation(self, obs):
         env = self.unwrapped
-        return env.render(mode = 'rgb_array', highlight = False)
+        return env.render(
+            mode='rgb_array',
+            highlight=False,
+            tile_size=self.tile_size
+        )
 
 class FullyObsWrapper(gym.core.ObservationWrapper):
     """