manual_control.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. #!/usr/bin/env python3
  2. import gymnasium as gym
  3. from minigrid.utils.window import Window
  4. from minigrid.wrappers import ImgObsWrapper, RGBImgPartialObsWrapper
  5. def redraw(window, img):
  6. window.show_img(img)
  7. def reset(env, window, seed=None):
  8. env.reset(seed=seed)
  9. if hasattr(env, "mission"):
  10. print("Mission: %s" % env.mission)
  11. window.set_caption(env.mission)
  12. img = env.get_frame()
  13. redraw(window, img)
  14. def step(env, window, action):
  15. obs, reward, terminated, truncated, info = env.step(action)
  16. print(f"step={env.step_count}, reward={reward:.2f}")
  17. if terminated:
  18. print("terminated!")
  19. reset(env, window)
  20. elif truncated:
  21. print("truncated!")
  22. reset(env, window)
  23. else:
  24. img = env.get_frame()
  25. redraw(window, img)
  26. def key_handler(env, window, event):
  27. print("pressed", event.key)
  28. if event.key == "escape":
  29. window.close()
  30. return
  31. if event.key == "backspace":
  32. reset(env, window)
  33. return
  34. if event.key == "left":
  35. step(env, window, env.actions.left)
  36. return
  37. if event.key == "right":
  38. step(env, window, env.actions.right)
  39. return
  40. if event.key == "up":
  41. step(env, window, env.actions.forward)
  42. return
  43. # Spacebar
  44. if event.key == " ":
  45. step(env, window, env.actions.toggle)
  46. return
  47. if event.key == "pageup":
  48. step(env, window, env.actions.pickup)
  49. return
  50. if event.key == "pagedown":
  51. step(env, window, env.actions.drop)
  52. return
  53. if event.key == "enter":
  54. step(env, window, env.actions.done)
  55. return
  56. if __name__ == "__main__":
  57. import argparse
  58. parser = argparse.ArgumentParser()
  59. parser.add_argument(
  60. "--env", help="gym environment to load", default="MiniGrid-MultiRoom-N6-v0"
  61. )
  62. parser.add_argument(
  63. "--seed",
  64. type=int,
  65. help="random seed to generate the environment with",
  66. default=-1,
  67. )
  68. parser.add_argument(
  69. "--tile_size", type=int, help="size at which to render tiles", default=32
  70. )
  71. parser.add_argument(
  72. "--agent_view",
  73. default=False,
  74. help="draw the agent sees (partially observable view)",
  75. action="store_true",
  76. )
  77. args = parser.parse_args()
  78. env = gym.make(
  79. args.env,
  80. tile_size=args.tile_size,
  81. )
  82. if args.agent_view:
  83. env = RGBImgPartialObsWrapper(env)
  84. env = ImgObsWrapper(env)
  85. window = Window("minigrid - " + args.env)
  86. window.reg_key_handler(lambda event: key_handler(env, window, event))
  87. seed = None if args.seed == -1 else args.seed
  88. reset(env, window, seed)
  89. # Blocking event loop
  90. window.show(block=True)