manual_control.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. #!/usr/bin/env python3
  2. import time
  3. import argparse
  4. import matplotlib.pyplot as plt
  5. import numpy as np
  6. import gym
  7. import gym_minigrid
  8. from gym_minigrid.wrappers import *
  9. fig = None
  10. imshow_obj = None
  11. def redraw(img):
  12. global imshow_obj
  13. if not args.agent_view:
  14. img = env.render('rgb_array', tile_size=args.tile_size)
  15. # Show the first image of the environment
  16. if imshow_obj is None:
  17. imshow_obj = ax.imshow(img, interpolation='bilinear')
  18. imshow_obj.set_data(img)
  19. fig.canvas.draw()
  20. def reset():
  21. obs = env.reset()
  22. if hasattr(env, 'mission'):
  23. print('Mission: %s' % env.mission)
  24. plt.xlabel(env.mission)
  25. redraw(obs)
  26. def step(action):
  27. obs, reward, done, info = env.step(action)
  28. print('step=%s, reward=%.2f' % (env.step_count, reward))
  29. if done:
  30. print('done!')
  31. reset()
  32. redraw(obs)
  33. def key_handler(event):
  34. print('pressed', event.key)
  35. if event.key == 'escape':
  36. plt.close()
  37. return
  38. if event.key == 'backspace':
  39. reset()
  40. return
  41. if event.key == 'left':
  42. step(env.actions.left)
  43. return
  44. if event.key == 'right':
  45. step(env.actions.right)
  46. return
  47. if event.key == 'up':
  48. step(env.actions.forward)
  49. return
  50. # Spacebar
  51. if event.key == ' ':
  52. step(env.actions.toggle)
  53. return
  54. if event.key == 'pageup':
  55. step(env.actions.pickup)
  56. return
  57. if event.key == 'pagedown':
  58. step(env.actions.drop)
  59. return
  60. if event.key == 'enter':
  61. step(env.actions.done)
  62. return
  63. parser = argparse.ArgumentParser()
  64. parser.add_argument(
  65. "--env_name",
  66. help="gym environment to load",
  67. #default='MiniGrid-MultiRoom-N6-v0'
  68. default='MiniGrid-KeyCorridorS3R3-v0'
  69. )
  70. parser.add_argument(
  71. "--tile_size",
  72. type=int,
  73. help="size at which to render tiles",
  74. default=32
  75. )
  76. parser.add_argument(
  77. '--agent_view',
  78. default=False,
  79. help="Draw the agent's partially observable view",
  80. action='store_true'
  81. )
  82. args = parser.parse_args()
  83. env = gym.make(args.env_name)
  84. if args.agent_view:
  85. env = RGBImgPartialObsWrapper(env)
  86. env = ImgObsWrapper(env)
  87. fig, ax = plt.subplots()
  88. # Keyboard handler
  89. fig.canvas.mpl_connect('key_press_event', key_handler)
  90. # Show the env name in the window title
  91. fig.canvas.set_window_title('gym_minigrid - ' + args.env_name)
  92. # Turn off x/y axis numbering/ticks
  93. ax.set_xticks([], [])
  94. ax.set_yticks([], [])
  95. reset()
  96. # Show the plot, enter the matplotlib event loop
  97. plt.show()