|
@@ -1087,15 +1087,15 @@ class MiniGridEnv(gym.Env):
|
|
|
|
|
|
return obs
|
|
|
|
|
|
- def get_obs_render(self, obs):
|
|
|
+ def get_obs_render(self, obs, tile_pixels=CELL_PIXELS//2):
|
|
|
"""
|
|
|
Render an agent observation for visualization
|
|
|
"""
|
|
|
|
|
|
if self.obs_render == None:
|
|
|
self.obs_render = Renderer(
|
|
|
- AGENT_VIEW_SIZE * CELL_PIXELS // 2,
|
|
|
- AGENT_VIEW_SIZE * CELL_PIXELS // 2
|
|
|
+ AGENT_VIEW_SIZE * tile_pixels,
|
|
|
+ AGENT_VIEW_SIZE * tile_pixels
|
|
|
)
|
|
|
|
|
|
r = self.obs_render
|
|
@@ -1105,11 +1105,12 @@ class MiniGridEnv(gym.Env):
|
|
|
grid = Grid.decode(obs)
|
|
|
|
|
|
# Render the whole grid
|
|
|
- grid.render(r, CELL_PIXELS // 2)
|
|
|
+ grid.render(r, tile_pixels)
|
|
|
|
|
|
# Draw the agent
|
|
|
+ ratio = tile_pixels / CELL_PIXELS
|
|
|
r.push()
|
|
|
- r.scale(0.5, 0.5)
|
|
|
+ r.scale(ratio, ratio)
|
|
|
r.translate(
|
|
|
CELL_PIXELS * (0.5 + AGENT_VIEW_SIZE // 2),
|
|
|
CELL_PIXELS * (AGENT_VIEW_SIZE - 0.5)
|