|
@@ -1213,7 +1213,7 @@ class MiniGridEnv(gym.Env):
|
|
|
|
|
|
return obs
|
|
|
|
|
|
- def get_obs_render(self, obs, tile_pixels=CELL_PIXELS//2, mode='pixmap'):
|
|
|
+ def get_obs_render(self, obs, tile_size=CELL_PIXELS//2, mode='pixmap'):
|
|
|
"""
|
|
|
Render an agent observation for visualization
|
|
|
"""
|
|
@@ -1221,8 +1221,8 @@ class MiniGridEnv(gym.Env):
|
|
|
if self.obs_render == None:
|
|
|
from gym_minigrid.rendering import Renderer
|
|
|
self.obs_render = Renderer(
|
|
|
- self.agent_view_size * tile_pixels,
|
|
|
- self.agent_view_size * tile_pixels
|
|
|
+ self.agent_view_size * tile_size,
|
|
|
+ self.agent_view_size * tile_size
|
|
|
)
|
|
|
|
|
|
r = self.obs_render
|
|
@@ -1232,10 +1232,10 @@ class MiniGridEnv(gym.Env):
|
|
|
grid = Grid.decode(obs)
|
|
|
|
|
|
# Render the whole grid
|
|
|
- grid.render(r, tile_pixels)
|
|
|
+ grid.render(r, tile_size)
|
|
|
|
|
|
# Draw the agent
|
|
|
- ratio = tile_pixels / CELL_PIXELS
|
|
|
+ ratio = tile_size / CELL_PIXELS
|
|
|
r.push()
|
|
|
r.scale(ratio, ratio)
|
|
|
r.translate(
|
|
@@ -1253,14 +1253,14 @@ class MiniGridEnv(gym.Env):
|
|
|
r.pop()
|
|
|
|
|
|
r.endFrame()
|
|
|
-
|
|
|
+
|
|
|
if mode == 'rgb_array':
|
|
|
return r.getArray()
|
|
|
elif mode == 'pixmap':
|
|
|
return r.getPixmap()
|
|
|
return r
|
|
|
|
|
|
- def render(self, mode='human', close=False, highlight=True):
|
|
|
+ def render(self, mode='human', close=False, highlight=True, tile_size=CELL_PIXELS):
|
|
|
"""
|
|
|
Render the whole-grid human view
|
|
|
"""
|
|
@@ -1270,11 +1270,11 @@ class MiniGridEnv(gym.Env):
|
|
|
self.grid_render.close()
|
|
|
return
|
|
|
|
|
|
- if self.grid_render is None or self.grid_render.window is None:
|
|
|
+ if self.grid_render is None or self.grid_render.window is None or (self.grid_render.width != self.width * tile_size):
|
|
|
from gym_minigrid.rendering import Renderer
|
|
|
self.grid_render = Renderer(
|
|
|
- self.width * CELL_PIXELS,
|
|
|
- self.height * CELL_PIXELS,
|
|
|
+ self.width * tile_size,
|
|
|
+ self.height * tile_size,
|
|
|
True if mode == 'human' else False
|
|
|
)
|
|
|
|
|
@@ -1286,10 +1286,12 @@ class MiniGridEnv(gym.Env):
|
|
|
r.beginFrame()
|
|
|
|
|
|
# Render the whole grid
|
|
|
- self.grid.render(r, CELL_PIXELS)
|
|
|
+ self.grid.render(r, tile_size)
|
|
|
|
|
|
# Draw the agent
|
|
|
+ ratio = tile_size / CELL_PIXELS
|
|
|
r.push()
|
|
|
+ r.scale(ratio, ratio)
|
|
|
r.translate(
|
|
|
CELL_PIXELS * (self.agent_pos[0] + 0.5),
|
|
|
CELL_PIXELS * (self.agent_pos[1] + 0.5)
|
|
@@ -1326,10 +1328,10 @@ class MiniGridEnv(gym.Env):
|
|
|
|
|
|
# Highlight the cell
|
|
|
r.fillRect(
|
|
|
- abs_i * CELL_PIXELS,
|
|
|
- abs_j * CELL_PIXELS,
|
|
|
- CELL_PIXELS,
|
|
|
- CELL_PIXELS,
|
|
|
+ abs_i * tile_size,
|
|
|
+ abs_j * tile_size,
|
|
|
+ tile_size,
|
|
|
+ tile_size,
|
|
|
255, 255, 255, 75
|
|
|
)
|
|
|
|