|
@@ -3,34 +3,36 @@
|
|
|
from __future__ import annotations
|
|
|
|
|
|
import gymnasium as gym
|
|
|
+import pygame
|
|
|
+from gymnasium import Env
|
|
|
|
|
|
from minigrid.core.actions import Actions
|
|
|
from minigrid.minigrid_env import MiniGridEnv
|
|
|
-from minigrid.utils.window import Window
|
|
|
from minigrid.wrappers import ImgObsWrapper, RGBImgPartialObsWrapper
|
|
|
|
|
|
|
|
|
class ManualControl:
|
|
|
def __init__(
|
|
|
self,
|
|
|
- env: MiniGridEnv,
|
|
|
- agent_view: bool = False,
|
|
|
- window: Window = None,
|
|
|
+ env: Env,
|
|
|
seed=None,
|
|
|
) -> None:
|
|
|
self.env = env
|
|
|
- self.agent_view = agent_view
|
|
|
self.seed = seed
|
|
|
-
|
|
|
- if window is None:
|
|
|
- window = Window("minigrid - " + str(env.__class__))
|
|
|
- self.window = window
|
|
|
- self.window.reg_key_handler(self.key_handler)
|
|
|
+ self.closed = False
|
|
|
|
|
|
def start(self):
|
|
|
"""Start the window display with blocking event loop"""
|
|
|
self.reset(self.seed)
|
|
|
- self.window.show(block=True)
|
|
|
+
|
|
|
+ while not self.closed:
|
|
|
+ for event in pygame.event.get():
|
|
|
+ if event.type == pygame.QUIT:
|
|
|
+ self.env.close()
|
|
|
+ break
|
|
|
+ if event.type == pygame.KEYDOWN:
|
|
|
+ event.key = pygame.key.name(int(event.key))
|
|
|
+ self.key_handler(event)
|
|
|
|
|
|
def step(self, action: Actions):
|
|
|
_, reward, terminated, truncated, _ = self.env.step(action)
|
|
@@ -43,27 +45,18 @@ class ManualControl:
|
|
|
print("truncated!")
|
|
|
self.reset(self.seed)
|
|
|
else:
|
|
|
- self.redraw()
|
|
|
-
|
|
|
- def redraw(self):
|
|
|
- frame = self.env.get_frame(agent_pov=self.agent_view)
|
|
|
- self.window.show_img(frame)
|
|
|
+ self.env.render()
|
|
|
|
|
|
def reset(self, seed=None):
|
|
|
self.env.reset(seed=seed)
|
|
|
-
|
|
|
- if hasattr(self.env, "mission"):
|
|
|
- print("Mission: %s" % self.env.mission)
|
|
|
- self.window.set_caption(self.env.mission)
|
|
|
-
|
|
|
- self.redraw()
|
|
|
+ self.env.render()
|
|
|
|
|
|
def key_handler(self, event):
|
|
|
key: str = event.key
|
|
|
print("pressed", key)
|
|
|
|
|
|
if key == "escape":
|
|
|
- self.window.close()
|
|
|
+ self.env.close()
|
|
|
return
|
|
|
if key == "backspace":
|
|
|
self.reset()
|
|
@@ -73,14 +66,18 @@ class ManualControl:
|
|
|
"left": Actions.left,
|
|
|
"right": Actions.right,
|
|
|
"up": Actions.forward,
|
|
|
- " ": Actions.toggle,
|
|
|
+ "space": Actions.toggle,
|
|
|
"pageup": Actions.pickup,
|
|
|
"pagedown": Actions.drop,
|
|
|
+ "tab": Actions.pickup,
|
|
|
+ "left shift": Actions.drop,
|
|
|
"enter": Actions.done,
|
|
|
}
|
|
|
-
|
|
|
- action = key_to_action[key]
|
|
|
- self.step(action)
|
|
|
+ if key in key_to_action.keys():
|
|
|
+ action = key_to_action[key]
|
|
|
+ self.step(action)
|
|
|
+ else:
|
|
|
+ print(key)
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
@@ -88,7 +85,11 @@ if __name__ == "__main__":
|
|
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
parser.add_argument(
|
|
|
- "--env", help="gym environment to load", default="MiniGrid-MultiRoom-N6-v0"
|
|
|
+ "--env-id",
|
|
|
+ type=str,
|
|
|
+ help="gym environment to load",
|
|
|
+ choices=gym.envs.registry.keys(),
|
|
|
+ default="MiniGrid-MultiRoom-N6-v0",
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
"--seed",
|
|
@@ -101,19 +102,38 @@ if __name__ == "__main__":
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
"--agent-view",
|
|
|
- default=False,
|
|
|
- help="draw the agent sees (partially observable view)",
|
|
|
action="store_true",
|
|
|
+ help="draw the agent sees (partially observable view)",
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--agent-view-size",
|
|
|
+ type=int,
|
|
|
+ default=7,
|
|
|
+ help="set the number of grid spaces visible in agent-view ",
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--screen-size",
|
|
|
+ type=int,
|
|
|
+ default="640",
|
|
|
+ help="set the resolution for pygame rendering (width and height)",
|
|
|
)
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
- env: MiniGridEnv = gym.make(args.env, tile_size=args.tile_size)
|
|
|
+ env: MiniGridEnv = gym.make(
|
|
|
+ args.env_id,
|
|
|
+ tile_size=args.tile_size,
|
|
|
+ render_mode="human",
|
|
|
+ agent_pov=args.agent_view,
|
|
|
+ agent_view_size=args.agent_view_size,
|
|
|
+ screen_size=args.screen_size,
|
|
|
+ )
|
|
|
|
|
|
+ # TODO: check if this can be removed
|
|
|
if args.agent_view:
|
|
|
print("Using agent view")
|
|
|
- env = RGBImgPartialObsWrapper(env, env.tile_size)
|
|
|
+ env = RGBImgPartialObsWrapper(env, args.tile_size)
|
|
|
env = ImgObsWrapper(env)
|
|
|
|
|
|
- manual_control = ManualControl(env, agent_view=args.agent_view, seed=args.seed)
|
|
|
+ manual_control = ManualControl(env, seed=args.seed)
|
|
|
manual_control.start()
|