manual_control.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. #!/usr/bin/env python3
  2. import argparse
  3. import gym
  4. from gym_minigrid.window import Window
  5. from gym_minigrid.wrappers import ImgObsWrapper, RGBImgPartialObsWrapper
  6. def redraw(img):
  7. if not args.agent_view:
  8. img = env.render(tile_size=args.tile_size)
  9. window.show_img(img)
  10. def reset():
  11. seed = None if args.seed == -1 else args.seed
  12. obs = env.reset(seed=seed)
  13. if hasattr(env, "mission"):
  14. print("Mission: %s" % env.mission)
  15. window.set_caption(env.mission)
  16. redraw(obs)
  17. def step(action):
  18. obs, reward, done, info = env.step(action)
  19. print(f"step={env.step_count}, reward={reward:.2f}")
  20. if done:
  21. print("done!")
  22. reset()
  23. else:
  24. redraw(obs)
  25. def key_handler(event):
  26. print("pressed", event.key)
  27. if event.key == "escape":
  28. window.close()
  29. return
  30. if event.key == "backspace":
  31. reset()
  32. return
  33. if event.key == "left":
  34. step(env.actions.left)
  35. return
  36. if event.key == "right":
  37. step(env.actions.right)
  38. return
  39. if event.key == "up":
  40. step(env.actions.forward)
  41. return
  42. # Spacebar
  43. if event.key == " ":
  44. step(env.actions.toggle)
  45. return
  46. if event.key == "pageup":
  47. step(env.actions.pickup)
  48. return
  49. if event.key == "pagedown":
  50. step(env.actions.drop)
  51. return
  52. if event.key == "enter":
  53. step(env.actions.done)
  54. return
  55. parser = argparse.ArgumentParser()
  56. parser.add_argument(
  57. "--env", help="gym environment to load", default="MiniGrid-MultiRoom-N6-v0"
  58. )
  59. parser.add_argument(
  60. "--seed", type=int, help="random seed to generate the environment with", default=-1
  61. )
  62. parser.add_argument(
  63. "--tile_size", type=int, help="size at which to render tiles", default=32
  64. )
  65. parser.add_argument(
  66. "--agent_view",
  67. default=False,
  68. help="draw the agent sees (partially observable view)",
  69. action="store_true",
  70. )
  71. args = parser.parse_args()
  72. env = gym.make(args.env, render_mode="rgb_array")
  73. if args.agent_view:
  74. env = RGBImgPartialObsWrapper(env)
  75. env = ImgObsWrapper(env)
  76. window = Window("gym_minigrid - " + args.env)
  77. window.reg_key_handler(key_handler)
  78. reset()
  79. # Blocking event loop
  80. window.show(block=True)