12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576 |
- from __future__ import annotations
- import gymnasium as gym
- import pytest
- from minigrid.core.constants import COLOR_NAMES
- from minigrid.core.world_object import Ball, Box
- TESTING_ENVS = [
- "MiniGrid-ObstructedMaze-2Dlhb",
- "MiniGrid-ObstructedMaze-1Q",
- "MiniGrid-ObstructedMaze-2Q",
- "MiniGrid-ObstructedMaze-Full",
- ]
- def find_ball_room(env):
- for obj in env.unwrapped.grid.grid:
- if isinstance(obj, Ball) and obj.color == COLOR_NAMES[0]:
- return env.unwrapped.room_from_pos(*obj.cur_pos)
- def find_target_key(env, color):
- for obj in env.unwrapped.grid.grid:
- if isinstance(obj, Box) and obj.contains.color == color:
- return True
- return False
- def env_test(env_id, repeats=10000):
- env = gym.make(env_id)
- cnt = 0
- for _ in range(repeats):
- env.reset()
- ball_room = find_ball_room(env)
- ball_room_doors = list(filter(None, ball_room.doors))
- keys_exit = [find_target_key(env, door.color) for door in ball_room_doors]
- if not any(keys_exit):
- cnt += 1
- return (cnt / repeats) * 100
- @pytest.mark.parametrize("env_id", TESTING_ENVS)
- def test_solvable_env(env_id):
- assert env_test(env_id + "-v1") == 0, f"{env_id} is unsolvable."
- def main():
- """
- Test the frequency of unsolvable situation in this environment, including
- MiniGrid-ObstructedMaze-2Dlhb, -1Q, -2Q, and -Full. The reason for the unsolvable
- situation is that in the v0 version of these environments, the box containing
- the key to the door connecting the upper-right room may be covered by the
- blocking ball of the door connecting the lower-right room.
- Note: Covering that occurs in MiniGrid-ObstructedMaze-Full won't lead to an
- unsolvable situation.
- Expected probability of unsolvable situation:
- - MiniGrid-ObstructedMaze-2Dlhb-v0: 1 / 15 = 6.67%
- - MiniGrid-ObstructedMaze-1Q-v0: 1/ 15 = 6.67%
- - MiniGrid-ObstructedMaze-2Q-v0: 1 / 30 = 3.33%
- - MiniGrid-ObstructedMaze-Full-v0: 0%
- """
- for env_id in TESTING_ENVS:
- print(f"{env_id}: {env_test(env_id + '-v0'):.2f}% unsolvable.")
- for env_id in TESTING_ENVS:
- print(f"{env_id}: {env_test(env_id + '-v1'):.2f}% unsolvable.")
- if __name__ == "__main__":
- main()
|