Преглед изворни кода

fix rendering and add tests for scripts

saleml пре 2 година
родитељ
комит
e8166b3dea

+ 53 - 46
gym_minigrid/benchmark.py

@@ -1,55 +1,62 @@
 #!/usr/bin/env python3
 #!/usr/bin/env python3
 
 
-import argparse
 import time
 import time
 
 
 import gym
 import gym
 
 
 from gym_minigrid.wrappers import ImgObsWrapper, RGBImgPartialObsWrapper
 from gym_minigrid.wrappers import ImgObsWrapper, RGBImgPartialObsWrapper
 
 
-parser = argparse.ArgumentParser()
-parser.add_argument(
-    "--env-name",
-    dest="env_name",
-    help="gym environment to load",
-    default="MiniGrid-LavaGapS7-v0",
-)
-parser.add_argument("--num_resets", default=200)
-parser.add_argument("--num_frames", default=5000)
-args = parser.parse_args()
-
-env = gym.make(args.env_name, new_step_api=True)
-
-# Benchmark env.reset
-t0 = time.time()
-for i in range(args.num_resets):
+
+def benchmark(env_id, num_resets, num_frames):
+    env = gym.make(env_id, new_step_api=True, render_mode="rgb_array")
+    # Benchmark env.reset
+    t0 = time.time()
+    for i in range(num_resets):
+        env.reset()
+    t1 = time.time()
+    dt = t1 - t0
+    reset_time = (1000 * dt) / num_resets
+
+    # Benchmark rendering
+    t0 = time.time()
+    for i in range(num_frames):
+        env.render()
+    t1 = time.time()
+    dt = t1 - t0
+    frames_per_sec = num_frames / dt
+
+    # Create an environment with an RGB agent observation
+    env = gym.make(env_id, new_step_api=True, render_mode="rgb_array")
+    env = RGBImgPartialObsWrapper(env)
+    env = ImgObsWrapper(env)
+
     env.reset()
     env.reset()
-t1 = time.time()
-dt = t1 - t0
-reset_time = (1000 * dt) / args.num_resets
-
-# Benchmark rendering
-t0 = time.time()
-for i in range(args.num_frames):
-    env.render("rgb_array")
-t1 = time.time()
-dt = t1 - t0
-frames_per_sec = args.num_frames / dt
-
-# Create an environment with an RGB agent observation
-env = gym.make(args.env_name, new_step_api=True)
-env = RGBImgPartialObsWrapper(env)
-env = ImgObsWrapper(env)
-
-env.reset()
-# Benchmark rendering
-t0 = time.time()
-for i in range(args.num_frames):
-    obs, reward, terminated, truncated, info = env.step(0)
-t1 = time.time()
-dt = t1 - t0
-agent_view_fps = args.num_frames / dt
-
-print(f"Env reset time: {reset_time:.1f} ms")
-print(f"Rendering FPS : {frames_per_sec:.0f}")
-print(f"Agent view FPS: {agent_view_fps:.0f}")
+    # Benchmark rendering
+    t0 = time.time()
+    for i in range(num_frames):
+        obs, reward, terminated, truncated, info = env.step(0)
+    t1 = time.time()
+    dt = t1 - t0
+    agent_view_fps = num_frames / dt
+
+    print(f"Env reset time: {reset_time:.1f} ms")
+    print(f"Rendering FPS : {frames_per_sec:.0f}")
+    print(f"Agent view FPS: {agent_view_fps:.0f}")
+
+    env.close()
+
+
+if __name__ == "__main__":
+    import argparse
+
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        "--env-id",
+        dest="env_id",
+        help="gym environment to load",
+        default="MiniGrid-LavaGapS7-v0",
+    )
+    parser.add_argument("--num_resets", default=200)
+    parser.add_argument("--num_frames", default=5000)
+    args = parser.parse_args()
+    benchmark(args.env_id, args.num_resets, args.num_frames)

+ 1 - 0
gym_minigrid/envs/lavagap.py

@@ -29,6 +29,7 @@ class LavaGapEnv(MiniGridEnv):
             max_steps=4 * size * size,
             max_steps=4 * size * size,
             # Set this to True for maximum speed
             # Set this to True for maximum speed
             see_through_walls=False,
             see_through_walls=False,
+            **kwargs
         )
         )
 
 
     def _gen_grid(self, width, height):
     def _gen_grid(self, width, height):

+ 1 - 0
gym_minigrid/envs/multiroom.py

@@ -44,6 +44,7 @@ class MultiRoomEnv(MiniGridEnv):
             width=self.size,
             width=self.size,
             height=self.size,
             height=self.size,
             max_steps=self.maxNumRooms * 20,
             max_steps=self.maxNumRooms * 20,
+            **kwargs
         )
         )
 
 
     def _gen_grid(self, width, height):
     def _gen_grid(self, width, height):

+ 1 - 0
gym_minigrid/envs/putnear.py

@@ -35,6 +35,7 @@ class PutNearEnv(MiniGridEnv):
             max_steps=5 * size,
             max_steps=5 * size,
             # Set this to True for maximum speed
             # Set this to True for maximum speed
             see_through_walls=True,
             see_through_walls=True,
+            **kwargs,
         )
         )
 
 
     def _gen_grid(self, width, height):
     def _gen_grid(self, width, height):

+ 65 - 54
gym_minigrid/manual_control.py

@@ -1,44 +1,43 @@
 #!/usr/bin/env python3
 #!/usr/bin/env python3
 
 
-import argparse
-
 import gym
 import gym
 
 
 from gym_minigrid.window import Window
 from gym_minigrid.window import Window
 from gym_minigrid.wrappers import ImgObsWrapper, RGBImgPartialObsWrapper
 from gym_minigrid.wrappers import ImgObsWrapper, RGBImgPartialObsWrapper
 
 
 
 
-def redraw(img):
-    if not args.agent_view:
-        img = env.render(mode="rgb_array", tile_size=args.tile_size)
+def redraw(window, img):
     window.show_img(img)
     window.show_img(img)
 
 
 
 
-def reset():
-    obs = env.reset()
+def reset(env, window):
+    _ = env.reset()
 
 
     if hasattr(env, "mission"):
     if hasattr(env, "mission"):
         print("Mission: %s" % env.mission)
         print("Mission: %s" % env.mission)
         window.set_caption(env.mission)
         window.set_caption(env.mission)
 
 
-    redraw(obs)
+    img = env.get_render()
+
+    redraw(window, img)
 
 
 
 
-def step(action):
+def step(env, window, action):
     obs, reward, terminated, truncated, info = env.step(action)
     obs, reward, terminated, truncated, info = env.step(action)
     print(f"step={env.step_count}, reward={reward:.2f}")
     print(f"step={env.step_count}, reward={reward:.2f}")
 
 
     if terminated:
     if terminated:
         print("terminated!")
         print("terminated!")
-        reset()
+        reset(env, window)
     elif truncated:
     elif truncated:
         print("truncated!")
         print("truncated!")
-        reset()
+        reset(env, window)
     else:
     else:
-        redraw(obs)
+        img = env.get_full_render()
+        redraw(window, img)
 
 
 
 
-def key_handler(event):
+def key_handler(env, window, event):
     print("pressed", event.key)
     print("pressed", event.key)
 
 
     if event.key == "escape":
     if event.key == "escape":
@@ -46,65 +45,77 @@ def key_handler(event):
         return
         return
 
 
     if event.key == "backspace":
     if event.key == "backspace":
-        reset()
+        reset(env, window)
         return
         return
 
 
     if event.key == "left":
     if event.key == "left":
-        step(env.actions.left)
+        step(env, window, env.actions.left)
         return
         return
     if event.key == "right":
     if event.key == "right":
-        step(env.actions.right)
+        step(env, window, env.actions.right)
         return
         return
     if event.key == "up":
     if event.key == "up":
-        step(env.actions.forward)
+        step(env, window, env.actions.forward)
         return
         return
 
 
     # Spacebar
     # Spacebar
     if event.key == " ":
     if event.key == " ":
-        step(env.actions.toggle)
+        step(env, window, env.actions.toggle)
         return
         return
     if event.key == "pageup":
     if event.key == "pageup":
-        step(env.actions.pickup)
+        step(env, window, env.actions.pickup)
         return
         return
     if event.key == "pagedown":
     if event.key == "pagedown":
-        step(env.actions.drop)
+        step(env, window, env.actions.drop)
         return
         return
 
 
     if event.key == "enter":
     if event.key == "enter":
-        step(env.actions.done)
+        step(env, window, env.actions.done)
         return
         return
 
 
 
 
-parser = argparse.ArgumentParser()
-parser.add_argument(
-    "--env", help="gym environment to load", default="MiniGrid-MultiRoom-N6-v0"
-)
-parser.add_argument(
-    "--seed", type=int, help="random seed to generate the environment with", default=-1
-)
-parser.add_argument(
-    "--tile_size", type=int, help="size at which to render tiles", default=32
-)
-parser.add_argument(
-    "--agent_view",
-    default=False,
-    help="draw the agent sees (partially observable view)",
-    action="store_true",
-)
-
-args = parser.parse_args()
-
-seed = None if args.seed == -1 else args.seed
-env = gym.make(args.env, seed=seed, new_step_api=True)
-
-if args.agent_view:
-    env = RGBImgPartialObsWrapper(env)
-    env = ImgObsWrapper(env)
-
-window = Window("gym_minigrid - " + args.env)
-window.reg_key_handler(key_handler)
-
-reset()
-
-# Blocking event loop
-window.show(block=True)
+if __name__ == "__main__":
+    import argparse
+
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        "--env", help="gym environment to load", default="MiniGrid-MultiRoom-N6-v0"
+    )
+    parser.add_argument(
+        "--seed",
+        type=int,
+        help="random seed to generate the environment with",
+        default=-1,
+    )
+    parser.add_argument(
+        "--tile_size", type=int, help="size at which to render tiles", default=32
+    )
+    parser.add_argument(
+        "--agent_view",
+        default=False,
+        help="draw the agent sees (partially observable view)",
+        action="store_true",
+    )
+
+    args = parser.parse_args()
+
+    seed = None if args.seed == -1 else args.seed
+    env = gym.make(
+        args.env,
+        seed=seed,
+        new_step_api=True,
+        render_mode="human",  # Note that we do not need to use "human", as Window handles human rendering.
+        tile_size=args.tile_size,
+    )
+
+    if args.agent_view:
+        env = RGBImgPartialObsWrapper(env)
+        env = ImgObsWrapper(env)
+
+    window = Window("gym_minigrid - " + args.env)
+    window.reg_key_handler(lambda event: key_handler(env, window, event))
+
+    reset(env, window)
+
+    # Blocking event loop
+    window.show(block=True)

+ 34 - 12
gym_minigrid/minigrid.py

@@ -858,10 +858,17 @@ class MiniGridEnv(gym.Env):
         max_steps: int = 100,
         max_steps: int = 100,
         see_through_walls: bool = False,
         see_through_walls: bool = False,
         agent_view_size: int = 7,
         agent_view_size: int = 7,
+        render_mode: Optional[str] = None,
         highlight: bool = True,
         highlight: bool = True,
         tile_size: int = TILE_PIXELS,
         tile_size: int = TILE_PIXELS,
+        agent_pov: bool = False,
         **kwargs,
         **kwargs,
     ):
     ):
+        # Rendering attributes
+        self.render_mode = render_mode
+        self.highlight = highlight
+        self.tile_size = tile_size
+        self.agent_pov = agent_pov
 
 
         # Initialize mission
         # Initialize mission
         self.mission = mission_space.sample()
         self.mission = mission_space.sample()
@@ -1428,16 +1435,15 @@ class MiniGridEnv(gym.Env):
 
 
         return obs
         return obs
 
 
-    def get_obs_render(self, obs, tile_size=TILE_PIXELS // 2):
+    def get_pov_render(self):
         """
         """
-        Render an agent observation for visualization
+        Render an agent's POV observation for visualization
         """
         """
-
-        grid, vis_mask = Grid.decode(obs)
+        grid, vis_mask = self.gen_obs_grid()
 
 
         # Render the whole grid
         # Render the whole grid
         img = grid.render(
         img = grid.render(
-            tile_size,
+            self.tile_size,
             agent_pos=(self.agent_view_size // 2, self.agent_view_size - 1),
             agent_pos=(self.agent_view_size // 2, self.agent_view_size - 1),
             agent_dir=3,
             agent_dir=3,
             highlight_mask=vis_mask,
             highlight_mask=vis_mask,
@@ -1445,14 +1451,12 @@ class MiniGridEnv(gym.Env):
 
 
         return img
         return img
 
 
-    def render(self, mode="human", highlight=True, tile_size=TILE_PIXELS):
-        assert mode in self.metadata["render_modes"]
+    def get_full_render(self):
         """
         """
-        Render the whole-grid human view
+        Render a non-paratial observation for visualization
         """
         """
-        if mode == "human" and not self.window:
-            self.window = Window("gym_minigrid")
-            self.window.show(block=False)
+        tile_size = self.tile_size
+        highlight = self.highlight
 
 
         # Compute which cells are visible to the agent
         # Compute which cells are visible to the agent
         _, vis_mask = self.gen_obs_grid()
         _, vis_mask = self.gen_obs_grid()
@@ -1496,8 +1500,26 @@ class MiniGridEnv(gym.Env):
             highlight_mask=highlight_mask if highlight else None,
             highlight_mask=highlight_mask if highlight else None,
         )
         )
 
 
+        return img
+
+    def get_render(self):
+        "Returns an image corresponding to the whole environment or the agent's pov"
+        if self.agent_pov:
+            return self.get_pov_render()
+        else:
+            return self.get_full_render()
+
+    def render(self):
+        """
+        Render the whole-grid human view
+        """
+        img = self.get_render()
+
+        mode = self.render_mode
         if mode == "human":
         if mode == "human":
-            assert self.window is not None
+            if self.window is None:
+                self.window = Window("gym_minigrid")
+                self.window.show(block=False)
             self.window.set_caption(self.mission)
             self.window.set_caption(self.mission)
             self.window.show_img(img)
             self.window.show_img(img)
         else:
         else:

+ 21 - 19
gym_minigrid/wrappers.py

@@ -20,7 +20,7 @@ class ReseedWrapper(Wrapper):
     def __init__(self, env, seeds=[0], seed_idx=0):
     def __init__(self, env, seeds=[0], seed_idx=0):
         self.seeds = list(seeds)
         self.seeds = list(seeds)
         self.seed_idx = seed_idx
         self.seed_idx = seed_idx
-        super().__init__(env)
+        super().__init__(env, new_step_api=env.new_step_api)
 
 
     def reset(self, **kwargs):
     def reset(self, **kwargs):
         seed = self.seeds[self.seed_idx]
         seed = self.seeds[self.seed_idx]
@@ -39,7 +39,7 @@ class ActionBonus(gym.Wrapper):
     """
     """
 
 
     def __init__(self, env):
     def __init__(self, env):
-        super().__init__(env)
+        super().__init__(env, new_step_api=env.new_step_api)
         self.counts = {}
         self.counts = {}
 
 
     def step(self, action):
     def step(self, action):
@@ -73,7 +73,7 @@ class StateBonus(Wrapper):
     """
     """
 
 
     def __init__(self, env):
     def __init__(self, env):
-        super().__init__(env)
+        super().__init__(env, new_step_api=env.new_step_api)
         self.counts = {}
         self.counts = {}
 
 
     def step(self, action):
     def step(self, action):
@@ -108,7 +108,7 @@ class ImgObsWrapper(ObservationWrapper):
     """
     """
 
 
     def __init__(self, env):
     def __init__(self, env):
-        super().__init__(env)
+        super().__init__(env, new_step_api=env.new_step_api)
         self.observation_space = env.observation_space.spaces["image"]
         self.observation_space = env.observation_space.spaces["image"]
 
 
     def observation(self, obs):
     def observation(self, obs):
@@ -122,7 +122,7 @@ class OneHotPartialObsWrapper(ObservationWrapper):
     """
     """
 
 
     def __init__(self, env, tile_size=8):
     def __init__(self, env, tile_size=8):
-        super().__init__(env)
+        super().__init__(env, new_step_api=env.new_step_api)
 
 
         self.tile_size = tile_size
         self.tile_size = tile_size
 
 
@@ -162,7 +162,11 @@ class RGBImgObsWrapper(ObservationWrapper):
     """
     """
 
 
     def __init__(self, env, tile_size=8):
     def __init__(self, env, tile_size=8):
-        super().__init__(env)
+        super().__init__(env, new_step_api=env.new_step_api)
+
+        # Rendering attributes
+        self.highlight = True
+        self.tile_size = tile_size
 
 
         self.tile_size = tile_size
         self.tile_size = tile_size
 
 
@@ -178,9 +182,7 @@ class RGBImgObsWrapper(ObservationWrapper):
         )
         )
 
 
     def observation(self, obs):
     def observation(self, obs):
-        env = self.unwrapped
-
-        rgb_img = env.render(mode="rgb_array", highlight=True, tile_size=self.tile_size)
+        rgb_img = self.get_full_render()
 
 
         return {**obs, "image": rgb_img}
         return {**obs, "image": rgb_img}
 
 
@@ -192,8 +194,10 @@ class RGBImgPartialObsWrapper(ObservationWrapper):
     """
     """
 
 
     def __init__(self, env, tile_size=8):
     def __init__(self, env, tile_size=8):
-        super().__init__(env)
+        super().__init__(env, new_step_api=env.new_step_api)
 
 
+        # Rendering attributes
+        self.unwrapped.agent_pov = True
         self.tile_size = tile_size
         self.tile_size = tile_size
 
 
         obs_shape = env.observation_space.spaces["image"].shape
         obs_shape = env.observation_space.spaces["image"].shape
@@ -209,9 +213,7 @@ class RGBImgPartialObsWrapper(ObservationWrapper):
         )
         )
 
 
     def observation(self, obs):
     def observation(self, obs):
-        env = self.unwrapped
-
-        rgb_img_partial = env.get_obs_render(obs["image"], tile_size=self.tile_size)
+        rgb_img_partial = self.get_pov_render()
 
 
         return {**obs, "image": rgb_img_partial}
         return {**obs, "image": rgb_img_partial}
 
 
@@ -222,7 +224,7 @@ class FullyObsWrapper(ObservationWrapper):
     """
     """
 
 
     def __init__(self, env):
     def __init__(self, env):
-        super().__init__(env)
+        super().__init__(env, new_step_api=env.new_step_api)
 
 
         new_image_space = spaces.Box(
         new_image_space = spaces.Box(
             low=0,
             low=0,
@@ -257,7 +259,7 @@ class DictObservationSpaceWrapper(ObservationWrapper):
         word_dict is a dictionary of words to use (keys=words, values=indices from 1 to < max_words_in_mission),
         word_dict is a dictionary of words to use (keys=words, values=indices from 1 to < max_words_in_mission),
                   if None, use the Minigrid language
                   if None, use the Minigrid language
         """
         """
-        super().__init__(env)
+        super().__init__(env, new_step_api=env.new_step_api)
 
 
         if word_dict is None:
         if word_dict is None:
             word_dict = self.get_minigrid_words()
             word_dict = self.get_minigrid_words()
@@ -370,7 +372,7 @@ class FlatObsWrapper(ObservationWrapper):
     """
     """
 
 
     def __init__(self, env, maxStrLen=96):
     def __init__(self, env, maxStrLen=96):
-        super().__init__(env)
+        super().__init__(env, new_step_api=env.new_step_api)
 
 
         self.maxStrLen = maxStrLen
         self.maxStrLen = maxStrLen
         self.numCharCodes = 28
         self.numCharCodes = 28
@@ -431,7 +433,7 @@ class ViewSizeWrapper(Wrapper):
     """
     """
 
 
     def __init__(self, env, agent_view_size=7):
     def __init__(self, env, agent_view_size=7):
-        super().__init__(env)
+        super().__init__(env, new_step_api=env.new_step_api)
 
 
         assert agent_view_size % 2 == 1
         assert agent_view_size % 2 == 1
         assert agent_view_size >= 3
         assert agent_view_size >= 3
@@ -466,7 +468,7 @@ class DirectionObsWrapper(ObservationWrapper):
     """
     """
 
 
     def __init__(self, env, type="slope"):
     def __init__(self, env, type="slope"):
-        super().__init__(env)
+        super().__init__(env, new_step_api=env.new_step_api)
         self.goal_position: tuple = None
         self.goal_position: tuple = None
         self.type = type
         self.type = type
 
 
@@ -501,7 +503,7 @@ class SymbolicObsWrapper(ObservationWrapper):
     """
     """
 
 
     def __init__(self, env):
     def __init__(self, env):
-        super().__init__(env)
+        super().__init__(env, new_step_api=env.new_step_api)
 
 
         new_image_space = spaces.Box(
         new_image_space = spaces.Box(
             low=0,
             low=0,

+ 2 - 2
tests/test_envs.py

@@ -109,11 +109,11 @@ def test_render_modes(spec):
 
 
     for mode in env.metadata.get("render_modes", []):
     for mode in env.metadata.get("render_modes", []):
         if mode != "human":
         if mode != "human":
-            new_env = spec.make(new_step_api=True)
+            new_env = spec.make(new_step_api=True, render_mode=mode)
 
 
             new_env.reset()
             new_env.reset()
             new_env.step(new_env.action_space.sample())
             new_env.step(new_env.action_space.sample())
-            new_env.render(mode=mode)
+            new_env.render()
 
 
 
 
 @pytest.mark.parametrize("env_id", ["MiniGrid-DoorKey-6x6-v0"])
 @pytest.mark.parametrize("env_id", ["MiniGrid-DoorKey-6x6-v0"])

+ 64 - 0
tests/test_scripts.py

@@ -0,0 +1,64 @@
+import gym
+import numpy as np
+
+from gym_minigrid.benchmark import benchmark
+from gym_minigrid.manual_control import key_handler, reset
+from gym_minigrid.window import Window
+
+
+def test_benchmark():
+    "Test that the benchmark function works for a specific environment"
+    env_id = "MiniGrid-Empty-16x16-v0"
+    benchmark(env_id, num_resets=10, num_frames=100)
+
+
+def test_window():
+    "Testing the class functions of window.Window. This should locally open a window !"
+    title = "testing window"
+    window = Window(title)
+
+    img = np.random.rand(100, 100, 3)
+    window.show_img(img)
+
+    caption = "testing caption"
+    window.set_caption(caption)
+
+    window.show()
+
+    window.close()
+
+
+def test_manual_control():
+    class FakeRandomKeyboardEvent:
+        active_actions = ["left", "right", "up", " ", "pageup", "pagedown"]
+        reset_action = "backspace"
+        close_action = "escape"
+
+        def __init__(self, active_actions=True, reset_action=False) -> None:
+            if active_actions:
+                self.key = np.random.choice(self.active_actions)
+            elif reset_action:
+                self.key = self.reset_action
+            else:
+                self.key = self.close_action
+
+    env_id = "MiniGrid-Empty-16x16-v0"
+    env = gym.make(env_id, new_step_api=True)
+    window = Window(env_id)
+
+    reset(env, window)
+
+    for i in range(3):  # 3 resets
+        for j in range(20):  # Do 20 steps
+            key_handler(env, window, FakeRandomKeyboardEvent())
+
+        key_handler(
+            env,
+            window,
+            FakeRandomKeyboardEvent(active_actions=False, reset_action=True),
+        )
+
+    # Close the environment
+    key_handler(
+        env, window, FakeRandomKeyboardEvent(active_actions=False, reset_action=False)
+    )