manual_control.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. #!/usr/bin/env python3
  2. from __future__ import annotations
  3. import gymnasium as gym
  4. from minigrid.minigrid_env import MiniGridEnv
  5. from minigrid.utils.window import Window
  6. from minigrid.wrappers import ImgObsWrapper, RGBImgPartialObsWrapper
  7. class ManualControl:
  8. def __init__(
  9. self,
  10. env: MiniGridEnv,
  11. agent_view: bool = False,
  12. window: Window = None,
  13. seed=None,
  14. ) -> None:
  15. self.env = env
  16. self.agent_view = agent_view
  17. self.seed = seed
  18. if window is None:
  19. window = Window("minigrid - " + str(env.__class__))
  20. self.window = window
  21. self.window.reg_key_handler(self.key_handler)
  22. def start(self):
  23. """Start the window display with blocking event loop"""
  24. self.reset(self.seed)
  25. self.window.show(block=True)
  26. def step(self, action: MiniGridEnv.Actions):
  27. _, reward, terminated, truncated, _ = self.env.step(action)
  28. print(f"step={self.env.step_count}, reward={reward:.2f}")
  29. if terminated:
  30. print("terminated!")
  31. self.reset(self.seed)
  32. elif truncated:
  33. print("truncated!")
  34. self.reset(self.seed)
  35. else:
  36. self.redraw()
  37. def redraw(self):
  38. frame = self.env.get_frame(agent_pov=self.agent_view)
  39. self.window.show_img(frame)
  40. def reset(self, seed=None):
  41. self.env.reset(seed=seed)
  42. if hasattr(self.env, "mission"):
  43. print("Mission: %s" % self.env.mission)
  44. self.window.set_caption(self.env.mission)
  45. self.redraw()
  46. def key_handler(self, event):
  47. key: str = event.key
  48. print("pressed", key)
  49. if key == "escape":
  50. self.window.close()
  51. return
  52. if key == "backspace":
  53. self.reset()
  54. return
  55. key_to_action = {
  56. "left": MiniGridEnv.Actions.left,
  57. "right": MiniGridEnv.Actions.right,
  58. "up": MiniGridEnv.Actions.forward,
  59. " ": MiniGridEnv.Actions.toggle,
  60. "pageup": MiniGridEnv.Actions.pickup,
  61. "pagedown": MiniGridEnv.Actions.drop,
  62. "enter": MiniGridEnv.Actions.done,
  63. }
  64. action = key_to_action[key]
  65. self.step(action)
  66. if __name__ == "__main__":
  67. import argparse
  68. parser = argparse.ArgumentParser()
  69. parser.add_argument(
  70. "--env", help="gym environment to load", default="MiniGrid-MultiRoom-N6-v0"
  71. )
  72. parser.add_argument(
  73. "--seed",
  74. type=int,
  75. help="random seed to generate the environment with",
  76. default=None,
  77. )
  78. parser.add_argument(
  79. "--tile-size", type=int, help="size at which to render tiles", default=32
  80. )
  81. parser.add_argument(
  82. "--agent-view",
  83. default=False,
  84. help="draw the agent sees (partially observable view)",
  85. action="store_true",
  86. )
  87. args = parser.parse_args()
  88. env: MiniGridEnv = gym.make(args.env, tile_size=args.tile_size)
  89. if args.agent_view:
  90. print("Using agent view")
  91. env = RGBImgPartialObsWrapper(env, env.tile_size)
  92. env = ImgObsWrapper(env)
  93. manual_control = ManualControl(env, agent_view=args.agent_view, seed=args.seed)
  94. manual_control.start()