manual_control.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. #!/usr/bin/env python3
  2. from __future__ import annotations
  3. import gymnasium as gym
  4. import pygame
  5. from gymnasium import Env
  6. from minigrid.core.actions import Actions
  7. from minigrid.minigrid_env import MiniGridEnv
  8. from minigrid.wrappers import ImgObsWrapper, RGBImgPartialObsWrapper
  9. class ManualControl:
  10. def __init__(
  11. self,
  12. env: Env,
  13. seed=None,
  14. ) -> None:
  15. self.env = env
  16. self.seed = seed
  17. self.closed = False
  18. def start(self):
  19. """Start the window display with blocking event loop"""
  20. self.reset(self.seed)
  21. while not self.closed:
  22. for event in pygame.event.get():
  23. if event.type == pygame.QUIT:
  24. self.env.close()
  25. break
  26. if event.type == pygame.KEYDOWN:
  27. event.key = pygame.key.name(int(event.key))
  28. self.key_handler(event)
  29. def step(self, action: Actions):
  30. _, reward, terminated, truncated, _ = self.env.step(action)
  31. print(f"step={self.env.step_count}, reward={reward:.2f}")
  32. if terminated:
  33. print("terminated!")
  34. self.reset(self.seed)
  35. elif truncated:
  36. print("truncated!")
  37. self.reset(self.seed)
  38. else:
  39. self.env.render()
  40. def reset(self, seed=None):
  41. self.env.reset(seed=seed)
  42. self.env.render()
  43. def key_handler(self, event):
  44. key: str = event.key
  45. print("pressed", key)
  46. if key == "escape":
  47. self.env.close()
  48. return
  49. if key == "backspace":
  50. self.reset()
  51. return
  52. key_to_action = {
  53. "left": Actions.left,
  54. "right": Actions.right,
  55. "up": Actions.forward,
  56. "space": Actions.toggle,
  57. "pageup": Actions.pickup,
  58. "pagedown": Actions.drop,
  59. "tab": Actions.pickup,
  60. "left shift": Actions.drop,
  61. "enter": Actions.done,
  62. }
  63. if key in key_to_action.keys():
  64. action = key_to_action[key]
  65. self.step(action)
  66. else:
  67. print(key)
  68. if __name__ == "__main__":
  69. import argparse
  70. parser = argparse.ArgumentParser()
  71. parser.add_argument(
  72. "--env-id",
  73. type=str,
  74. help="gym environment to load",
  75. choices=gym.envs.registry.keys(),
  76. default="MiniGrid-MultiRoom-N6-v0",
  77. )
  78. parser.add_argument(
  79. "--seed",
  80. type=int,
  81. help="random seed to generate the environment with",
  82. default=None,
  83. )
  84. parser.add_argument(
  85. "--tile-size", type=int, help="size at which to render tiles", default=32
  86. )
  87. parser.add_argument(
  88. "--agent-view",
  89. action="store_true",
  90. help="draw the agent sees (partially observable view)",
  91. )
  92. parser.add_argument(
  93. "--agent-view-size",
  94. type=int,
  95. default=7,
  96. help="set the number of grid spaces visible in agent-view ",
  97. )
  98. parser.add_argument(
  99. "--screen-size",
  100. type=int,
  101. default="640",
  102. help="set the resolution for pygame rendering (width and height)",
  103. )
  104. args = parser.parse_args()
  105. env: MiniGridEnv = gym.make(
  106. args.env_id,
  107. tile_size=args.tile_size,
  108. render_mode="human",
  109. agent_pov=args.agent_view,
  110. agent_view_size=args.agent_view_size,
  111. screen_size=args.screen_size,
  112. )
  113. # TODO: check if this can be removed
  114. if args.agent_view:
  115. print("Using agent view")
  116. env = RGBImgPartialObsWrapper(env, args.tile_size)
  117. env = ImgObsWrapper(env)
  118. manual_control = ManualControl(env, seed=args.seed)
  119. manual_control.start()