manual_control.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. #!/usr/bin/env python3
  2. import time
  3. import argparse
  4. import matplotlib.pyplot as plt
  5. import numpy as np
  6. import gym
  7. import gym_minigrid
  8. def reset():
  9. env.reset()
  10. if hasattr(env, 'mission'):
  11. print('Mission: %s' % env.mission)
  12. plt.xlabel(env.mission)
  13. img = env.render('rgb_array', tile_size=args.tile_size)
  14. imshow_obj.set_data(img)
  15. fig.canvas.draw()
  16. def step(action):
  17. obs, reward, done, info = env.step(action)
  18. print('step=%s, reward=%.2f' % (env.step_count, reward))
  19. if done:
  20. print('done!')
  21. reset()
  22. img = env.render('rgb_array', tile_size=args.tile_size)
  23. imshow_obj.set_data(img)
  24. fig.canvas.draw()
  25. def key_handler(event):
  26. print('pressed', event.key)
  27. if event.key == 'escape':
  28. plt.close()
  29. return
  30. if event.key == 'backspace':
  31. reset()
  32. return
  33. if event.key == 'left':
  34. step(env.actions.left)
  35. return
  36. if event.key == 'right':
  37. step(env.actions.right)
  38. return
  39. if event.key == 'up':
  40. step(env.actions.forward)
  41. return
  42. # Spacebar
  43. if event.key == ' ':
  44. step(env.actions.toggle)
  45. return
  46. if event.key == 'pageup':
  47. step(env.actions.pickup)
  48. return
  49. if event.key == 'pagedown':
  50. step(env.actions.drop)
  51. return
  52. if event.key == 'enter':
  53. step(env.actions.done)
  54. return
  55. parser = argparse.ArgumentParser()
  56. parser.add_argument(
  57. "--env_name",
  58. help="gym environment to load",
  59. #default='MiniGrid-MultiRoom-N6-v0'
  60. default='MiniGrid-KeyCorridorS3R3-v0'
  61. )
  62. parser.add_argument(
  63. "--tile_size",
  64. type=int,
  65. help="size at which to render tiles",
  66. default=32
  67. )
  68. args = parser.parse_args()
  69. env = gym.make(args.env_name)
  70. fig, ax = plt.subplots()
  71. # Keyboard handler
  72. fig.canvas.mpl_connect('key_press_event', key_handler)
  73. # Show the env name in the window title
  74. fig.canvas.set_window_title('gym_minigrid - ' + args.env_name)
  75. # Turn off x/y axis numbering/ticks
  76. ax.set_xticks([], [])
  77. ax.set_yticks([], [])
  78. print(args.tile_size)
  79. # Show the first image of the environment
  80. img = env.render('rgb_array', tile_size=args.tile_size)
  81. imshow_obj = ax.imshow(img, interpolation='bilinear')
  82. reset()
  83. # Show the plot, enter the matplotlib event loop
  84. plt.show()