Browse Source

Create matplotlib window for human rendering mode

Maxime Chevalier-Boisvert 5 years ago
parent
commit
ff9583669b
6 changed files with 102 additions and 41 deletions
  1. 1 1
      README.md
  2. 15 11
      gym_minigrid/minigrid.py
  3. 1 0
      gym_minigrid/rendering.py
  4. 76 0
      gym_minigrid/window.py
  5. 8 27
      manual_control.py
  6. 1 2
      setup.py

+ 1 - 1
README.md

@@ -13,7 +13,7 @@ Requirements:
 - Python 3.5+
 - OpenAI Gym
 - NumPy
-- PyQT 5 for graphics
+- Matplotlib (optional, only needed for display)
 
 Please use this bibtex if you want to cite this repository in your publications:
 

+ 15 - 11
gym_minigrid/minigrid.py

@@ -288,7 +288,7 @@ class Key(WorldObj):
         c = COLORS[self.color]
 
         # Vertical quad
-        fill_coords(img, point_in_rect(0.50, 0.63, 0.31, 0.89), c)
+        fill_coords(img, point_in_rect(0.50, 0.63, 0.31, 0.88), c)
 
         # Teeth
         fill_coords(img, point_in_rect(0.38, 0.50, 0.59, 0.66), c)
@@ -619,7 +619,7 @@ class MiniGridEnv(gym.Env):
     """
 
     metadata = {
-        'render.modes': ['human', 'rgb_array', 'pixmap'],
+        'render.modes': ['human', 'rgb_array'],
         'video.frames_per_second' : 10
     }
 
@@ -680,11 +680,8 @@ class MiniGridEnv(gym.Env):
         # Range of possible rewards
         self.reward_range = (0, 1)
 
-        # Renderer object used to render the whole grid (full-scale)
-        self.grid_render = None
-
-        # Renderer used to render observations (small-scale agent view)
-        self.obs_render = None
+        # Window to use for human rendering mode
+        self.window = None
 
         # Environment configuration
         self.width = width
@@ -1226,12 +1223,15 @@ class MiniGridEnv(gym.Env):
         Render the whole-grid human view
         """
 
-        """
         if close:
-            if self.grid_render:
-                self.grid_render.close()
+            if self.window:
+                self.window.close()
             return
-        """
+
+        if mode == 'human' and not self.window:
+            import gym_minigrid.window
+            self.window = gym_minigrid.window.Window('gym_minigrid')
+            self.window.show(block=False)
 
         # Compute which cells are visible to the agent
         _, vis_mask = self.gen_obs_grid()
@@ -1271,4 +1271,8 @@ class MiniGridEnv(gym.Env):
             highlight_mask=highlight_mask if highlight else None
         )
 
+        if mode == 'human':
+            self.window.show_img(img)
+            self.window.set_caption(self.mission)
+
         return img

+ 1 - 0
gym_minigrid/rendering.py

@@ -41,6 +41,7 @@ def point_in_line(x0, y0, x1, y1, r):
     ymax = max(y0, y1) + r
 
     def fn(x, y):
+        # Fast, early escape test
         if x < xmin or x > xmax or y < ymin or y > ymax:
             return False
 

+ 76 - 0
gym_minigrid/window.py

@@ -0,0 +1,76 @@
+import sys
+import numpy as np
+
+# Only ask users to install matplotlib if they actually need it
+try:
+    import matplotlib.pyplot as plt
+except:
+    print('To display the environment in a window, please install matplotlib, eg:')
+    print('pip3 install --user matplotlib')
+    sys.exit(-1)
+
+class Window:
+    """
+    Window to draw a gridworld instance using Matplotlib
+    """
+
+    def __init__(self, title):
+        self.fig = None
+
+        self.imshow_obj = None
+
+        # Create the figure and axes
+        self.fig, self.ax = plt.subplots()
+
+        # Show the env name in the window title
+        self.fig.canvas.set_window_title(title)
+
+        # Turn off x/y axis numbering/ticks
+        self.ax.set_xticks([], [])
+        self.ax.set_yticks([], [])
+
+    def show_img(self, img):
+        """
+        Show an image or update the image being shown
+        """
+
+        # Show the first image of the environment
+        if self.imshow_obj is None:
+            self.imshow_obj = self.ax.imshow(img, interpolation='bilinear')
+
+        self.imshow_obj.set_data(img)
+        self.fig.canvas.draw()
+
+    def set_caption(self, text):
+        """
+        Set/update the caption text below the image
+        """
+
+        plt.xlabel(text)
+
+    def reg_key_handler(self, key_handler):
+        """
+        Register a keyboard event handler
+        """
+
+        # Keyboard handler
+        self.fig.canvas.mpl_connect('key_press_event', key_handler)
+
+    def show(self, block=True):
+        """
+        Show the window, and start an event loop
+        """
+
+        # If not blocking, trigger interactive mode
+        if not block:
+            plt.ion()
+
+        # Show the plot, enter the matplotlib event loop
+        plt.show(block=block)
+
+    def close(self):
+        """
+        Close the window
+        """
+
+        plt.close()

+ 8 - 27
manual_control.py

@@ -2,34 +2,24 @@
 
 import time
 import argparse
-import matplotlib.pyplot as plt
 import numpy as np
 import gym
 import gym_minigrid
 from gym_minigrid.wrappers import *
-
-fig = None
-imshow_obj = None
+from gym_minigrid.window import Window
 
 def redraw(img):
-    global imshow_obj
-
     if not args.agent_view:
         img = env.render('rgb_array', tile_size=args.tile_size)
 
-    # Show the first image of the environment
-    if imshow_obj is None:
-        imshow_obj = ax.imshow(img, interpolation='bilinear')
-
-    imshow_obj.set_data(img)
-    fig.canvas.draw()
+    window.show_img(img)
 
 def reset():
     obs = env.reset()
 
     if hasattr(env, 'mission'):
         print('Mission: %s' % env.mission)
-        plt.xlabel(env.mission)
+        window.set_caption(env.mission)
 
     redraw(obs)
 
@@ -47,7 +37,7 @@ def key_handler(event):
     print('pressed', event.key)
 
     if event.key == 'escape':
-        plt.close()
+        window.close()
         return
 
     if event.key == 'backspace':
@@ -107,19 +97,10 @@ if args.agent_view:
     env = RGBImgPartialObsWrapper(env)
     env = ImgObsWrapper(env)
 
-fig, ax = plt.subplots()
-
-# Keyboard handler
-fig.canvas.mpl_connect('key_press_event', key_handler)
-
-# Show the env name in the window title
-fig.canvas.set_window_title('gym_minigrid - ' + args.env_name)
-
-# Turn off x/y axis numbering/ticks
-ax.set_xticks([], [])
-ax.set_yticks([], [])
+window = Window('gym_minigrid - ' + args.env_name)
+window.reg_key_handler(key_handler)
 
 reset()
 
-# Show the plot, enter the matplotlib event loop
-plt.show()
+# Blocking event loop
+window.show(block=True)

+ 1 - 2
setup.py

@@ -9,7 +9,6 @@ setup(
     packages=['gym_minigrid', 'gym_minigrid.envs'],
     install_requires=[
         'gym>=0.9.6',
-        'numpy>=1.15.0',
-        'matplotlib'
+        'numpy>=1.15.0'
     ]
 )