Browse Source

Make sure ObstructedMaze is solvable. (#334)

Co-authored-by: Jiang Haobin <jianghaobin@pku.edu.cn>
H.B. Jiang 2 years ago
parent
commit
83731bd911

+ 44 - 0
minigrid/__init__.py

@@ -472,6 +472,50 @@ def register_minigrid_envs():
         entry_point="minigrid.envs:ObstructedMaze_Full",
     )
 
+    # ObstructedMaze-v1
+    # ----------------------------------------
+
+    register(
+        id="MiniGrid-ObstructedMaze-2Dlhb-v1",
+        entry_point="minigrid.envs.obstructedmaze_v1:ObstructedMaze_Full",
+        kwargs={
+            "agent_room": (2, 1),
+            "key_in_box": True,
+            "blocked": True,
+            "num_quarters": 1,
+            "num_rooms_visited": 4,
+        },
+    )
+
+    register(
+        id="MiniGrid-ObstructedMaze-1Q-v1",
+        entry_point="minigrid.envs.obstructedmaze_v1:ObstructedMaze_Full",
+        kwargs={
+            "agent_room": (1, 1),
+            "key_in_box": True,
+            "blocked": True,
+            "num_quarters": 1,
+            "num_rooms_visited": 5,
+        },
+    )
+
+    register(
+        id="MiniGrid-ObstructedMaze-2Q-v1",
+        entry_point="minigrid.envs.obstructedmaze_v1:ObstructedMaze_Full",
+        kwargs={
+            "agent_room": (2, 1),
+            "key_in_box": True,
+            "blocked": True,
+            "num_quarters": 2,
+            "num_rooms_visited": 11,
+        },
+    )
+
+    register(
+        id="MiniGrid-ObstructedMaze-Full-v1",
+        entry_point="minigrid.envs.obstructedmaze_v1:ObstructedMaze_Full",
+    )
+
     # Playground
     # ----------------------------------------
 

+ 7 - 0
minigrid/envs/obstructedmaze.py

@@ -58,6 +58,9 @@ class ObstructedMazeEnv(RoomGrid):
     "Q" number of quarters that will have doors and keys out of the 9 that the
     map already has.
     "Full" 3x3 maze with "h" and "b" options.
+    "v1" prevents the key from being covered by the blocking ball. Only 2Dlhb, 1Q, 2Q, and Full are
+    updated to v1. Other configurations won't face this issue because there is no blocking ball (1Dl,
+    1Dlh, 2Dl, 2Dlh) or the only blocking ball is added before the key (1Dlhb).
 
     - `MiniGrid-ObstructedMaze-1Dl-v0`
     - `MiniGrid-ObstructedMaze-1Dlh-v0`
@@ -65,9 +68,13 @@ class ObstructedMazeEnv(RoomGrid):
     - `MiniGrid-ObstructedMaze-2Dl-v0`
     - `MiniGrid-ObstructedMaze-2Dlh-v0`
     - `MiniGrid-ObstructedMaze-2Dlhb-v0`
+    - `MiniGrid-ObstructedMaze-2Dlhb-v1`
     - `MiniGrid-ObstructedMaze-1Q-v0`
+    - `MiniGrid-ObstructedMaze-1Q-v1`
     - `MiniGrid-ObstructedMaze-2Q-v0`
+    - `MiniGrid-ObstructedMaze-2Q-v1`
     - `MiniGrid-ObstructedMaze-Full-v0`
+    - `MiniGrid-ObstructedMaze-Full-v1`
 
     """
 

+ 99 - 0
minigrid/envs/obstructedmaze_v1.py

@@ -0,0 +1,99 @@
+from __future__ import annotations
+
+from minigrid.core.constants import DIR_TO_VEC
+from minigrid.core.roomgrid import RoomGrid
+from minigrid.core.world_object import Ball, Box, Key
+from minigrid.envs.obstructedmaze import ObstructedMazeEnv
+
+
+class ObstructedMaze_Full(ObstructedMazeEnv):
+    """
+    A blue ball is hidden in one of the 4 corners of a 3x3 maze. Doors
+    are locked, doors are obstructed by a ball and keys are hidden in
+    boxes.
+
+    All doors and their corresponding blocking balls will be added first,
+    followed by the boxes containing the keys.
+    """
+
+    def __init__(
+        self,
+        agent_room=(1, 1),
+        key_in_box=True,
+        blocked=True,
+        num_quarters=4,
+        num_rooms_visited=25,
+        **kwargs,
+    ):
+        self.agent_room = agent_room
+        self.key_in_box = key_in_box
+        self.blocked = blocked
+        self.num_quarters = num_quarters
+
+        super().__init__(
+            num_rows=3, num_cols=3, num_rooms_visited=num_rooms_visited, **kwargs
+        )
+
+    def _gen_grid(self, width, height):
+        super()._gen_grid(width, height)
+
+        middle_room = (1, 1)
+        # Define positions of "side rooms" i.e. rooms that are neither
+        # corners nor the center.
+        side_rooms = [(2, 1), (1, 2), (0, 1), (1, 0)][: self.num_quarters]
+        for i in range(len(side_rooms)):
+            side_room = side_rooms[i]
+
+            # Add a door between the center room and the side room
+            self.add_door(
+                *middle_room, door_idx=i, color=self.door_colors[i], locked=False
+            )
+
+            for k in [-1, 1]:
+                # Add a door to each side of the side room w/o placing a key
+                self.add_locked_door(
+                    *side_room,
+                    door_idx=(i + k) % 4,
+                    color=self.door_colors[(i + k) % len(self.door_colors)],
+                    blocked=self.blocked,
+                )
+
+            # Add keys after all doors and their blocking balls are added
+            for k in [-1, 1]:
+                self.add_key(
+                    *side_room,
+                    color=self.door_colors[(i + k) % len(self.door_colors)],
+                    key_in_box=self.key_in_box,
+                )
+
+        corners = [(2, 0), (2, 2), (0, 2), (0, 0)][: self.num_quarters]
+        ball_room = self._rand_elem(corners)
+
+        self.obj, _ = self.add_object(
+            ball_room[0], ball_room[1], "ball", color=self.ball_to_find_color
+        )
+        self.place_agent(*self.agent_room)
+
+    def add_locked_door(self, i, j, door_idx=0, color=None, blocked=False):
+        door, door_pos = RoomGrid.add_door(self, i, j, door_idx, color, locked=True)
+
+        if blocked:
+            vec = DIR_TO_VEC[door_idx]
+            blocking_ball = Ball(self.blocking_ball_color) if blocked else None
+            self.grid.set(door_pos[0] - vec[0], door_pos[1] - vec[1], blocking_ball)
+
+        return door, door_pos
+
+    def add_key(
+        self,
+        i,
+        j,
+        color=None,
+        key_in_box=False,
+    ):
+        obj = Key(color)
+        if key_in_box:
+            box = Box(self.box_color)
+            box.contains = obj
+            obj = box
+        self.place_in_room(i, j, obj)

+ 75 - 0
tests/test_obstructed_maze.py

@@ -0,0 +1,75 @@
+from __future__ import annotations
+
+import gymnasium as gym
+import pytest
+
+from minigrid.core.constants import COLOR_NAMES
+from minigrid.core.world_object import Ball, Box
+
+TESTING_ENVS = [
+    "MiniGrid-ObstructedMaze-2Dlhb",
+    "MiniGrid-ObstructedMaze-1Q",
+    "MiniGrid-ObstructedMaze-2Q",
+    "MiniGrid-ObstructedMaze-Full",
+]
+
+
+def find_ball_room(env):
+    for obj in env.grid.grid:
+        if isinstance(obj, Ball) and obj.color == COLOR_NAMES[0]:
+            return env.room_from_pos(*obj.cur_pos)
+
+
+def find_target_key(env, color):
+    for obj in env.grid.grid:
+        if isinstance(obj, Box) and obj.contains.color == color:
+            return True
+    return False
+
+
+def env_test(env_id, repeats=10000):
+    env = gym.make(env_id)
+
+    cnt = 0
+    for _ in range(repeats):
+        env.reset()
+        ball_room = find_ball_room(env)
+        ball_room_doors = list(filter(None, ball_room.doors))
+        keys_exit = [find_target_key(env, door.color) for door in ball_room_doors]
+        if not any(keys_exit):
+            cnt += 1
+
+    return (cnt / repeats) * 100
+
+
+@pytest.mark.parametrize("env_id", TESTING_ENVS)
+def test_solvable_env(env_id):
+    assert env_test(env_id + "-v1") == 0, f"{env_id} is unsolvable."
+
+
+def main():
+    """
+    Test the frequency of unsolvable situation in this environment, including
+    MiniGrid-ObstructedMaze-2Dlhb, -1Q, -2Q, and -Full. The reason for the unsolvable
+    situation is that in the v0 version of these environments, the box containing
+    the key to the door connecting the upper-right room may be covered by the
+    blocking ball of the door connecting the lower-right room.
+
+    Note: Covering that occurs in MiniGrid-ObstructedMaze-Full won't lead to an
+    unsolvable situation.
+
+    Expected probability of unsolvable situation:
+    - MiniGrid-ObstructedMaze-2Dlhb-v0: 1 / 15 = 6.67%
+    - MiniGrid-ObstructedMaze-1Q-v0: 1/ 15 = 6.67%
+    - MiniGrid-ObstructedMaze-2Q-v0: 1 / 30 = 3.33%
+    - MiniGrid-ObstructedMaze-Full-v0: 0%
+    """
+
+    for env_id in TESTING_ENVS:
+        print(f"{env_id}: {env_test(env_id + '-v0'):.2f}% unsolvable.")
+    for env_id in TESTING_ENVS:
+        print(f"{env_id}: {env_test(env_id + '-v1'):.2f}% unsolvable.")
+
+
+if __name__ == "__main__":
+    main()