|
@@ -2,12 +2,14 @@ import hashlib
|
|
|
import math
|
|
|
from abc import abstractmethod
|
|
|
from enum import IntEnum
|
|
|
+from functools import partial
|
|
|
from typing import Any, Callable, Optional, Union
|
|
|
|
|
|
import gym
|
|
|
import numpy as np
|
|
|
from gym import spaces
|
|
|
from gym.utils import seeding
|
|
|
+from gym.utils.renderer import Renderer
|
|
|
|
|
|
# Size in pixels of a tile in the full-scale human view
|
|
|
from gym_minigrid.rendering import (
|
|
@@ -865,6 +867,7 @@ class MiniGridEnv(gym.Env):
|
|
|
self.render_mode = render_mode
|
|
|
self.highlight = highlight
|
|
|
self.tile_size = tile_size
|
|
|
+ # Agent's Point of View
|
|
|
self.agent_pov = agent_pov
|
|
|
|
|
|
# Initialize mission
|
|
@@ -922,6 +925,10 @@ class MiniGridEnv(gym.Env):
|
|
|
self.grid = Grid(width, height)
|
|
|
self.carrying = None
|
|
|
|
|
|
+ render_frame = partial(self._render, highlight=highlight, tile_size=tile_size)
|
|
|
+
|
|
|
+ self.renderer = Renderer(self.render_mode, render_frame)
|
|
|
+
|
|
|
# Initialize the state
|
|
|
self.reset()
|
|
|
|
|
@@ -955,6 +962,10 @@ class MiniGridEnv(gym.Env):
|
|
|
# Return first observation
|
|
|
obs = self.gen_obs()
|
|
|
|
|
|
+ # Reset Renderer
|
|
|
+ self.renderer.reset()
|
|
|
+ self.renderer.render_step()
|
|
|
+
|
|
|
if not return_info:
|
|
|
return obs
|
|
|
else:
|
|
@@ -1375,6 +1386,9 @@ class MiniGridEnv(gym.Env):
|
|
|
|
|
|
obs = self.gen_obs()
|
|
|
|
|
|
+ # Reset Renderer
|
|
|
+ self.renderer.render_step()
|
|
|
+
|
|
|
return obs, reward, terminated, truncated, {}
|
|
|
|
|
|
def gen_obs_grid(self, agent_view_size=None):
|
|
@@ -1432,7 +1446,7 @@ class MiniGridEnv(gym.Env):
|
|
|
|
|
|
return obs
|
|
|
|
|
|
- def get_pov_render(self):
|
|
|
+ def get_pov_render(self, tile_size):
|
|
|
"""
|
|
|
Render an agent's POV observation for visualization
|
|
|
"""
|
|
@@ -1440,7 +1454,7 @@ class MiniGridEnv(gym.Env):
|
|
|
|
|
|
# Render the whole grid
|
|
|
img = grid.render(
|
|
|
- self.tile_size,
|
|
|
+ tile_size,
|
|
|
agent_pos=(self.agent_view_size // 2, self.agent_view_size - 1),
|
|
|
agent_dir=3,
|
|
|
highlight_mask=vis_mask,
|
|
@@ -1448,13 +1462,10 @@ class MiniGridEnv(gym.Env):
|
|
|
|
|
|
return img
|
|
|
|
|
|
- def get_full_render(self):
|
|
|
+ def get_full_render(self, highlight, tile_size):
|
|
|
"""
|
|
|
Render a non-paratial observation for visualization
|
|
|
"""
|
|
|
- tile_size = self.tile_size
|
|
|
- highlight = self.highlight
|
|
|
-
|
|
|
# Compute which cells are visible to the agent
|
|
|
_, vis_mask = self.gen_obs_grid()
|
|
|
|
|
@@ -1499,18 +1510,20 @@ class MiniGridEnv(gym.Env):
|
|
|
|
|
|
return img
|
|
|
|
|
|
- def get_render(self):
|
|
|
+ def get_render(self, highlight, tile_size):
|
|
|
"Returns an image corresponding to the whole environment or the agent's pov"
|
|
|
if self.agent_pov:
|
|
|
- return self.get_pov_render()
|
|
|
+ return self.get_pov_render(tile_size)
|
|
|
else:
|
|
|
- return self.get_full_render()
|
|
|
+ return self.get_full_render(highlight, tile_size)
|
|
|
|
|
|
- def render(self):
|
|
|
+ def _render(
|
|
|
+ self, mode: str = "human", highlight: bool = True, tile_size: int = TILE_PIXELS
|
|
|
+ ):
|
|
|
"""
|
|
|
Render the whole-grid human view
|
|
|
"""
|
|
|
- img = self.get_render()
|
|
|
+ img = self.get_render(highlight, tile_size)
|
|
|
|
|
|
mode = self.render_mode
|
|
|
if mode == "human":
|
|
@@ -1522,6 +1535,22 @@ class MiniGridEnv(gym.Env):
|
|
|
else:
|
|
|
return img
|
|
|
|
|
|
+ def render(
|
|
|
+ self,
|
|
|
+ mode: str = "human",
|
|
|
+ highlight: Optional[bool] = None,
|
|
|
+ tile_size: Optional[int] = None,
|
|
|
+ ):
|
|
|
+ if self.render_mode is not None:
|
|
|
+ assert (
|
|
|
+ highlight is None and tile_size is None
|
|
|
+ ), "Unexpected argument for render. Specify render arguments at environment initialization."
|
|
|
+ return self.renderer.get_renders()
|
|
|
+ else:
|
|
|
+ highlight = highlight if highlight is not None else True
|
|
|
+ tile_size = tile_size if tile_size is not None else TILE_PIXELS
|
|
|
+ return self._render(mode=mode, highlight=highlight, tile_size=tile_size)
|
|
|
+
|
|
|
def close(self):
|
|
|
if self.window:
|
|
|
self.window.close()
|