manual_control.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. #!/usr/bin/env python3
  2. import time
  3. import argparse
  4. import numpy as np
  5. import gym
  6. import gym_minigrid
  7. from gym_minigrid.wrappers import *
  8. from gym_minigrid.window import Window
  9. def redraw(img):
  10. if not args.agent_view:
  11. img = env.render('rgb_array', tile_size=args.tile_size)
  12. window.show_img(img)
  13. def reset():
  14. obs = env.reset()
  15. if hasattr(env, 'mission'):
  16. print('Mission: %s' % env.mission)
  17. window.set_caption(env.mission)
  18. redraw(obs)
  19. def step(action):
  20. obs, reward, done, info = env.step(action)
  21. print('step=%s, reward=%.2f' % (env.step_count, reward))
  22. if done:
  23. print('done!')
  24. reset()
  25. redraw(obs)
  26. def key_handler(event):
  27. print('pressed', event.key)
  28. if event.key == 'escape':
  29. window.close()
  30. return
  31. if event.key == 'backspace':
  32. reset()
  33. return
  34. if event.key == 'left':
  35. step(env.actions.left)
  36. return
  37. if event.key == 'right':
  38. step(env.actions.right)
  39. return
  40. if event.key == 'up':
  41. step(env.actions.forward)
  42. return
  43. # Spacebar
  44. if event.key == ' ':
  45. step(env.actions.toggle)
  46. return
  47. if event.key == 'pageup':
  48. step(env.actions.pickup)
  49. return
  50. if event.key == 'pagedown':
  51. step(env.actions.drop)
  52. return
  53. if event.key == 'enter':
  54. step(env.actions.done)
  55. return
  56. parser = argparse.ArgumentParser()
  57. parser.add_argument(
  58. "--env",
  59. help="gym environment to load",
  60. default='MiniGrid-MultiRoom-N6-v0'
  61. )
  62. parser.add_argument(
  63. "--tile_size",
  64. type=int,
  65. help="size at which to render tiles",
  66. default=32
  67. )
  68. parser.add_argument(
  69. '--agent_view',
  70. default=False,
  71. help="draw the agent sees (partially observable view)",
  72. action='store_true'
  73. )
  74. args = parser.parse_args()
  75. env = gym.make(args.env)
  76. if args.agent_view:
  77. env = RGBImgPartialObsWrapper(env)
  78. env = ImgObsWrapper(env)
  79. window = Window('gym_minigrid - ' + args.env)
  80. window.reg_key_handler(key_handler)
  81. reset()
  82. # Blocking event loop
  83. window.show(block=True)