|
@@ -2,12 +2,14 @@ import hashlib
|
|
|
import math
|
|
|
from abc import abstractmethod
|
|
|
from enum import IntEnum
|
|
|
+from functools import partial
|
|
|
from typing import Any, Callable, Optional, Union
|
|
|
|
|
|
import gym
|
|
|
import numpy as np
|
|
|
from gym import spaces
|
|
|
from gym.utils import seeding
|
|
|
+from gym.utils.renderer import Renderer
|
|
|
|
|
|
# Size in pixels of a tile in the full-scale human view
|
|
|
from gym_minigrid.rendering import (
|
|
@@ -825,9 +827,6 @@ class MiniGridEnv(gym.Env):
|
|
|
"""
|
|
|
|
|
|
metadata = {
|
|
|
- # Deprecated: use 'render_modes' instead
|
|
|
- "render.modes": ["human", "rgb_array"],
|
|
|
- "video.frames_per_second": 10, # Deprecated: use 'render_fps' instead
|
|
|
"render_modes": ["human", "rgb_array", "single_rgb_array"],
|
|
|
"render_fps": 10,
|
|
|
}
|
|
@@ -862,13 +861,9 @@ class MiniGridEnv(gym.Env):
|
|
|
highlight: bool = True,
|
|
|
tile_size: int = TILE_PIXELS,
|
|
|
agent_pov: bool = False,
|
|
|
- **kwargs,
|
|
|
):
|
|
|
# Rendering attributes
|
|
|
self.render_mode = render_mode
|
|
|
- self.highlight = highlight
|
|
|
- self.tile_size = tile_size
|
|
|
- self.agent_pov = agent_pov
|
|
|
|
|
|
# Initialize mission
|
|
|
self.mission = mission_space.sample()
|
|
@@ -925,8 +920,11 @@ class MiniGridEnv(gym.Env):
|
|
|
self.grid = Grid(width, height)
|
|
|
self.carrying = None
|
|
|
|
|
|
- # Initialize the state
|
|
|
- self.reset()
|
|
|
+ frame_rendering = partial(
|
|
|
+ self._render, highlight=highlight, tile_size=tile_size, agent_pov=agent_pov
|
|
|
+ )
|
|
|
+
|
|
|
+ self.renderer = Renderer(self.render_mode, frame_rendering)
|
|
|
|
|
|
def reset(self, *, seed=None, return_info=False, options=None):
|
|
|
super().reset(seed=seed)
|
|
@@ -958,6 +956,10 @@ class MiniGridEnv(gym.Env):
|
|
|
# Return first observation
|
|
|
obs = self.gen_obs()
|
|
|
|
|
|
+ # Reset Renderer
|
|
|
+ self.renderer.reset()
|
|
|
+ self.renderer.render_step()
|
|
|
+
|
|
|
if not return_info:
|
|
|
return obs
|
|
|
else:
|
|
@@ -1378,6 +1380,9 @@ class MiniGridEnv(gym.Env):
|
|
|
|
|
|
obs = self.gen_obs()
|
|
|
|
|
|
+ # Reset Renderer
|
|
|
+ self.renderer.render_step()
|
|
|
+
|
|
|
return obs, reward, terminated, truncated, {}
|
|
|
|
|
|
def gen_obs_grid(self, agent_view_size=None):
|
|
@@ -1435,7 +1440,7 @@ class MiniGridEnv(gym.Env):
|
|
|
|
|
|
return obs
|
|
|
|
|
|
- def get_pov_render(self):
|
|
|
+ def get_pov_render(self, tile_size):
|
|
|
"""
|
|
|
Render an agent's POV observation for visualization
|
|
|
"""
|
|
@@ -1443,7 +1448,7 @@ class MiniGridEnv(gym.Env):
|
|
|
|
|
|
# Render the whole grid
|
|
|
img = grid.render(
|
|
|
- self.tile_size,
|
|
|
+ tile_size,
|
|
|
agent_pos=(self.agent_view_size // 2, self.agent_view_size - 1),
|
|
|
agent_dir=3,
|
|
|
highlight_mask=vis_mask,
|
|
@@ -1451,13 +1456,10 @@ class MiniGridEnv(gym.Env):
|
|
|
|
|
|
return img
|
|
|
|
|
|
- def get_full_render(self):
|
|
|
+ def get_full_render(self, highlight, tile_size):
|
|
|
"""
|
|
|
Render a non-paratial observation for visualization
|
|
|
"""
|
|
|
- tile_size = self.tile_size
|
|
|
- highlight = self.highlight
|
|
|
-
|
|
|
# Compute which cells are visible to the agent
|
|
|
_, vis_mask = self.gen_obs_grid()
|
|
|
|
|
@@ -1502,18 +1504,36 @@ class MiniGridEnv(gym.Env):
|
|
|
|
|
|
return img
|
|
|
|
|
|
- def get_render(self):
|
|
|
- "Returns an image corresponding to the whole environment or the agent's pov"
|
|
|
- if self.agent_pov:
|
|
|
- return self.get_pov_render()
|
|
|
+ def get_frame(
|
|
|
+ self,
|
|
|
+ highlight: bool = True,
|
|
|
+ tile_size: int = TILE_PIXELS,
|
|
|
+ agent_pov: bool = False,
|
|
|
+ ):
|
|
|
+ """Returns an RGB image corresponding to the whole environment or the agent's point of view.
|
|
|
+
|
|
|
+ Args:
|
|
|
+
|
|
|
+ highlight (bool): If true, the agent's field of view or point of view is highlighted with a lighter gray color.
|
|
|
+ tile_size (int): How many pixels will form a tile from the NxM grid.
|
|
|
+ agent_pov (bool): If true, the rendered frame will only contain the point of view of the agent.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+
|
|
|
+ frame (np.ndarray): A frame of type numpy.ndarray with shape (x, y, 3) representing RGB values for the x-by-y pixel image.
|
|
|
+
|
|
|
+ """
|
|
|
+
|
|
|
+ if agent_pov:
|
|
|
+ return self.get_pov_render(tile_size)
|
|
|
else:
|
|
|
- return self.get_full_render()
|
|
|
+ return self.get_full_render(highlight, tile_size)
|
|
|
|
|
|
- def render(self):
|
|
|
+ def _render(self, mode: str = "human", **kwargs):
|
|
|
"""
|
|
|
Render the whole-grid human view
|
|
|
"""
|
|
|
- img = self.get_render()
|
|
|
+ img = self.get_frame(**kwargs)
|
|
|
|
|
|
mode = self.render_mode
|
|
|
if mode == "human":
|
|
@@ -1525,6 +1545,26 @@ class MiniGridEnv(gym.Env):
|
|
|
else:
|
|
|
return img
|
|
|
|
|
|
+ def render(
|
|
|
+ self,
|
|
|
+ mode: str = "human",
|
|
|
+ highlight: Optional[bool] = None,
|
|
|
+ tile_size: Optional[int] = None,
|
|
|
+ agent_pov: Optional[bool] = None,
|
|
|
+ ):
|
|
|
+ if self.render_mode is not None:
|
|
|
+ assert (
|
|
|
+ highlight is None and tile_size is None and agent_pov is None
|
|
|
+ ), "Unexpected argument for render. Specify render arguments at environment initialization."
|
|
|
+ return self.renderer.get_renders()
|
|
|
+ else:
|
|
|
+ highlight = highlight if highlight is not None else True
|
|
|
+ tile_size = tile_size if tile_size is not None else TILE_PIXELS
|
|
|
+ agent_pov = agent_pov if agent_pov is not None else False
|
|
|
+ return self._render(
|
|
|
+ mode=mode, highlight=highlight, tile_size=tile_size, agent_pov=agent_pov
|
|
|
+ )
|
|
|
+
|
|
|
def close(self):
|
|
|
if self.window:
|
|
|
self.window.close()
|