Browse Source

Merge pull request #220 from rodrigodelazcano/render_api

Render API
Mark Towers 2 years ago
parent
commit
704c046cee

+ 6 - 8
gym_minigrid/manual_control.py

@@ -10,14 +10,14 @@ def redraw(window, img):
     window.show_img(img)
     window.show_img(img)
 
 
 
 
-def reset(env, window):
-    _ = env.reset()
+def reset(env, window, seed=None):
+    _ = env.reset(seed=seed)
 
 
     if hasattr(env, "mission"):
     if hasattr(env, "mission"):
         print("Mission: %s" % env.mission)
         print("Mission: %s" % env.mission)
         window.set_caption(env.mission)
         window.set_caption(env.mission)
 
 
-    img = env.get_render()
+    img = env.get_frame()
 
 
     redraw(window, img)
     redraw(window, img)
 
 
@@ -33,7 +33,7 @@ def step(env, window, action):
         print("truncated!")
         print("truncated!")
         reset(env, window)
         reset(env, window)
     else:
     else:
-        img = env.get_full_render()
+        img = env.get_frame()
         redraw(window, img)
         redraw(window, img)
 
 
 
 
@@ -99,12 +99,9 @@ if __name__ == "__main__":
 
 
     args = parser.parse_args()
     args = parser.parse_args()
 
 
-    seed = None if args.seed == -1 else args.seed
     env = gym.make(
     env = gym.make(
         args.env,
         args.env,
-        seed=seed,
         new_step_api=True,
         new_step_api=True,
-        render_mode="human",  # Note that we do not need to use "human", as Window handles human rendering.
         tile_size=args.tile_size,
         tile_size=args.tile_size,
     )
     )
 
 
@@ -115,7 +112,8 @@ if __name__ == "__main__":
     window = Window("gym_minigrid - " + args.env)
     window = Window("gym_minigrid - " + args.env)
     window.reg_key_handler(lambda event: key_handler(env, window, event))
     window.reg_key_handler(lambda event: key_handler(env, window, event))
 
 
-    reset(env, window)
+    seed = None if args.seed == -1 else args.seed
+    reset(env, window, seed)
 
 
     # Blocking event loop
     # Blocking event loop
     window.show(block=True)
     window.show(block=True)

+ 62 - 22
gym_minigrid/minigrid.py

@@ -2,12 +2,14 @@ import hashlib
 import math
 import math
 from abc import abstractmethod
 from abc import abstractmethod
 from enum import IntEnum
 from enum import IntEnum
+from functools import partial
 from typing import Any, Callable, Optional, Union
 from typing import Any, Callable, Optional, Union
 
 
 import gym
 import gym
 import numpy as np
 import numpy as np
 from gym import spaces
 from gym import spaces
 from gym.utils import seeding
 from gym.utils import seeding
+from gym.utils.renderer import Renderer
 
 
 # Size in pixels of a tile in the full-scale human view
 # Size in pixels of a tile in the full-scale human view
 from gym_minigrid.rendering import (
 from gym_minigrid.rendering import (
@@ -825,9 +827,6 @@ class MiniGridEnv(gym.Env):
     """
     """
 
 
     metadata = {
     metadata = {
-        # Deprecated: use 'render_modes' instead
-        "render.modes": ["human", "rgb_array"],
-        "video.frames_per_second": 10,  # Deprecated: use 'render_fps' instead
         "render_modes": ["human", "rgb_array", "single_rgb_array"],
         "render_modes": ["human", "rgb_array", "single_rgb_array"],
         "render_fps": 10,
         "render_fps": 10,
     }
     }
@@ -862,13 +861,9 @@ class MiniGridEnv(gym.Env):
         highlight: bool = True,
         highlight: bool = True,
         tile_size: int = TILE_PIXELS,
         tile_size: int = TILE_PIXELS,
         agent_pov: bool = False,
         agent_pov: bool = False,
-        **kwargs,
     ):
     ):
         # Rendering attributes
         # Rendering attributes
         self.render_mode = render_mode
         self.render_mode = render_mode
-        self.highlight = highlight
-        self.tile_size = tile_size
-        self.agent_pov = agent_pov
 
 
         # Initialize mission
         # Initialize mission
         self.mission = mission_space.sample()
         self.mission = mission_space.sample()
@@ -925,8 +920,11 @@ class MiniGridEnv(gym.Env):
         self.grid = Grid(width, height)
         self.grid = Grid(width, height)
         self.carrying = None
         self.carrying = None
 
 
-        # Initialize the state
-        self.reset()
+        frame_rendering = partial(
+            self._render, highlight=highlight, tile_size=tile_size, agent_pov=agent_pov
+        )
+
+        self.renderer = Renderer(self.render_mode, frame_rendering)
 
 
     def reset(self, *, seed=None, return_info=False, options=None):
     def reset(self, *, seed=None, return_info=False, options=None):
         super().reset(seed=seed)
         super().reset(seed=seed)
@@ -958,6 +956,10 @@ class MiniGridEnv(gym.Env):
         # Return first observation
         # Return first observation
         obs = self.gen_obs()
         obs = self.gen_obs()
 
 
+        # Reset Renderer
+        self.renderer.reset()
+        self.renderer.render_step()
+
         if not return_info:
         if not return_info:
             return obs
             return obs
         else:
         else:
@@ -1378,6 +1380,9 @@ class MiniGridEnv(gym.Env):
 
 
         obs = self.gen_obs()
         obs = self.gen_obs()
 
 
+        # Reset Renderer
+        self.renderer.render_step()
+
         return obs, reward, terminated, truncated, {}
         return obs, reward, terminated, truncated, {}
 
 
     def gen_obs_grid(self, agent_view_size=None):
     def gen_obs_grid(self, agent_view_size=None):
@@ -1435,7 +1440,7 @@ class MiniGridEnv(gym.Env):
 
 
         return obs
         return obs
 
 
-    def get_pov_render(self):
+    def get_pov_render(self, tile_size):
         """
         """
         Render an agent's POV observation for visualization
         Render an agent's POV observation for visualization
         """
         """
@@ -1443,7 +1448,7 @@ class MiniGridEnv(gym.Env):
 
 
         # Render the whole grid
         # Render the whole grid
         img = grid.render(
         img = grid.render(
-            self.tile_size,
+            tile_size,
             agent_pos=(self.agent_view_size // 2, self.agent_view_size - 1),
             agent_pos=(self.agent_view_size // 2, self.agent_view_size - 1),
             agent_dir=3,
             agent_dir=3,
             highlight_mask=vis_mask,
             highlight_mask=vis_mask,
@@ -1451,13 +1456,10 @@ class MiniGridEnv(gym.Env):
 
 
         return img
         return img
 
 
-    def get_full_render(self):
+    def get_full_render(self, highlight, tile_size):
         """
         """
         Render a non-paratial observation for visualization
         Render a non-paratial observation for visualization
         """
         """
-        tile_size = self.tile_size
-        highlight = self.highlight
-
         # Compute which cells are visible to the agent
         # Compute which cells are visible to the agent
         _, vis_mask = self.gen_obs_grid()
         _, vis_mask = self.gen_obs_grid()
 
 
@@ -1502,18 +1504,36 @@ class MiniGridEnv(gym.Env):
 
 
         return img
         return img
 
 
-    def get_render(self):
-        "Returns an image corresponding to the whole environment or the agent's pov"
-        if self.agent_pov:
-            return self.get_pov_render()
+    def get_frame(
+        self,
+        highlight: bool = True,
+        tile_size: int = TILE_PIXELS,
+        agent_pov: bool = False,
+    ):
+        """Returns an RGB image corresponding to the whole environment or the agent's point of view.
+
+        Args:
+
+            highlight (bool): If true, the agent's field of view or point of view is highlighted with a lighter gray color.
+            tile_size (int): How many pixels will form a tile from the NxM grid.
+            agent_pov (bool): If true, the rendered frame will only contain the point of view of the agent.
+
+        Returns:
+
+            frame (np.ndarray): A frame of type numpy.ndarray with shape (x, y, 3) representing RGB values for the x-by-y pixel image.
+
+        """
+
+        if agent_pov:
+            return self.get_pov_render(tile_size)
         else:
         else:
-            return self.get_full_render()
+            return self.get_full_render(highlight, tile_size)
 
 
-    def render(self):
+    def _render(self, mode: str = "human", **kwargs):
         """
         """
         Render the whole-grid human view
         Render the whole-grid human view
         """
         """
-        img = self.get_render()
+        img = self.get_frame(**kwargs)
 
 
         mode = self.render_mode
         mode = self.render_mode
         if mode == "human":
         if mode == "human":
@@ -1525,6 +1545,26 @@ class MiniGridEnv(gym.Env):
         else:
         else:
             return img
             return img
 
 
+    def render(
+        self,
+        mode: str = "human",
+        highlight: Optional[bool] = None,
+        tile_size: Optional[int] = None,
+        agent_pov: Optional[bool] = None,
+    ):
+        if self.render_mode is not None:
+            assert (
+                highlight is None and tile_size is None and agent_pov is None
+            ), "Unexpected argument for render. Specify render arguments at environment initialization."
+            return self.renderer.get_renders()
+        else:
+            highlight = highlight if highlight is not None else True
+            tile_size = tile_size if tile_size is not None else TILE_PIXELS
+            agent_pov = agent_pov if agent_pov is not None else False
+            return self._render(
+                mode=mode, highlight=highlight, tile_size=tile_size, agent_pov=agent_pov
+            )
+
     def close(self):
     def close(self):
         if self.window:
         if self.window:
             self.window.close()
             self.window.close()

+ 0 - 1
gym_minigrid/window.py

@@ -88,6 +88,5 @@ class Window:
         """
         """
         Close the window
         Close the window
         """
         """
-
         plt.close()
         plt.close()
         self.closed = True
         self.closed = True

+ 3 - 8
gym_minigrid/wrappers.py

@@ -164,10 +164,6 @@ class RGBImgObsWrapper(ObservationWrapper):
     def __init__(self, env, tile_size=8):
     def __init__(self, env, tile_size=8):
         super().__init__(env, new_step_api=env.new_step_api)
         super().__init__(env, new_step_api=env.new_step_api)
 
 
-        # Rendering attributes
-        self.highlight = True
-        self.tile_size = tile_size
-
         self.tile_size = tile_size
         self.tile_size = tile_size
 
 
         new_image_space = spaces.Box(
         new_image_space = spaces.Box(
@@ -182,7 +178,7 @@ class RGBImgObsWrapper(ObservationWrapper):
         )
         )
 
 
     def observation(self, obs):
     def observation(self, obs):
-        rgb_img = self.get_full_render()
+        rgb_img = self.get_frame(highlight=True, tile_size=self.tile_size)
 
 
         return {**obs, "image": rgb_img}
         return {**obs, "image": rgb_img}
 
 
@@ -196,8 +192,7 @@ class RGBImgPartialObsWrapper(ObservationWrapper):
     def __init__(self, env, tile_size=8):
     def __init__(self, env, tile_size=8):
         super().__init__(env, new_step_api=env.new_step_api)
         super().__init__(env, new_step_api=env.new_step_api)
 
 
-        # Rendering attributes
-        self.unwrapped.agent_pov = True
+        # Rendering attributes for observations
         self.tile_size = tile_size
         self.tile_size = tile_size
 
 
         obs_shape = env.observation_space.spaces["image"].shape
         obs_shape = env.observation_space.spaces["image"].shape
@@ -213,7 +208,7 @@ class RGBImgPartialObsWrapper(ObservationWrapper):
         )
         )
 
 
     def observation(self, obs):
     def observation(self, obs):
-        rgb_img_partial = self.get_pov_render()
+        rgb_img_partial = self.get_frame(tile_size=self.tile_size, agent_pov=True)
 
 
         return {**obs, "image": rgb_img_partial}
         return {**obs, "image": rgb_img_partial}
 
 

+ 2 - 3
tests/test_envs.py

@@ -121,12 +121,11 @@ def test_agent_sees_method(env_id):
     env = gym.make(env_id, new_step_api=True)
     env = gym.make(env_id, new_step_api=True)
     goal_pos = (env.grid.width - 2, env.grid.height - 2)
     goal_pos = (env.grid.width - 2, env.grid.height - 2)
 
 
+    # Test the env.agent_sees() function
+    env.reset()
     # Test the "in" operator on grid objects
     # Test the "in" operator on grid objects
     assert ("green", "goal") in env.grid
     assert ("green", "goal") in env.grid
     assert ("blue", "key") not in env.grid
     assert ("blue", "key") not in env.grid
-
-    # Test the env.agent_sees() function
-    env.reset()
     for i in range(0, 500):
     for i in range(0, 500):
         action = env.action_space.sample()
         action = env.action_space.sample()
         obs, reward, terminated, truncated, info = env.step(action)
         obs, reward, terminated, truncated, info = env.step(action)

+ 1 - 1
tests/test_scripts.py

@@ -23,7 +23,7 @@ def test_window():
     caption = "testing caption"
     caption = "testing caption"
     window.set_caption(caption)
     window.set_caption(caption)
 
 
-    window.show()
+    window.show(block=False)
 
 
     window.close()
     window.close()