manual_control.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. #!/usr/bin/env python3
  2. import argparse
  3. import gym
  4. from gym_minigrid.window import Window
  5. from gym_minigrid.wrappers import ImgObsWrapper, RGBImgPartialObsWrapper
  6. def redraw(img):
  7. if not args.agent_view:
  8. img = env.render("rgb_array", tile_size=args.tile_size)
  9. window.show_img(img)
  10. def reset():
  11. if args.seed != -1:
  12. env.seed(args.seed)
  13. obs = env.reset()
  14. if hasattr(env, "mission"):
  15. print("Mission: %s" % env.mission)
  16. window.set_caption(env.mission)
  17. redraw(obs)
  18. def step(action):
  19. obs, reward, done, info = env.step(action)
  20. print(f"step={env.step_count}, reward={reward:.2f}")
  21. if done:
  22. print("done!")
  23. reset()
  24. else:
  25. redraw(obs)
  26. def key_handler(event):
  27. print("pressed", event.key)
  28. if event.key == "escape":
  29. window.close()
  30. return
  31. if event.key == "backspace":
  32. reset()
  33. return
  34. if event.key == "left":
  35. step(env.actions.left)
  36. return
  37. if event.key == "right":
  38. step(env.actions.right)
  39. return
  40. if event.key == "up":
  41. step(env.actions.forward)
  42. return
  43. # Spacebar
  44. if event.key == " ":
  45. step(env.actions.toggle)
  46. return
  47. if event.key == "pageup":
  48. step(env.actions.pickup)
  49. return
  50. if event.key == "pagedown":
  51. step(env.actions.drop)
  52. return
  53. if event.key == "enter":
  54. step(env.actions.done)
  55. return
  56. parser = argparse.ArgumentParser()
  57. parser.add_argument(
  58. "--env", help="gym environment to load", default="MiniGrid-MultiRoom-N6-v0"
  59. )
  60. parser.add_argument(
  61. "--seed", type=int, help="random seed to generate the environment with", default=-1
  62. )
  63. parser.add_argument(
  64. "--tile_size", type=int, help="size at which to render tiles", default=32
  65. )
  66. parser.add_argument(
  67. "--agent_view",
  68. default=False,
  69. help="draw the agent sees (partially observable view)",
  70. action="store_true",
  71. )
  72. args = parser.parse_args()
  73. env = gym.make(args.env)
  74. if args.agent_view:
  75. env = RGBImgPartialObsWrapper(env)
  76. env = ImgObsWrapper(env)
  77. window = Window("gym_minigrid - " + args.env)
  78. window.reg_key_handler(key_handler)
  79. reset()
  80. # Blocking event loop
  81. window.show(block=True)