|
@@ -2,76 +2,82 @@
|
|
|
|
|
|
import gymnasium as gym
|
|
|
|
|
|
+from minigrid.minigrid_env import MiniGridEnv
|
|
|
from minigrid.utils.window import Window
|
|
|
from minigrid.wrappers import ImgObsWrapper, RGBImgPartialObsWrapper
|
|
|
|
|
|
|
|
|
-def redraw(window, img):
|
|
|
- window.show_img(img)
|
|
|
-
|
|
|
-
|
|
|
-def reset(env, window, seed=None):
|
|
|
- env.reset(seed=seed)
|
|
|
-
|
|
|
- if hasattr(env, "mission"):
|
|
|
- print("Mission: %s" % env.mission)
|
|
|
- window.set_caption(env.mission)
|
|
|
-
|
|
|
- img = env.get_frame()
|
|
|
-
|
|
|
- redraw(window, img)
|
|
|
-
|
|
|
-
|
|
|
-def step(env, window, action):
|
|
|
- obs, reward, terminated, truncated, info = env.step(action)
|
|
|
- print(f"step={env.step_count}, reward={reward:.2f}")
|
|
|
-
|
|
|
- if terminated:
|
|
|
- print("terminated!")
|
|
|
- reset(env, window)
|
|
|
- elif truncated:
|
|
|
- print("truncated!")
|
|
|
- reset(env, window)
|
|
|
- else:
|
|
|
- img = env.get_frame()
|
|
|
- redraw(window, img)
|
|
|
-
|
|
|
-
|
|
|
-def key_handler(env, window, event):
|
|
|
- print("pressed", event.key)
|
|
|
-
|
|
|
- if event.key == "escape":
|
|
|
- window.close()
|
|
|
- return
|
|
|
-
|
|
|
- if event.key == "backspace":
|
|
|
- reset(env, window)
|
|
|
- return
|
|
|
-
|
|
|
- if event.key == "left":
|
|
|
- step(env, window, env.actions.left)
|
|
|
- return
|
|
|
- if event.key == "right":
|
|
|
- step(env, window, env.actions.right)
|
|
|
- return
|
|
|
- if event.key == "up":
|
|
|
- step(env, window, env.actions.forward)
|
|
|
- return
|
|
|
-
|
|
|
- # Spacebar
|
|
|
- if event.key == " ":
|
|
|
- step(env, window, env.actions.toggle)
|
|
|
- return
|
|
|
- if event.key == "pageup":
|
|
|
- step(env, window, env.actions.pickup)
|
|
|
- return
|
|
|
- if event.key == "pagedown":
|
|
|
- step(env, window, env.actions.drop)
|
|
|
- return
|
|
|
-
|
|
|
- if event.key == "enter":
|
|
|
- step(env, window, env.actions.done)
|
|
|
- return
|
|
|
+class ManualControl:
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ env: MiniGridEnv,
|
|
|
+ agent_view: bool = False,
|
|
|
+ window: Window = None,
|
|
|
+ seed=None,
|
|
|
+ ) -> None:
|
|
|
+ self.env = env
|
|
|
+ self.agent_view = agent_view
|
|
|
+ self.seed = seed
|
|
|
+
|
|
|
+ if window is None:
|
|
|
+ window = Window("minigrid - " + str(env.__class__))
|
|
|
+ self.window = window
|
|
|
+ self.window.reg_key_handler(self.key_handler)
|
|
|
+
|
|
|
+ def start(self):
|
|
|
+ """Start the window display with blocking event loop"""
|
|
|
+ self.reset(self.seed)
|
|
|
+ self.window.show(block=True)
|
|
|
+
|
|
|
+ def step(self, action: MiniGridEnv.Actions):
|
|
|
+ _, reward, terminated, truncated, _ = self.env.step(action)
|
|
|
+ print(f"step={self.env.step_count}, reward={reward:.2f}")
|
|
|
+
|
|
|
+ if terminated:
|
|
|
+ print("terminated!")
|
|
|
+ self.reset(self.seed)
|
|
|
+ elif truncated:
|
|
|
+ print("truncated!")
|
|
|
+ self.reset(self.seed)
|
|
|
+ else:
|
|
|
+ self.redraw()
|
|
|
+
|
|
|
+ def redraw(self):
|
|
|
+ frame = self.env.get_frame(agent_pov=self.agent_view)
|
|
|
+ self.window.show_img(frame)
|
|
|
+
|
|
|
+ def reset(self, seed=None):
|
|
|
+ self.env.reset(seed=seed)
|
|
|
+
|
|
|
+ if hasattr(self.env, "mission"):
|
|
|
+ print("Mission: %s" % self.env.mission)
|
|
|
+ self.window.set_caption(self.env.mission)
|
|
|
+
|
|
|
+ self.redraw()
|
|
|
+
|
|
|
+ def key_handler(self, event):
|
|
|
+ key: str = event.key
|
|
|
+ print("pressed", key)
|
|
|
+
|
|
|
+ if key == "escape":
|
|
|
+ self.window.close()
|
|
|
+ return
|
|
|
+ if key == "backspace":
|
|
|
+ self.reset()
|
|
|
+ return
|
|
|
+
|
|
|
+ key_to_action = {
|
|
|
+ "left": MiniGridEnv.Actions.left,
|
|
|
+ "right": MiniGridEnv.Actions.right,
|
|
|
+ "up": MiniGridEnv.Actions.forward,
|
|
|
+ " ": MiniGridEnv.Actions.toggle,
|
|
|
+ "pageup": MiniGridEnv.Actions.pickup,
|
|
|
+ "pagedown": MiniGridEnv.Actions.drop,
|
|
|
+ "enter": MiniGridEnv.Actions.done,
|
|
|
+ }
|
|
|
+
|
|
|
+ action = key_to_action[key]
|
|
|
+ self.step(action)
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
@@ -85,13 +91,13 @@ if __name__ == "__main__":
|
|
|
"--seed",
|
|
|
type=int,
|
|
|
help="random seed to generate the environment with",
|
|
|
- default=-1,
|
|
|
+ default=None,
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
- "--tile_size", type=int, help="size at which to render tiles", default=32
|
|
|
+ "--tile-size", type=int, help="size at which to render tiles", default=32
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
- "--agent_view",
|
|
|
+ "--agent-view",
|
|
|
default=False,
|
|
|
help="draw the agent sees (partially observable view)",
|
|
|
action="store_true",
|
|
@@ -99,20 +105,12 @@ if __name__ == "__main__":
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
- env = gym.make(
|
|
|
- args.env,
|
|
|
- tile_size=args.tile_size,
|
|
|
- )
|
|
|
+ env: MiniGridEnv = gym.make(args.env, tile_size=args.tile_size)
|
|
|
|
|
|
if args.agent_view:
|
|
|
- env = RGBImgPartialObsWrapper(env)
|
|
|
+ print("Using agent view")
|
|
|
+ env = RGBImgPartialObsWrapper(env, env.tile_size)
|
|
|
env = ImgObsWrapper(env)
|
|
|
|
|
|
- window = Window("minigrid - " + args.env)
|
|
|
- window.reg_key_handler(lambda event: key_handler(env, window, event))
|
|
|
-
|
|
|
- seed = None if args.seed == -1 else args.seed
|
|
|
- reset(env, window, seed)
|
|
|
-
|
|
|
- # Blocking event loop
|
|
|
- window.show(block=True)
|
|
|
+ manual_control = ManualControl(env, agent_view=args.agent_view, seed=args.seed)
|
|
|
+ manual_control.start()
|