瀏覽代碼

Merge pull request #220 from rodrigodelazcano/render_api

Render API
Mark Towers 2 年之前
父節點
當前提交
704c046cee
共有 6 個文件被更改,包括 74 次插入43 次删除
  1. 6 8
      gym_minigrid/manual_control.py
  2. 62 22
      gym_minigrid/minigrid.py
  3. 0 1
      gym_minigrid/window.py
  4. 3 8
      gym_minigrid/wrappers.py
  5. 2 3
      tests/test_envs.py
  6. 1 1
      tests/test_scripts.py

+ 6 - 8
gym_minigrid/manual_control.py

@@ -10,14 +10,14 @@ def redraw(window, img):
     window.show_img(img)
 
 
-def reset(env, window):
-    _ = env.reset()
+def reset(env, window, seed=None):
+    _ = env.reset(seed=seed)
 
     if hasattr(env, "mission"):
         print("Mission: %s" % env.mission)
         window.set_caption(env.mission)
 
-    img = env.get_render()
+    img = env.get_frame()
 
     redraw(window, img)
 
@@ -33,7 +33,7 @@ def step(env, window, action):
         print("truncated!")
         reset(env, window)
     else:
-        img = env.get_full_render()
+        img = env.get_frame()
         redraw(window, img)
 
 
@@ -99,12 +99,9 @@ if __name__ == "__main__":
 
     args = parser.parse_args()
 
-    seed = None if args.seed == -1 else args.seed
     env = gym.make(
         args.env,
-        seed=seed,
         new_step_api=True,
-        render_mode="human",  # Note that we do not need to use "human", as Window handles human rendering.
         tile_size=args.tile_size,
     )
 
@@ -115,7 +112,8 @@ if __name__ == "__main__":
     window = Window("gym_minigrid - " + args.env)
     window.reg_key_handler(lambda event: key_handler(env, window, event))
 
-    reset(env, window)
+    seed = None if args.seed == -1 else args.seed
+    reset(env, window, seed)
 
     # Blocking event loop
     window.show(block=True)

+ 62 - 22
gym_minigrid/minigrid.py

@@ -2,12 +2,14 @@ import hashlib
 import math
 from abc import abstractmethod
 from enum import IntEnum
+from functools import partial
 from typing import Any, Callable, Optional, Union
 
 import gym
 import numpy as np
 from gym import spaces
 from gym.utils import seeding
+from gym.utils.renderer import Renderer
 
 # Size in pixels of a tile in the full-scale human view
 from gym_minigrid.rendering import (
@@ -825,9 +827,6 @@ class MiniGridEnv(gym.Env):
     """
 
     metadata = {
-        # Deprecated: use 'render_modes' instead
-        "render.modes": ["human", "rgb_array"],
-        "video.frames_per_second": 10,  # Deprecated: use 'render_fps' instead
         "render_modes": ["human", "rgb_array", "single_rgb_array"],
         "render_fps": 10,
     }
@@ -862,13 +861,9 @@ class MiniGridEnv(gym.Env):
         highlight: bool = True,
         tile_size: int = TILE_PIXELS,
         agent_pov: bool = False,
-        **kwargs,
     ):
         # Rendering attributes
         self.render_mode = render_mode
-        self.highlight = highlight
-        self.tile_size = tile_size
-        self.agent_pov = agent_pov
 
         # Initialize mission
         self.mission = mission_space.sample()
@@ -925,8 +920,11 @@ class MiniGridEnv(gym.Env):
         self.grid = Grid(width, height)
         self.carrying = None
 
-        # Initialize the state
-        self.reset()
+        frame_rendering = partial(
+            self._render, highlight=highlight, tile_size=tile_size, agent_pov=agent_pov
+        )
+
+        self.renderer = Renderer(self.render_mode, frame_rendering)
 
     def reset(self, *, seed=None, return_info=False, options=None):
         super().reset(seed=seed)
@@ -958,6 +956,10 @@ class MiniGridEnv(gym.Env):
         # Return first observation
         obs = self.gen_obs()
 
+        # Reset Renderer
+        self.renderer.reset()
+        self.renderer.render_step()
+
         if not return_info:
             return obs
         else:
@@ -1378,6 +1380,9 @@ class MiniGridEnv(gym.Env):
 
         obs = self.gen_obs()
 
+        # Reset Renderer
+        self.renderer.render_step()
+
         return obs, reward, terminated, truncated, {}
 
     def gen_obs_grid(self, agent_view_size=None):
@@ -1435,7 +1440,7 @@ class MiniGridEnv(gym.Env):
 
         return obs
 
-    def get_pov_render(self):
+    def get_pov_render(self, tile_size):
         """
         Render an agent's POV observation for visualization
         """
@@ -1443,7 +1448,7 @@ class MiniGridEnv(gym.Env):
 
         # Render the whole grid
         img = grid.render(
-            self.tile_size,
+            tile_size,
             agent_pos=(self.agent_view_size // 2, self.agent_view_size - 1),
             agent_dir=3,
             highlight_mask=vis_mask,
@@ -1451,13 +1456,10 @@ class MiniGridEnv(gym.Env):
 
         return img
 
-    def get_full_render(self):
+    def get_full_render(self, highlight, tile_size):
         """
         Render a non-paratial observation for visualization
         """
-        tile_size = self.tile_size
-        highlight = self.highlight
-
         # Compute which cells are visible to the agent
         _, vis_mask = self.gen_obs_grid()
 
@@ -1502,18 +1504,36 @@ class MiniGridEnv(gym.Env):
 
         return img
 
-    def get_render(self):
-        "Returns an image corresponding to the whole environment or the agent's pov"
-        if self.agent_pov:
-            return self.get_pov_render()
+    def get_frame(
+        self,
+        highlight: bool = True,
+        tile_size: int = TILE_PIXELS,
+        agent_pov: bool = False,
+    ):
+        """Returns an RGB image corresponding to the whole environment or the agent's point of view.
+
+        Args:
+
+            highlight (bool): If true, the agent's field of view or point of view is highlighted with a lighter gray color.
+            tile_size (int): How many pixels will form a tile from the NxM grid.
+            agent_pov (bool): If true, the rendered frame will only contain the point of view of the agent.
+
+        Returns:
+
+            frame (np.ndarray): A frame of type numpy.ndarray with shape (x, y, 3) representing RGB values for the x-by-y pixel image.
+
+        """
+
+        if agent_pov:
+            return self.get_pov_render(tile_size)
         else:
-            return self.get_full_render()
+            return self.get_full_render(highlight, tile_size)
 
-    def render(self):
+    def _render(self, mode: str = "human", **kwargs):
         """
         Render the whole-grid human view
         """
-        img = self.get_render()
+        img = self.get_frame(**kwargs)
 
         mode = self.render_mode
         if mode == "human":
@@ -1525,6 +1545,26 @@ class MiniGridEnv(gym.Env):
         else:
             return img
 
+    def render(
+        self,
+        mode: str = "human",
+        highlight: Optional[bool] = None,
+        tile_size: Optional[int] = None,
+        agent_pov: Optional[bool] = None,
+    ):
+        if self.render_mode is not None:
+            assert (
+                highlight is None and tile_size is None and agent_pov is None
+            ), "Unexpected argument for render. Specify render arguments at environment initialization."
+            return self.renderer.get_renders()
+        else:
+            highlight = highlight if highlight is not None else True
+            tile_size = tile_size if tile_size is not None else TILE_PIXELS
+            agent_pov = agent_pov if agent_pov is not None else False
+            return self._render(
+                mode=mode, highlight=highlight, tile_size=tile_size, agent_pov=agent_pov
+            )
+
     def close(self):
         if self.window:
             self.window.close()

+ 0 - 1
gym_minigrid/window.py

@@ -88,6 +88,5 @@ class Window:
         """
         Close the window
         """
-
         plt.close()
         self.closed = True

+ 3 - 8
gym_minigrid/wrappers.py

@@ -164,10 +164,6 @@ class RGBImgObsWrapper(ObservationWrapper):
     def __init__(self, env, tile_size=8):
         super().__init__(env, new_step_api=env.new_step_api)
 
-        # Rendering attributes
-        self.highlight = True
-        self.tile_size = tile_size
-
         self.tile_size = tile_size
 
         new_image_space = spaces.Box(
@@ -182,7 +178,7 @@ class RGBImgObsWrapper(ObservationWrapper):
         )
 
     def observation(self, obs):
-        rgb_img = self.get_full_render()
+        rgb_img = self.get_frame(highlight=True, tile_size=self.tile_size)
 
         return {**obs, "image": rgb_img}
 
@@ -196,8 +192,7 @@ class RGBImgPartialObsWrapper(ObservationWrapper):
     def __init__(self, env, tile_size=8):
         super().__init__(env, new_step_api=env.new_step_api)
 
-        # Rendering attributes
-        self.unwrapped.agent_pov = True
+        # Rendering attributes for observations
         self.tile_size = tile_size
 
         obs_shape = env.observation_space.spaces["image"].shape
@@ -213,7 +208,7 @@ class RGBImgPartialObsWrapper(ObservationWrapper):
         )
 
     def observation(self, obs):
-        rgb_img_partial = self.get_pov_render()
+        rgb_img_partial = self.get_frame(tile_size=self.tile_size, agent_pov=True)
 
         return {**obs, "image": rgb_img_partial}
 

+ 2 - 3
tests/test_envs.py

@@ -121,12 +121,11 @@ def test_agent_sees_method(env_id):
     env = gym.make(env_id, new_step_api=True)
     goal_pos = (env.grid.width - 2, env.grid.height - 2)
 
+    # Test the env.agent_sees() function
+    env.reset()
     # Test the "in" operator on grid objects
     assert ("green", "goal") in env.grid
     assert ("blue", "key") not in env.grid
-
-    # Test the env.agent_sees() function
-    env.reset()
     for i in range(0, 500):
         action = env.action_space.sample()
         obs, reward, terminated, truncated, info = env.step(action)

+ 1 - 1
tests/test_scripts.py

@@ -23,7 +23,7 @@ def test_window():
     caption = "testing caption"
     window.set_caption(caption)
 
-    window.show()
+    window.show(block=False)
 
     window.close()