Kaynağa Gözat

Completed port of manual_control script to use matplotlib

Maxime Chevalier-Boisvert 5 yıl önce
ebeveyn
işleme
3fdf885346
3 değiştirilmiş dosya ile 107 ekleme ve 142 silme
  1. 1 1
      gym_minigrid/minigrid.py
  2. 106 90
      manual_control.py
  3. 0 51
      manual_control_matplotlib.py

+ 1 - 1
gym_minigrid/minigrid.py

@@ -578,7 +578,7 @@ class Grid:
 
 
                 tile_img = Grid.render_tile(
                 tile_img = Grid.render_tile(
                     cell,
                     cell,
-                    agent_dir=agent_dir if agent_pos == (i, j) else None,
+                    agent_dir=agent_dir if np.array_equal(agent_pos, (i, j)) else None,
                     highlight=False,
                     highlight=False,
                     tile_size=tile_size
                     tile_size=tile_size
                 )
                 )

+ 106 - 90
manual_control.py

@@ -1,95 +1,111 @@
 #!/usr/bin/env python3
 #!/usr/bin/env python3
 
 
-from __future__ import division, print_function
-
-import sys
-import numpy
-import gym
 import time
 import time
-from optparse import OptionParser
-
+import argparse
+import matplotlib.pyplot as plt
+import numpy as np
+import gym
 import gym_minigrid
 import gym_minigrid
 
 
-def main():
-    parser = OptionParser()
-    parser.add_option(
-        "-e",
-        "--env-name",
-        dest="env_name",
-        help="gym environment to load",
-        default='MiniGrid-MultiRoom-N6-v0'
-    )
-    (options, args) = parser.parse_args()
-
-    # Load the gym environment
-    env = gym.make(options.env_name)
-
-    def resetEnv():
-        env.reset()
-        if hasattr(env, 'mission'):
-            print('Mission: %s' % env.mission)
-
-    resetEnv()
-
-    # Create a window to render into
-    renderer = env.render('human')
-
-    def keyDownCb(keyName):
-        if keyName == 'BACKSPACE':
-            resetEnv()
-            return
-
-        if keyName == 'ESCAPE':
-            sys.exit(0)
-
-        action = 0
-
-        if keyName == 'LEFT':
-            action = env.actions.left
-        elif keyName == 'RIGHT':
-            action = env.actions.right
-        elif keyName == 'UP':
-            action = env.actions.forward
-
-        elif keyName == 'SPACE':
-            action = env.actions.toggle
-        elif keyName == 'PAGE_UP':
-            action = env.actions.pickup
-        elif keyName == 'PAGE_DOWN':
-            action = env.actions.drop
-
-        elif keyName == 'RETURN':
-            action = env.actions.done
-
-        # Screenshot funcitonality
-        elif keyName == 'ALT':
-            screen_path = options.env_name + '.png'
-            print('saving screenshot "{}"'.format(screen_path))
-            pixmap = env.render('pixmap')
-            pixmap.save(screen_path)
-            return
-
-        else:
-            print("unknown key %s" % keyName)
-            return
-
-        obs, reward, done, info = env.step(action)
-
-        print('step=%s, reward=%.2f' % (env.step_count, reward))
-
-        if done:
-            print('done!')
-            resetEnv()
-
-    renderer.window.setKeyDownCb(keyDownCb)
-
-    while True:
-        env.render('human')
-        time.sleep(0.01)
-
-        # If the window was closed
-        if renderer.window == None:
-            break
-
-if __name__ == "__main__":
-    main()
+def reset():
+    env.reset()
+
+    if hasattr(env, 'mission'):
+        print('Mission: %s' % env.mission)
+        plt.xlabel(env.mission)
+
+    img = env.render('rgb_array')
+    imshow_obj.set_data(img)
+    fig.canvas.draw()
+
+def step(action):
+    obs, reward, done, info = env.step(action)
+    print('step=%s, reward=%.2f' % (env.step_count, reward))
+
+    if done:
+        print('done!')
+        reset()
+
+    img = env.render('rgb_array')
+    imshow_obj.set_data(img)
+    fig.canvas.draw()
+
+def key_handler(event):
+    print('pressed', event.key)
+
+    if event.key == 'escape':
+        plt.close()
+        return
+
+    if event.key == 'backspace':
+        reset()
+        return
+
+    if event.key == 'left':
+        step(env.actions.left)
+        return
+    if event.key == 'right':
+        step(env.actions.right)
+        return
+    if event.key == 'up':
+        step(env.actions.forward)
+        return
+
+    # Spacebar
+    if event.key == ' ':
+        step(env.actions.toggle)
+        return
+    if event.key == 'pageup':
+        step(env.actions.pickup)
+        return
+    if event.key == 'pagedown':
+        step(env.actions.drop)
+        return
+
+    if event.key == 'enter':
+        step(env.actions.done)
+        return
+
+parser = argparse.ArgumentParser()
+parser.add_argument(
+    "-e",
+    "--env-name",
+    dest="env_name",
+    help="gym environment to load",
+    #default='MiniGrid-MultiRoom-N6-v0'
+    default='MiniGrid-Empty-8x8-v0'
+)
+args = parser.parse_args()
+
+env = gym.make(args.env_name)
+
+"""
+t0 = time.time()
+
+for i in range(1000):
+    img = env.render('rgb_array')
+
+t1 = time.time()
+dt = int(1000 * (t1-t0))
+print(dt)
+"""
+
+fig, ax = plt.subplots()
+
+# Keyboard handler
+fig.canvas.mpl_connect('key_press_event', key_handler)
+
+# Show the env name in the window title
+fig.canvas.set_window_title('gym_minigrid - ' + args.env_name)
+
+# Turn off x/y axis numbering/ticks
+ax.set_xticks([], [])
+ax.set_yticks([], [])
+
+#plt.figure(num='gym-minigrid')
+imshow_obj = ax.imshow(img)
+
+reset()
+
+# Show the plot, enter the matplotlib event loop
+plt.show()

+ 0 - 51
manual_control_matplotlib.py

@@ -1,51 +0,0 @@
-import time
-import matplotlib.pyplot as plt
-import numpy as np
-import gym
-import gym_minigrid
-
-def key_handler(event):
-    print('pressed', event.key)
-
-    if event.key == 'escape':
-        plt.close()
-        return
-
-    if event.key == 'left':
-        env.step(env.actions.left)
-        img = env.render('rgb_array')
-
-        #img = np.zeros(shape=(256,256,3), dtype=np.uint8)
-        imshow_obj.set_data(img)
-        fig.canvas.draw()
-        #plt.show()
-
-        return
-
-env = gym.make('MiniGrid-Empty-8x8-v0')
-
-#env.step(env.actions.left)
-
-
-t0 = time.time()
-
-for i in range(1000):
-    img = env.render('rgb_array')
-
-t1 = time.time()
-dt = int(1000 * (t1-t0))
-
-print(dt)
-
-print(img.shape)
-
-fig, ax = plt.subplots()
-fig.canvas.mpl_connect('key_press_event', key_handler)
-
-#plt.figure(num='gym-minigrid')
-imshow_obj = ax.imshow(img)
-
-
-
-
-plt.show()