|
@@ -1213,7 +1213,7 @@ class MiniGridEnv(gym.Env):
|
|
|
|
|
|
return obs
|
|
|
|
|
|
- def get_obs_render(self, obs, tile_pixels=CELL_PIXELS//2):
|
|
|
+ def get_obs_render(self, obs, tile_pixels=CELL_PIXELS//2, mode='pixmap'):
|
|
|
"""
|
|
|
Render an agent observation for visualization
|
|
|
"""
|
|
@@ -1253,8 +1253,12 @@ class MiniGridEnv(gym.Env):
|
|
|
r.pop()
|
|
|
|
|
|
r.endFrame()
|
|
|
-
|
|
|
- return r.getPixmap()
|
|
|
+
|
|
|
+ if mode == 'rgb_array':
|
|
|
+ return r.getArray()
|
|
|
+ elif mode == 'pixmap':
|
|
|
+ return r.getPixmap()
|
|
|
+ return r
|
|
|
|
|
|
def render(self, mode='human', close=False, highlight=True):
|
|
|
"""
|
|
@@ -1335,5 +1339,4 @@ class MiniGridEnv(gym.Env):
|
|
|
return r.getArray()
|
|
|
elif mode == 'pixmap':
|
|
|
return r.getPixmap()
|
|
|
-
|
|
|
return r
|