manual_control.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. #!/usr/bin/env python3
  2. from __future__ import annotations
  3. import gymnasium as gym
  4. from minigrid.core.actions import Actions
  5. from minigrid.minigrid_env import MiniGridEnv
  6. from minigrid.utils.window import Window
  7. from minigrid.wrappers import ImgObsWrapper, RGBImgPartialObsWrapper
  8. class ManualControl:
  9. def __init__(
  10. self,
  11. env: MiniGridEnv,
  12. agent_view: bool = False,
  13. window: Window = None,
  14. seed=None,
  15. ) -> None:
  16. self.env = env
  17. self.agent_view = agent_view
  18. self.seed = seed
  19. if window is None:
  20. window = Window("minigrid - " + str(env.__class__))
  21. self.window = window
  22. self.window.reg_key_handler(self.key_handler)
  23. def start(self):
  24. """Start the window display with blocking event loop"""
  25. self.reset(self.seed)
  26. self.window.show(block=True)
  27. def step(self, action: Actions):
  28. _, reward, terminated, truncated, _ = self.env.step(action)
  29. print(f"step={self.env.step_count}, reward={reward:.2f}")
  30. if terminated:
  31. print("terminated!")
  32. self.reset(self.seed)
  33. elif truncated:
  34. print("truncated!")
  35. self.reset(self.seed)
  36. else:
  37. self.redraw()
  38. def redraw(self):
  39. frame = self.env.get_frame(agent_pov=self.agent_view)
  40. self.window.show_img(frame)
  41. def reset(self, seed=None):
  42. self.env.reset(seed=seed)
  43. if hasattr(self.env, "mission"):
  44. print("Mission: %s" % self.env.mission)
  45. self.window.set_caption(self.env.mission)
  46. self.redraw()
  47. def key_handler(self, event):
  48. key: str = event.key
  49. print("pressed", key)
  50. if key == "escape":
  51. self.window.close()
  52. return
  53. if key == "backspace":
  54. self.reset()
  55. return
  56. key_to_action = {
  57. "left": Actions.left,
  58. "right": Actions.right,
  59. "up": Actions.forward,
  60. " ": Actions.toggle,
  61. "pageup": Actions.pickup,
  62. "pagedown": Actions.drop,
  63. "enter": Actions.done,
  64. }
  65. action = key_to_action[key]
  66. self.step(action)
  67. if __name__ == "__main__":
  68. import argparse
  69. parser = argparse.ArgumentParser()
  70. parser.add_argument(
  71. "--env", help="gym environment to load", default="MiniGrid-MultiRoom-N6-v0"
  72. )
  73. parser.add_argument(
  74. "--seed",
  75. type=int,
  76. help="random seed to generate the environment with",
  77. default=None,
  78. )
  79. parser.add_argument(
  80. "--tile-size", type=int, help="size at which to render tiles", default=32
  81. )
  82. parser.add_argument(
  83. "--agent-view",
  84. default=False,
  85. help="draw the agent sees (partially observable view)",
  86. action="store_true",
  87. )
  88. args = parser.parse_args()
  89. env: MiniGridEnv = gym.make(args.env, tile_size=args.tile_size)
  90. if args.agent_view:
  91. print("Using agent view")
  92. env = RGBImgPartialObsWrapper(env, env.tile_size)
  93. env = ImgObsWrapper(env)
  94. manual_control = ManualControl(env, agent_view=args.agent_view, seed=args.seed)
  95. manual_control.start()