Browse Source

Add pygame rendering support (#313)

Elliot Tower 2 years ago
parent
commit
7178af1540
9 changed files with 211 additions and 166 deletions
  1. 26 0
      docker_entrypoint
  2. 71 3
      minigrid/benchmark.py
  3. 53 33
      minigrid/manual_control.py
  4. 46 7
      minigrid/minigrid_env.py
  5. 0 92
      minigrid/utils/window.py
  6. 8 1
      py.Dockerfile
  7. 2 1
      pyproject.toml
  8. 1 1
      requirements.txt
  9. 4 28
      tests/test_scripts.py

+ 26 - 0
docker_entrypoint

@@ -0,0 +1,26 @@
+#!/bin/bash
+# This script is the entrypoint for our Docker image.
+
+set -ex
+
+# Set up display; otherwise rendering will fail
+Xvfb -screen 0 1024x768x24 &
+export DISPLAY=:0
+
+# Wait for the file to come up
+display=0
+file="/tmp/.X11-unix/X$display"
+for i in $(seq 1 10); do
+    if [ -e "$file" ]; then
+	break
+    fi
+
+    echo "Waiting for $file to be created (try $i/10)"
+    sleep "$i"
+done
+if ! [ -e "$file" ]; then
+    echo "Timing out: $file was not created"
+    exit 1
+fi
+
+exec "$@"

+ 71 - 3
minigrid/benchmark.py

@@ -6,6 +6,7 @@ import time
 
 import gymnasium as gym
 
+from minigrid.manual_control import ManualControl
 from minigrid.wrappers import ImgObsWrapper, RGBImgPartialObsWrapper
 
 
@@ -33,7 +34,7 @@ def benchmark(env_id, num_resets, num_frames):
     env = ImgObsWrapper(env)
 
     env.reset()
-    # Benchmark rendering
+    # Benchmark rendering in agent view
     t0 = time.time()
     for i in range(num_frames):
         obs, reward, terminated, truncated, info = env.step(0)
@@ -48,6 +49,49 @@ def benchmark(env_id, num_resets, num_frames):
     env.close()
 
 
+def benchmark_manual_control(env_id, num_resets, num_frames, tile_size):
+    env = gym.make(env_id, tile_size=tile_size)
+    env = ManualControl(env, seed=args.seed)
+
+    # Benchmark env.reset
+    t0 = time.time()
+    for i in range(num_resets):
+        env.reset()
+    t1 = time.time()
+    dt = t1 - t0
+    reset_time = (1000 * dt) / num_resets
+
+    # Benchmark rendering
+    t0 = time.time()
+    for i in range(num_frames):
+        env.redraw()
+    t1 = time.time()
+    dt = t1 - t0
+    frames_per_sec = num_frames / dt
+
+    # Create an environment with an RGB agent observation
+    env = gym.make(env_id, tile_size=tile_size)
+    env = RGBImgPartialObsWrapper(env, env.tile_size)
+    env = ImgObsWrapper(env)
+
+    env = ManualControl(env, seed=args.seed)
+    env.reset()
+
+    # Benchmark rendering in agent view
+    t0 = time.time()
+    for i in range(num_frames):
+        env.step(0)
+    t1 = time.time()
+    dt = t1 - t0
+    agent_view_fps = num_frames / dt
+
+    print(f"Env reset time: {reset_time:.1f} ms")
+    print(f"Rendering FPS : {frames_per_sec:.0f}")
+    print(f"Agent view FPS: {agent_view_fps:.0f}")
+
+    env.close()
+
+
 if __name__ == "__main__":
     import argparse
 
@@ -58,7 +102,31 @@ if __name__ == "__main__":
         help="gym environment to load",
         default="MiniGrid-LavaGapS7-v0",
     )
-    parser.add_argument("--num_resets", default=200)
-    parser.add_argument("--num_frames", default=5000)
+    parser.add_argument(
+        "--seed",
+        type=int,
+        help="random seed to generate the environment with",
+        default=None,
+    )
+    parser.add_argument(
+        "--num-resets",
+        type=int,
+        help="number of times to reset the environment for benchmarking",
+        default=200,
+    )
+    parser.add_argument(
+        "--num-frames",
+        type=int,
+        help="number of frames to test rendering for",
+        default=5000,
+    )
+    parser.add_argument(
+        "--tile-size", type=int, help="size at which to render tiles", default=32
+    )
+
     args = parser.parse_args()
     benchmark(args.env_id, args.num_resets, args.num_frames)
+
+    benchmark_manual_control(
+        args.env_id, args.num_resets, args.num_frames, args.tile_size
+    )

+ 53 - 33
minigrid/manual_control.py

@@ -3,34 +3,36 @@
 from __future__ import annotations
 
 import gymnasium as gym
+import pygame
+from gymnasium import Env
 
 from minigrid.core.actions import Actions
 from minigrid.minigrid_env import MiniGridEnv
-from minigrid.utils.window import Window
 from minigrid.wrappers import ImgObsWrapper, RGBImgPartialObsWrapper
 
 
 class ManualControl:
     def __init__(
         self,
-        env: MiniGridEnv,
-        agent_view: bool = False,
-        window: Window = None,
+        env: Env,
         seed=None,
     ) -> None:
         self.env = env
-        self.agent_view = agent_view
         self.seed = seed
-
-        if window is None:
-            window = Window("minigrid - " + str(env.__class__))
-        self.window = window
-        self.window.reg_key_handler(self.key_handler)
+        self.closed = False
 
     def start(self):
         """Start the window display with blocking event loop"""
         self.reset(self.seed)
-        self.window.show(block=True)
+
+        while not self.closed:
+            for event in pygame.event.get():
+                if event.type == pygame.QUIT:
+                    self.env.close()
+                    break
+                if event.type == pygame.KEYDOWN:
+                    event.key = pygame.key.name(int(event.key))
+                    self.key_handler(event)
 
     def step(self, action: Actions):
         _, reward, terminated, truncated, _ = self.env.step(action)
@@ -43,27 +45,18 @@ class ManualControl:
             print("truncated!")
             self.reset(self.seed)
         else:
-            self.redraw()
-
-    def redraw(self):
-        frame = self.env.get_frame(agent_pov=self.agent_view)
-        self.window.show_img(frame)
+            self.env.render()
 
     def reset(self, seed=None):
         self.env.reset(seed=seed)
-
-        if hasattr(self.env, "mission"):
-            print("Mission: %s" % self.env.mission)
-            self.window.set_caption(self.env.mission)
-
-        self.redraw()
+        self.env.render()
 
     def key_handler(self, event):
         key: str = event.key
         print("pressed", key)
 
         if key == "escape":
-            self.window.close()
+            self.env.close()
             return
         if key == "backspace":
             self.reset()
@@ -73,14 +66,18 @@ class ManualControl:
             "left": Actions.left,
             "right": Actions.right,
             "up": Actions.forward,
-            " ": Actions.toggle,
+            "space": Actions.toggle,
             "pageup": Actions.pickup,
             "pagedown": Actions.drop,
+            "tab": Actions.pickup,
+            "left shift": Actions.drop,
             "enter": Actions.done,
         }
-
-        action = key_to_action[key]
-        self.step(action)
+        if key in key_to_action.keys():
+            action = key_to_action[key]
+            self.step(action)
+        else:
+            print(key)
 
 
 if __name__ == "__main__":
@@ -88,7 +85,11 @@ if __name__ == "__main__":
 
     parser = argparse.ArgumentParser()
     parser.add_argument(
-        "--env", help="gym environment to load", default="MiniGrid-MultiRoom-N6-v0"
+        "--env-id",
+        type=str,
+        help="gym environment to load",
+        choices=gym.envs.registry.keys(),
+        default="MiniGrid-MultiRoom-N6-v0",
     )
     parser.add_argument(
         "--seed",
@@ -101,19 +102,38 @@ if __name__ == "__main__":
     )
     parser.add_argument(
         "--agent-view",
-        default=False,
-        help="draw the agent sees (partially observable view)",
         action="store_true",
+        help="draw the agent sees (partially observable view)",
+    )
+    parser.add_argument(
+        "--agent-view-size",
+        type=int,
+        default=7,
+        help="set the number of grid spaces visible in agent-view ",
+    )
+    parser.add_argument(
+        "--screen-size",
+        type=int,
+        default="640",
+        help="set the resolution for pygame rendering (width and height)",
     )
 
     args = parser.parse_args()
 
-    env: MiniGridEnv = gym.make(args.env, tile_size=args.tile_size)
+    env: MiniGridEnv = gym.make(
+        args.env_id,
+        tile_size=args.tile_size,
+        render_mode="human",
+        agent_pov=args.agent_view,
+        agent_view_size=args.agent_view_size,
+        screen_size=args.screen_size,
+    )
 
+    # TODO: check if this can be removed
     if args.agent_view:
         print("Using agent view")
-        env = RGBImgPartialObsWrapper(env, env.tile_size)
+        env = RGBImgPartialObsWrapper(env, args.tile_size)
         env = ImgObsWrapper(env)
 
-    manual_control = ManualControl(env, agent_view=args.agent_view, seed=args.seed)
+    manual_control = ManualControl(env, seed=args.seed)
     manual_control.start()

+ 46 - 7
minigrid/minigrid_env.py

@@ -7,6 +7,8 @@ from typing import Any, Iterable, SupportsFloat, TypeVar
 
 import gymnasium as gym
 import numpy as np
+import pygame
+import pygame.freetype
 from gymnasium import spaces
 from gymnasium.core import ActType, ObsType
 
@@ -15,7 +17,6 @@ from minigrid.core.constants import COLOR_NAMES, DIR_TO_VEC, TILE_PIXELS
 from minigrid.core.grid import Grid
 from minigrid.core.mission import MissionSpace
 from minigrid.core.world_object import Point, WorldObj
-from minigrid.utils.window import Window
 
 T = TypeVar("T")
 
@@ -40,6 +41,7 @@ class MiniGridEnv(gym.Env):
         see_through_walls: bool = False,
         agent_view_size: int = 7,
         render_mode: str | None = None,
+        screen_size: int | None = 1,
         highlight: bool = True,
         tile_size: int = TILE_PIXELS,
         agent_pov: bool = False,
@@ -84,7 +86,10 @@ class MiniGridEnv(gym.Env):
         # Range of possible rewards
         self.reward_range = (0, 1)
 
-        self.window: Window = None
+        self.screen_size = screen_size
+        self.render_size = None
+        self.window = None
+        self.clock = None
 
         # Environment configuration
         self.width = width
@@ -730,14 +735,48 @@ class MiniGridEnv(gym.Env):
         img = self.get_frame(self.highlight, self.tile_size, self.agent_pov)
 
         if self.render_mode == "human":
+            img = np.transpose(img, axes=(1, 0, 2))
+            if self.render_size is None:
+                self.render_size = img.shape[:2]
             if self.window is None:
-                self.window = Window("minigrid")
-                self.window.show(block=False)
-            self.window.set_caption(self.mission)
-            self.window.show_img(img)
+                pygame.init()
+                pygame.display.init()
+                self.window = pygame.display.set_mode(
+                    (self.screen_size, self.screen_size)
+                )
+                pygame.display.set_caption("minigrid")
+            if self.clock is None:
+                self.clock = pygame.time.Clock()
+            surf = pygame.surfarray.make_surface(img)
+
+            # Create background with mission description
+            offset = surf.get_size()[0] * 0.1
+            # offset = 32 if self.agent_pov else 64
+            bg = pygame.Surface(
+                (int(surf.get_size()[0] + offset), int(surf.get_size()[1] + offset))
+            )
+            bg.convert()
+            bg.fill((255, 255, 255))
+            bg.blit(surf, (offset / 2, 0))
+
+            bg = pygame.transform.smoothscale(bg, (self.screen_size, self.screen_size))
+
+            font_size = 22
+            text = self.mission
+            font = pygame.freetype.SysFont(pygame.font.get_default_font(), font_size)
+            text_rect = font.get_rect(text, size=font_size)
+            text_rect.center = bg.get_rect().center
+            text_rect.y = bg.get_height() - font_size * 1.5
+            font.render_to(bg, text_rect, text, size=font_size)
+
+            self.window.blit(bg, (0, 0))
+            pygame.event.pump()
+            self.clock.tick(self.metadata["render_fps"])
+            pygame.display.flip()
+
         elif self.render_mode == "rgb_array":
             return img
 
     def close(self):
         if self.window:
-            self.window.close()
+            pygame.quit()

+ 0 - 92
minigrid/utils/window.py

@@ -1,92 +0,0 @@
-# Only ask users to install matplotlib if they actually need it
-try:
-    import matplotlib.pyplot as plt
-except ImportError:
-    raise ImportError(
-        "To display the environment in a window, please install matplotlib, eg: `pip3 install --user matplotlib`"
-    )
-
-
-class Window:
-    """
-    Window to draw a gridworld instance using Matplotlib
-    """
-
-    def __init__(self, title):
-        self.no_image_shown = True
-
-        # Create the figure and axes
-        self.fig, self.ax = plt.subplots()
-
-        # Show the env name in the window title
-        self.fig.canvas.manager.set_window_title(title)
-
-        # Turn off x/y axis numbering/ticks
-        self.ax.xaxis.set_ticks_position("none")
-        self.ax.yaxis.set_ticks_position("none")
-        _ = self.ax.set_xticklabels([])
-        _ = self.ax.set_yticklabels([])
-
-        # Flag indicating the window was closed
-        self.closed = False
-
-        def close_handler(evt):
-            self.closed = True
-
-        self.fig.canvas.mpl_connect("close_event", close_handler)
-
-    def show_img(self, img):
-        """
-        Show an image or update the image being shown
-        """
-
-        # If no image has been shown yet,
-        # show the first image of the environment
-        if self.no_image_shown:
-            self.imshow_obj = self.ax.imshow(img, interpolation="bilinear")
-            self.no_image_shown = False
-        # Update the image data
-        self.imshow_obj.set_data(img)
-
-        # Request the window be redrawn
-        self.fig.canvas.draw_idle()
-        self.fig.canvas.flush_events()
-
-        # Let matplotlib process UI events
-        plt.pause(0.001)
-
-    def set_caption(self, text):
-        """
-        Set/update the caption text below the image
-        """
-
-        plt.xlabel(text)
-
-    def reg_key_handler(self, key_handler):
-        """
-        Register a keyboard event handler
-        """
-
-        # Keyboard handler
-        self.fig.canvas.mpl_connect("key_press_event", key_handler)
-
-    def show(self, block=True):
-        """
-        Show the window, and start an event loop
-        """
-
-        # If not blocking, trigger interactive mode
-        if not block:
-            plt.ion()
-
-        # Show the plot
-        # In non-interative mode, this enters the matplotlib event loop
-        # In interactive mode, this call does not block
-        plt.show()
-
-    def close(self):
-        """
-        Close the window
-        """
-        plt.close()
-        self.closed = True

+ 8 - 1
py.Dockerfile

@@ -4,9 +4,16 @@ FROM python:$PYTHON_VERSION
 
 SHELL ["/bin/bash", "-o", "pipefail", "-c"]
 
-RUN apt-get -y update
+RUN apt-get -y update \
+    && apt-get install --no-install-recommends -y \
+    xvfb
 
 COPY . /usr/local/minigrid/
 WORKDIR /usr/local/minigrid/
 
 RUN pip install .[testing] --no-cache-dir
+
+RUN ["chmod", "+x", "/usr/local/minigrid/docker_entrypoint"]
+
+ENTRYPOINT ["/usr/local/minigrid/docker_entrypoint"]
+

+ 2 - 1
pyproject.toml

@@ -26,7 +26,7 @@ classifiers = [
 dependencies = [
     "numpy>=1.18.0",
     "gymnasium>=0.26",
-    "matplotlib>=3.0",
+    "pygame>=2.2.0",
 ]
 dynamic = ["version"]
 
@@ -34,6 +34,7 @@ dynamic = ["version"]
 testing = [
     "pytest>=7.0.1",
     "pytest-mock>=3.10.0",
+    "matplotlib>=3.0"
 ]
 
 [project.urls]

+ 1 - 1
requirements.txt

@@ -1,3 +1,3 @@
 numpy>=1.18.0
 gymnasium>=0.26
-matplotlib>=3.0
+pygame>=2.2.0

+ 4 - 28
tests/test_scripts.py

@@ -7,7 +7,6 @@ from pytest_mock import MockerFixture
 from minigrid.benchmark import benchmark
 from minigrid.manual_control import ManualControl
 from minigrid.minigrid_env import MiniGridEnv
-from minigrid.utils.window import Window
 
 
 def test_benchmark():
@@ -16,22 +15,6 @@ def test_benchmark():
     benchmark(env_id, num_resets=10, num_frames=100)
 
 
-def test_window():
-    "Testing the class functions of window.Window. This should locally open a window !"
-    title = "testing window"
-    window = Window(title)
-
-    img = np.random.rand(100, 100, 3)
-    window.show_img(img)
-
-    caption = "testing caption"
-    window.set_caption(caption)
-
-    window.show(block=False)
-
-    window.close()
-
-
 def test_manual_control(mocker: MockerFixture):
     class FakeRandomKeyboardEvent:
         active_actions = ["left", "right", "up", " ", "pageup", "pagedown"]
@@ -48,26 +31,19 @@ def test_manual_control(mocker: MockerFixture):
             self.key = np.random.choice(self.active_actions)
 
     env_id = "MiniGrid-Empty-16x16-v0"
-    env: MiniGridEnv = gym.make(env_id)
-    window = mocker.MagicMock()
-    window.close = mocker.MagicMock()
-    window.set_caption = mocker.MagicMock()
-    manual_control = ManualControl(env, window=window)
+    env: MiniGridEnv = gym.make(env_id, render_mode="human")
+    manual_control = ManualControl(env)
 
     for i in range(3):  # 3 resets
-        mission = f"Mission {i}"
-        env.mission = mission
         manual_control.reset()
-        window.set_caption.assert_called_with(mission)
         for j in range(20):  # Do 20 steps
             manual_control.key_handler(FakeRandomKeyboardEvent())
 
         fake_event = FakeRandomKeyboardEvent(reset=True)
         manual_control.key_handler(fake_event)
 
-    window.close.assert_not_called()
-
     # Close the environment
+    mocked_quit = mocker.patch("pygame.quit")
     fake_event = FakeRandomKeyboardEvent(close=True)
     manual_control.key_handler(fake_event)
-    window.close.assert_called()
+    mocked_quit.assert_called_once()