manual_control.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. #!/usr/bin/env python3
  2. import time
  3. import argparse
  4. import matplotlib.pyplot as plt
  5. import numpy as np
  6. import gym
  7. import gym_minigrid
  8. def reset():
  9. env.reset()
  10. if hasattr(env, 'mission'):
  11. print('Mission: %s' % env.mission)
  12. plt.xlabel(env.mission)
  13. img = env.render('rgb_array')
  14. imshow_obj.set_data(img)
  15. fig.canvas.draw()
  16. def step(action):
  17. obs, reward, done, info = env.step(action)
  18. print('step=%s, reward=%.2f' % (env.step_count, reward))
  19. if done:
  20. print('done!')
  21. reset()
  22. img = env.render('rgb_array')
  23. imshow_obj.set_data(img)
  24. fig.canvas.draw()
  25. def key_handler(event):
  26. print('pressed', event.key)
  27. if event.key == 'escape':
  28. plt.close()
  29. return
  30. if event.key == 'backspace':
  31. reset()
  32. return
  33. if event.key == 'left':
  34. step(env.actions.left)
  35. return
  36. if event.key == 'right':
  37. step(env.actions.right)
  38. return
  39. if event.key == 'up':
  40. step(env.actions.forward)
  41. return
  42. # Spacebar
  43. if event.key == ' ':
  44. step(env.actions.toggle)
  45. return
  46. if event.key == 'pageup':
  47. step(env.actions.pickup)
  48. return
  49. if event.key == 'pagedown':
  50. step(env.actions.drop)
  51. return
  52. if event.key == 'enter':
  53. step(env.actions.done)
  54. return
  55. parser = argparse.ArgumentParser()
  56. parser.add_argument(
  57. "-e",
  58. "--env-name",
  59. dest="env_name",
  60. help="gym environment to load",
  61. #default='MiniGrid-MultiRoom-N6-v0'
  62. default='MiniGrid-KeyCorridorS3R3-v0'
  63. )
  64. args = parser.parse_args()
  65. env = gym.make(args.env_name)
  66. fig, ax = plt.subplots()
  67. # Keyboard handler
  68. fig.canvas.mpl_connect('key_press_event', key_handler)
  69. # Show the env name in the window title
  70. fig.canvas.set_window_title('gym_minigrid - ' + args.env_name)
  71. # Turn off x/y axis numbering/ticks
  72. ax.set_xticks([], [])
  73. ax.set_yticks([], [])
  74. # Show the first image of the environment
  75. img = env.render('rgb_array')
  76. imshow_obj = ax.imshow(img, interpolation='bilinear')
  77. reset()
  78. # Show the plot, enter the matplotlib event loop
  79. plt.show()