Browse Source

:hammer: Refactor manual_control (#268)

Co-authored-by: Rodrigo Perez-Vicente <r.l.p.v96@gmail.com>
Mathïs Fédérico 2 years ago
parent
commit
8fdebee79e
6 changed files with 131 additions and 113 deletions
  1. 5 4
      .gitignore
  2. 80 82
      minigrid/manual_control.py
  3. 16 6
      setup.py
  4. 2 1
      test_requirements.txt
  5. 1 0
      tests/test_envs.py
  6. 27 20
      tests/test_scripts.py

+ 5 - 4
.gitignore

@@ -8,15 +8,16 @@ build/*
 dist/*
 .idea/
 
-#docs
-_build/
+# Docs
+_build/*
 .DS_Store
 _site
 .jekyll-cache
 __pycache__
 .vscode/
-docs/environments/**/*.*
-!docs/environments/**/index.md
+/docs/environments/*.md
+!docs/environments/index.md
+!docs/environments/babyAI_index.md
 
 # Virtual environments
 .env

+ 80 - 82
minigrid/manual_control.py

@@ -2,76 +2,82 @@
 
 import gymnasium as gym
 
+from minigrid.minigrid_env import MiniGridEnv
 from minigrid.utils.window import Window
 from minigrid.wrappers import ImgObsWrapper, RGBImgPartialObsWrapper
 
 
-def redraw(window, img):
-    window.show_img(img)
-
-
-def reset(env, window, seed=None):
-    env.reset(seed=seed)
-
-    if hasattr(env, "mission"):
-        print("Mission: %s" % env.mission)
-        window.set_caption(env.mission)
-
-    img = env.get_frame()
-
-    redraw(window, img)
-
-
-def step(env, window, action):
-    obs, reward, terminated, truncated, info = env.step(action)
-    print(f"step={env.step_count}, reward={reward:.2f}")
-
-    if terminated:
-        print("terminated!")
-        reset(env, window)
-    elif truncated:
-        print("truncated!")
-        reset(env, window)
-    else:
-        img = env.get_frame()
-        redraw(window, img)
-
-
-def key_handler(env, window, event):
-    print("pressed", event.key)
-
-    if event.key == "escape":
-        window.close()
-        return
-
-    if event.key == "backspace":
-        reset(env, window)
-        return
-
-    if event.key == "left":
-        step(env, window, env.actions.left)
-        return
-    if event.key == "right":
-        step(env, window, env.actions.right)
-        return
-    if event.key == "up":
-        step(env, window, env.actions.forward)
-        return
-
-    # Spacebar
-    if event.key == " ":
-        step(env, window, env.actions.toggle)
-        return
-    if event.key == "pageup":
-        step(env, window, env.actions.pickup)
-        return
-    if event.key == "pagedown":
-        step(env, window, env.actions.drop)
-        return
-
-    if event.key == "enter":
-        step(env, window, env.actions.done)
-        return
+class ManualControl:
+    def __init__(
+        self,
+        env: MiniGridEnv,
+        agent_view: bool = False,
+        window: Window = None,
+        seed=None,
+    ) -> None:
+        self.env = env
+        self.agent_view = agent_view
+        self.seed = seed
+
+        if window is None:
+            window = Window("minigrid - " + str(env.__class__))
+        self.window = window
+        self.window.reg_key_handler(self.key_handler)
+
+    def start(self):
+        """Start the window display with blocking event loop"""
+        self.reset(self.seed)
+        self.window.show(block=True)
+
+    def step(self, action: MiniGridEnv.Actions):
+        _, reward, terminated, truncated, _ = self.env.step(action)
+        print(f"step={self.env.step_count}, reward={reward:.2f}")
+
+        if terminated:
+            print("terminated!")
+            self.reset(self.seed)
+        elif truncated:
+            print("truncated!")
+            self.reset(self.seed)
+        else:
+            self.redraw()
+
+    def redraw(self):
+        frame = self.env.get_frame(agent_pov=self.agent_view)
+        self.window.show_img(frame)
+
+    def reset(self, seed=None):
+        self.env.reset(seed=seed)
+
+        if hasattr(self.env, "mission"):
+            print("Mission: %s" % self.env.mission)
+            self.window.set_caption(self.env.mission)
+
+        self.redraw()
+
+    def key_handler(self, event):
+        key: str = event.key
+        print("pressed", key)
+
+        if key == "escape":
+            self.window.close()
+            return
+        if key == "backspace":
+            self.reset()
+            return
+
+        key_to_action = {
+            "left": MiniGridEnv.Actions.left,
+            "right": MiniGridEnv.Actions.right,
+            "up": MiniGridEnv.Actions.forward,
+            " ": MiniGridEnv.Actions.toggle,
+            "pageup": MiniGridEnv.Actions.pickup,
+            "pagedown": MiniGridEnv.Actions.drop,
+            "enter": MiniGridEnv.Actions.done,
+        }
+
+        action = key_to_action[key]
+        self.step(action)
 
 
 if __name__ == "__main__":
@@ -85,13 +91,13 @@ if __name__ == "__main__":
         "--seed",
         type=int,
         help="random seed to generate the environment with",
-        default=-1,
+        default=None,
     )
     parser.add_argument(
-        "--tile_size", type=int, help="size at which to render tiles", default=32
+        "--tile-size", type=int, help="size at which to render tiles", default=32
     )
     parser.add_argument(
-        "--agent_view",
+        "--agent-view",
         default=False,
         help="draw the agent sees (partially observable view)",
         action="store_true",
@@ -99,20 +105,12 @@ if __name__ == "__main__":
 
     args = parser.parse_args()
 
-    env = gym.make(
-        args.env,
-        tile_size=args.tile_size,
-    )
+    env: MiniGridEnv = gym.make(args.env, tile_size=args.tile_size)
 
     if args.agent_view:
-        env = RGBImgPartialObsWrapper(env)
+        print("Using agent view")
+        env = RGBImgPartialObsWrapper(env, env.tile_size)
         env = ImgObsWrapper(env)
 
-    window = Window("minigrid - " + args.env)
-    window.reg_key_handler(lambda event: key_handler(env, window, event))
-
-    seed = None if args.seed == -1 else args.seed
-    reset(env, window, seed)
-
-    # Blocking event loop
-    window.show(block=True)
+    manual_control = ManualControl(env, agent_view=args.agent_view, seed=args.seed)
+    manual_control.start()

+ 16 - 6
setup.py

@@ -30,8 +30,22 @@ def get_version():
     raise RuntimeError("bad version data in __init__.py")
 
 
+def get_requirements():
+    """Gets the description from the readme."""
+    with open("requirements.txt") as reqs_file:
+        reqs = reqs_file.readlines()
+    return reqs
+
+
+def get_tests_requirements():
+    """Gets the description from the readme."""
+    with open("test_requirements.txt") as test_reqs_file:
+        test_reqs = test_reqs_file.readlines()
+    return test_reqs
+
+
 # pytest is pinned to 7.0.1 as this is last version for python 3.6
-extras = {"testing": ["pytest==7.0.1"]}
+extras = {"testing": get_tests_requirements()}
 
 version = get_version()
 header_count, long_description = get_description()
@@ -51,11 +65,7 @@ setup(
     python_requires=">=3.7, <3.11",
     packages=[package for package in find_packages() if package.startswith("minigrid")],
     include_package_data=True,
-    install_requires=[
-        "gymnasium>=0.26",
-        "numpy>=1.18.0",
-        "matplotlib>=3.0",
-    ],
+    install_requires=get_requirements(),
     classifiers=[
         "Development Status :: 5 - Production/Stable",
         "Programming Language :: Python :: 3",

+ 2 - 1
test_requirements.txt

@@ -1 +1,2 @@
-pytest==7.0.1
+pytest>=7.0.1
+pytest-mock>=3.10.0

+ 1 - 0
tests/test_envs.py

@@ -169,6 +169,7 @@ def test_max_steps_argument(env_spec):
     ids=[spec.id for spec in all_testing_env_specs],
 )
 def test_pickle_env(env_spec):
+    """Test that all environments are picklable."""
     env: gym.Env = env_spec.make()
     pickled_env: gym.Env = pickle.loads(pickle.dumps(env))
 

+ 27 - 20
tests/test_scripts.py

@@ -1,8 +1,10 @@
 import gymnasium as gym
 import numpy as np
+from pytest_mock import MockerFixture
 
 from minigrid.benchmark import benchmark
-from minigrid.manual_control import key_handler, reset
+from minigrid.manual_control import ManualControl
+from minigrid.minigrid_env import MiniGridEnv
 from minigrid.utils.window import Window
 
 
@@ -28,37 +30,42 @@ def test_window():
     window.close()
 
 
-def test_manual_control():
+def test_manual_control(mocker: MockerFixture):
     class FakeRandomKeyboardEvent:
         active_actions = ["left", "right", "up", " ", "pageup", "pagedown"]
         reset_action = "backspace"
         close_action = "escape"
 
-        def __init__(self, active_actions=True, reset_action=False) -> None:
-            if active_actions:
-                self.key = np.random.choice(self.active_actions)
-            elif reset_action:
+        def __init__(self, reset: bool = False, close: bool = False) -> None:
+            if reset:
                 self.key = self.reset_action
-            else:
+                return
+            if close:
                 self.key = self.close_action
+                return
+            self.key = np.random.choice(self.active_actions)
 
     env_id = "MiniGrid-Empty-16x16-v0"
-    env = gym.make(env_id)
-    window = Window(env_id)
-
-    reset(env, window)
+    env: MiniGridEnv = gym.make(env_id)
+    window = mocker.MagicMock()
+    window.close = mocker.MagicMock()
+    window.set_caption = mocker.MagicMock()
+    manual_control = ManualControl(env, window=window)
 
     for i in range(3):  # 3 resets
+        mission = f"Mission {i}"
+        env.mission = mission
+        manual_control.reset()
+        window.set_caption.assert_called_with(mission)
         for j in range(20):  # Do 20 steps
-            key_handler(env, window, FakeRandomKeyboardEvent())
+            manual_control.key_handler(FakeRandomKeyboardEvent())
+
+        fake_event = FakeRandomKeyboardEvent(reset=True)
+        manual_control.key_handler(fake_event)
 
-        key_handler(
-            env,
-            window,
-            FakeRandomKeyboardEvent(active_actions=False, reset_action=True),
-        )
+    window.close.assert_not_called()
 
     # Close the environment
-    key_handler(
-        env, window, FakeRandomKeyboardEvent(active_actions=False, reset_action=False)
-    )
+    fake_event = FakeRandomKeyboardEvent(close=True)
+    manual_control.key_handler(fake_event)
+    window.close.assert_called()