瀏覽代碼

Add pyright to pre-commit

	* add pyright pre-commit

	* add pyproject.toml

	* fix process_vis

	* fix unbounded door state

	* fix character unbounded

	* fix unbounded obj type

	* fix window None

	* type annotate agent pos and dir

	* remove unecessary conditional

	* type annotate flatobswrapper

	* type annotate goal_position
Rodrigo Perez-Vicente 2 年之前
父節點
當前提交
3aab399ab1

+ 11 - 11
.pre-commit-config.yaml

@@ -40,14 +40,14 @@ repos:
     hooks:
     hooks:
       - id: pyupgrade
       - id: pyupgrade
         args: ["--py37-plus"]
         args: ["--py37-plus"]
-#  - repo: local
-#    hooks:
-#      - id: pyright
-#        name: pyright
-#        entry: pyright
-#        language: node
-#        pass_filenames: false
-#        types: [python]
-#        additional_dependencies: ["pyright"]
-#        args:
-#          - --project=pyproject.toml
+ - repo: local
+   hooks:
+     - id: pyright
+       name: pyright
+       entry: pyright
+       language: node
+       pass_filenames: false
+       types: [python]
+       additional_dependencies: ["pyright"]
+       args:
+         - --project=pyproject.toml

+ 4 - 0
gym_minigrid/envs/fetch.py

@@ -41,6 +41,10 @@ class FetchEnv(MiniGridEnv):
                 obj = Key(objColor)
                 obj = Key(objColor)
             elif objType == "ball":
             elif objType == "ball":
                 obj = Ball(objColor)
                 obj = Ball(objColor)
+            else:
+                raise ValueError(
+                    "{} object type given. Object type can only be of values key and ball.".format(objType)
+                )
 
 
             self.place_obj(obj)
             self.place_obj(obj)
             objs.append(obj)
             objs.append(obj)

+ 3 - 0
gym_minigrid/envs/fourrooms.py

@@ -65,7 +65,10 @@ class FourRoomsEnv(MiniGridEnv):
             self.place_obj(Goal())
             self.place_obj(Goal())
 
 
         self.mission = "reach the goal"
         self.mission = "reach the goal"
+<<<<<<< HEAD
         self.mission = "Reach the goal"
         self.mission = "Reach the goal"
+=======
+>>>>>>> Add pyright to pre-commit
 
 
     def step(self, action):
     def step(self, action):
         obs, reward, done, info = MiniGridEnv.step(self, action)
         obs, reward, done, info = MiniGridEnv.step(self, action)

+ 4 - 0
gym_minigrid/envs/gotoobject.py

@@ -46,6 +46,10 @@ class GoToObjectEnv(MiniGridEnv):
                 obj = Ball(objColor)
                 obj = Ball(objColor)
             elif objType == "box":
             elif objType == "box":
                 obj = Box(objColor)
                 obj = Box(objColor)
+            else:
+                raise ValueError(
+                    "{} object type given. Object type can only be of values key, ball and box.".format(objType)
+                )
 
 
             pos = self.place_obj(obj)
             pos = self.place_obj(obj)
             objs.append((objType, objColor))
             objs.append((objType, objColor))

+ 1 - 1
gym_minigrid/envs/obstructedmaze.py

@@ -71,7 +71,7 @@ class ObstructedMazeEnv(RoomGrid):
         if locked:
         if locked:
             obj = Key(door.color)
             obj = Key(door.color)
             if key_in_box:
             if key_in_box:
-                box = Box(self.box_color) if key_in_box else None
+                box = Box(self.box_color)
                 box.contains = obj
                 box.contains = obj
                 obj = box
                 obj = box
             self.place_in_room(i, j, obj)
             self.place_in_room(i, j, obj)

+ 4 - 0
gym_minigrid/envs/playground.py

@@ -62,6 +62,10 @@ class PlaygroundEnv(MiniGridEnv):
                 obj = Ball(objColor)
                 obj = Ball(objColor)
             elif objType == "box":
             elif objType == "box":
                 obj = Box(objColor)
                 obj = Box(objColor)
+            else:
+                raise ValueError(
+                    "{} object type given. Object type can only be of values key, ball and box.".format(objType)
+                )
             self.place_obj(obj)
             self.place_obj(obj)
 
 
         # No explicit mission in this environment
         # No explicit mission in this environment

+ 4 - 0
gym_minigrid/envs/putnear.py

@@ -57,6 +57,10 @@ class PutNearEnv(MiniGridEnv):
                 obj = Ball(objColor)
                 obj = Ball(objColor)
             elif objType == "box":
             elif objType == "box":
                 obj = Box(objColor)
                 obj = Box(objColor)
+            else:
+                raise ValueError(
+                    "{} object type given. Object type can only be of values key, ball and box.".format(objType)
+                )
 
 
             pos = self.place_obj(obj, reject_fn=near_obj)
             pos = self.place_obj(obj, reject_fn=near_obj)
 
 

+ 14 - 15
gym_minigrid/minigrid.py

@@ -252,8 +252,9 @@ class Door(WorldObj):
             state = 0
             state = 0
         elif self.is_locked:
         elif self.is_locked:
             state = 2
             state = 2
-        elif not self.is_open:
-            state = 1
+        # if door is closed and unlocked
+        else:
+            state = 1 
 
 
         return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], state)
         return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], state)
 
 
@@ -580,8 +581,8 @@ class Grid:
 
 
         return grid, vis_mask
         return grid, vis_mask
 
 
-    def process_vis(grid, agent_pos):
-        mask = np.zeros(shape=(grid.width, grid.height), dtype=bool)
+    def process_vis(self, agent_pos):
+        mask = np.zeros(shape=(self.width, self.height), dtype=bool)
 
 
         mask[agent_pos[0], agent_pos[1]] = True
         mask[agent_pos[0], agent_pos[1]] = True
 
 
@@ -590,7 +591,7 @@ class Grid:
                 if not mask[i, j]:
                 if not mask[i, j]:
                     continue
                     continue
 
 
-                cell = grid.get(i, j)
+                cell = self.get(i, j)
                 if cell and not cell.see_behind():
                 if cell and not cell.see_behind():
                     continue
                     continue
 
 
@@ -599,11 +600,11 @@ class Grid:
                     mask[i + 1, j - 1] = True
                     mask[i + 1, j - 1] = True
                     mask[i, j - 1] = True
                     mask[i, j - 1] = True
 
 
-            for i in reversed(range(1, grid.width)):
+            for i in reversed(range(1, self.width)):
                 if not mask[i, j]:
                 if not mask[i, j]:
                     continue
                     continue
 
 
-                cell = grid.get(i, j)
+                cell = self.get(i, j)
                 if cell and not cell.see_behind():
                 if cell and not cell.see_behind():
                     continue
                     continue
 
 
@@ -612,10 +613,10 @@ class Grid:
                     mask[i - 1, j - 1] = True
                     mask[i - 1, j - 1] = True
                     mask[i, j - 1] = True
                     mask[i, j - 1] = True
 
 
-        for j in range(0, grid.height):
-            for i in range(0, grid.width):
+        for j in range(0, self.height):
+            for i in range(0, self.width):
                 if not mask[i, j]:
                 if not mask[i, j]:
-                    grid.set(i, j, None)
+                    self.set(i, j, None)
 
 
         return mask
         return mask
 
 
@@ -703,9 +704,6 @@ class MiniGridEnv(gym.Env):
         # Range of possible rewards
         # Range of possible rewards
         self.reward_range = (0, 1)
         self.reward_range = (0, 1)
 
 
-        # Window to use for human rendering mode
-        self.window = None
-
         # Environment configuration
         # Environment configuration
         self.width = width
         self.width = width
         self.height = height
         self.height = height
@@ -722,8 +720,9 @@ class MiniGridEnv(gym.Env):
     def reset(self, *, seed=None, return_info=False, options=None):
     def reset(self, *, seed=None, return_info=False, options=None):
         super().reset(seed=seed)
         super().reset(seed=seed)
         # Current position and direction of the agent
         # Current position and direction of the agent
-        self.agent_pos = None
-        self.agent_dir = None
+        NDArrayInt = npt.NDArray[np.int_]
+        self.agent_pos: NDArrayInt = None
+        self.agent_dir: int = None
 
 
         # Generate a new random grid at the start of each episode
         # Generate a new random grid at the start of each episode
         self._gen_grid(self.width, self.height)
         self._gen_grid(self.width, self.height)

+ 3 - 0
gym_minigrid/roomgrid.py

@@ -1,6 +1,7 @@
 from gym_minigrid.minigrid import COLOR_NAMES, Ball, Box, Door, Grid, Key, MiniGridEnv
 from gym_minigrid.minigrid import COLOR_NAMES, Ball, Box, Door, Grid, Key, MiniGridEnv
 
 
 
 
+
 def reject_next_to(env, pos):
 def reject_next_to(env, pos):
     """
     """
     Function to filter out object positions that are right next to
     Function to filter out object positions that are right next to
@@ -203,6 +204,8 @@ class RoomGrid(MiniGridEnv):
             obj = Ball(color)
             obj = Ball(color)
         elif kind == "box":
         elif kind == "box":
             obj = Box(color)
             obj = Box(color)
+        else:
+            raise "{} object kind is not available in this environment.".format(kind)
 
 
         return self.place_in_room(i, j, obj)
         return self.place_in_room(i, j, obj)
 
 

+ 2 - 3
gym_minigrid/window.py

@@ -7,15 +7,14 @@ except ImportError:
     )
     )
 
 
 
 
+
 class Window:
 class Window:
     """
     """
     Window to draw a gridworld instance using Matplotlib
     Window to draw a gridworld instance using Matplotlib
     """
     """
 
 
     def __init__(self, title):
     def __init__(self, title):
-        self.fig = None
-
-        self.imshow_obj = None
+        self.no_image_shown = True
 
 
         # Create the figure and axes
         # Create the figure and axes
         self.fig, self.ax = plt.subplots()
         self.fig, self.ax = plt.subplots()

+ 17 - 16
gym_minigrid/wrappers.py

@@ -8,8 +8,10 @@ from gym import spaces
 
 
 from gym_minigrid.minigrid import COLOR_TO_IDX, OBJECT_TO_IDX, STATE_TO_IDX, Goal
 from gym_minigrid.minigrid import COLOR_TO_IDX, OBJECT_TO_IDX, STATE_TO_IDX, Goal
 
 
+from gym_minigrid.minigrid import COLOR_TO_IDX, OBJECT_TO_IDX, STATE_TO_IDX, Goal
+
 
 
-class ReseedWrapper(gym.Wrapper):
+class ReseedWrapper(Wrapper):
     """
     """
     Wrapper to always regenerate an environment with the same set of seeds.
     Wrapper to always regenerate an environment with the same set of seeds.
     This can be used to force an environment to always keep the same
     This can be used to force an environment to always keep the same
@@ -31,7 +33,7 @@ class ReseedWrapper(gym.Wrapper):
         return obs, reward, done, info
         return obs, reward, done, info
 
 
 
 
-class ActionBonus(gym.Wrapper):
+class ActionBonus(Wrapper):
     """
     """
     Wrapper which adds an exploration bonus.
     Wrapper which adds an exploration bonus.
     This is a reward to encourage exploration of less
     This is a reward to encourage exploration of less
@@ -66,7 +68,7 @@ class ActionBonus(gym.Wrapper):
         return self.env.reset(**kwargs)
         return self.env.reset(**kwargs)
 
 
 
 
-class StateBonus(gym.Wrapper):
+class StateBonus(Wrapper):
     """
     """
     Adds an exploration bonus based on which positions
     Adds an exploration bonus based on which positions
     are visited on the grid.
     are visited on the grid.
@@ -102,7 +104,7 @@ class StateBonus(gym.Wrapper):
         return self.env.reset(**kwargs)
         return self.env.reset(**kwargs)
 
 
 
 
-class ImgObsWrapper(gym.ObservationWrapper):
+class ImgObsWrapper(ObservationWrapper):
     """
     """
     Use the image as the only observation output, no language/mission.
     Use the image as the only observation output, no language/mission.
     """
     """
@@ -115,7 +117,7 @@ class ImgObsWrapper(gym.ObservationWrapper):
         return obs["image"]
         return obs["image"]
 
 
 
 
-class OneHotPartialObsWrapper(gym.ObservationWrapper):
+class OneHotPartialObsWrapper(ObservationWrapper):
     """
     """
     Wrapper to get a one-hot encoding of a partially observable
     Wrapper to get a one-hot encoding of a partially observable
     agent view as observation.
     agent view as observation.
@@ -155,7 +157,7 @@ class OneHotPartialObsWrapper(gym.ObservationWrapper):
         return {**obs, "image": out}
         return {**obs, "image": out}
 
 
 
 
-class RGBImgObsWrapper(gym.ObservationWrapper):
+class RGBImgObsWrapper(ObservationWrapper):
     """
     """
     Wrapper to use fully observable RGB image as observation,
     Wrapper to use fully observable RGB image as observation,
     This can be used to have the agent to solve the gridworld in pixel space.
     This can be used to have the agent to solve the gridworld in pixel space.
@@ -187,7 +189,7 @@ class RGBImgObsWrapper(gym.ObservationWrapper):
         return {**obs, "image": rgb_img}
         return {**obs, "image": rgb_img}
 
 
 
 
-class RGBImgPartialObsWrapper(gym.ObservationWrapper):
+class RGBImgPartialObsWrapper(ObservationWrapper):
     """
     """
     Wrapper to use partially observable RGB image as observation.
     Wrapper to use partially observable RGB image as observation.
     This can be used to have the agent to solve the gridworld in pixel space.
     This can be used to have the agent to solve the gridworld in pixel space.
@@ -218,7 +220,7 @@ class RGBImgPartialObsWrapper(gym.ObservationWrapper):
         return {**obs, "image": rgb_img_partial}
         return {**obs, "image": rgb_img_partial}
 
 
 
 
-class FullyObsWrapper(gym.ObservationWrapper):
+class FullyObsWrapper(ObservationWrapper):
     """
     """
     Fully observable gridworld using a compact grid encoding
     Fully observable gridworld using a compact grid encoding
     """
     """
@@ -247,7 +249,7 @@ class FullyObsWrapper(gym.ObservationWrapper):
         return {**obs, "image": full_grid}
         return {**obs, "image": full_grid}
 
 
 
 
-class DictObservationSpaceWrapper(gym.ObservationWrapper):
+class DictObservationSpaceWrapper(ObservationWrapper):
     """
     """
     Transforms the observation space (that has a textual component) to a fully numerical observation space,
     Transforms the observation space (that has a textual component) to a fully numerical observation space,
     where the textual instructions are replaced by arrays representing the indices of each word in a fixed vocabulary.
     where the textual instructions are replaced by arrays representing the indices of each word in a fixed vocabulary.
@@ -365,7 +367,7 @@ class DictObservationSpaceWrapper(gym.ObservationWrapper):
         return obs
         return obs
 
 
 
 
-class FlatObsWrapper(gym.ObservationWrapper):
+class FlatObsWrapper(ObservationWrapper):
     """
     """
     Encode mission strings using a one-hot scheme,
     Encode mission strings using a one-hot scheme,
     and combine these with observed images into one flat array
     and combine these with observed images into one flat array
@@ -387,8 +389,7 @@ class FlatObsWrapper(gym.ObservationWrapper):
             dtype="uint8",
             dtype="uint8",
         )
         )
 
 
-        self.cachedStr = None
-        self.cachedArray = None
+        self.cachedStr: str = None
 
 
     def observation(self, obs):
     def observation(self, obs):
         image = obs["image"]
         image = obs["image"]
@@ -421,7 +422,7 @@ class FlatObsWrapper(gym.ObservationWrapper):
         return obs
         return obs
 
 
 
 
-class ViewSizeWrapper(gym.Wrapper):
+class ViewSizeWrapper(Wrapper):
     """
     """
     Wrapper to customize the agent field of view size.
     Wrapper to customize the agent field of view size.
     This cannot be used with fully observable wrappers.
     This cannot be used with fully observable wrappers.
@@ -456,7 +457,7 @@ class ViewSizeWrapper(gym.Wrapper):
         return {**obs, "image": image}
         return {**obs, "image": image}
 
 
 
 
-class DirectionObsWrapper(gym.ObservationWrapper):
+class DirectionObsWrapper(ObservationWrapper):
     """
     """
     Provides the slope/angular direction to the goal with the observations as modeled by (y2 - y2 )/( x2 - x1)
     Provides the slope/angular direction to the goal with the observations as modeled by (y2 - y2 )/( x2 - x1)
     type = {slope , angle}
     type = {slope , angle}
@@ -464,7 +465,7 @@ class DirectionObsWrapper(gym.ObservationWrapper):
 
 
     def __init__(self, env, type="slope"):
     def __init__(self, env, type="slope"):
         super().__init__(env)
         super().__init__(env)
-        self.goal_position = None
+        self.goal_position: tuple = None
         self.type = type
         self.type = type
 
 
     def reset(self):
     def reset(self):
@@ -490,7 +491,7 @@ class DirectionObsWrapper(gym.ObservationWrapper):
         return obs
         return obs
 
 
 
 
-class SymbolicObsWrapper(gym.ObservationWrapper):
+class SymbolicObsWrapper(ObservationWrapper):
     """
     """
     Fully observable grid with a symbolic state representation.
     Fully observable grid with a symbolic state representation.
     The symbol is a triple of (X, Y, IDX), where X and Y are
     The symbol is a triple of (X, Y, IDX), where X and Y are

+ 4 - 0
manual_control.py

@@ -28,7 +28,11 @@ def reset():
 
 
 def step(action):
 def step(action):
     obs, reward, done, info = env.step(action)
     obs, reward, done, info = env.step(action)
+<<<<<<< HEAD
     print(f"step={env.step_count}, reward={reward:.2f}")
     print(f"step={env.step_count}, reward={reward:.2f}")
+=======
+    print("step={}, reward={:.2f}".format(env.step_count, reward))
+>>>>>>> Add pyright to pre-commit
 
 
     if done:
     if done:
         print("done!")
         print("done!")

+ 35 - 0
pyproject.toml

@@ -0,0 +1,35 @@
+[tool.pyright]
+
+include = [
+    "gym_minigrid/**",
+]
+
+exclude = [
+    "**/node_modules",
+    "**/__pycache__",
+
+   #"gym_minigrid/**",
+]
+
+strict = [
+
+]
+
+typeCheckingMode = "basic"
+pythonVersion = "3.7"
+typeshedPath = "typeshed"
+enableTypeIgnoreComments = true
+
+# This is required as the CI pre-commit does not download the module (i.e. numpy)
+#   Therefore, we have to ignore missing imports
+reportMissingImports = "none"
+
+reportUnknownMemberType = "none"
+reportUnknownParameterType = "none"
+reportUnknownVariableType = "none"
+reportUnknownArgumentType = "none"
+reportPrivateUsage = "warning"
+reportUntypedFunctionDecorator = "none"
+reportMissingTypeStubs = false
+reportUnboundVariable = "warning"
+reportGeneralTypeIssues ="none"

+ 4 - 0
run_tests.py

@@ -139,7 +139,11 @@ for env_idx, env_name in enumerate(env_list):
         obs_space, wrapper_name = env.observation_space, wrapper.__name__
         obs_space, wrapper_name = env.observation_space, wrapper.__name__
         assert isinstance(
         assert isinstance(
             obs_space, spaces.Dict
             obs_space, spaces.Dict
+<<<<<<< HEAD
         ), f"Observation space for {wrapper_name} is not a Dict: {obs_space}."
         ), f"Observation space for {wrapper_name} is not a Dict: {obs_space}."
+=======
+        ), "Observation space for {} is not a Dict: {}.".format(wrapper_name, obs_space)
+>>>>>>> Add pyright to pre-commit
         # This should not fail either
         # This should not fail either
         ImgObsWrapper(env)
         ImgObsWrapper(env)
         env.reset()
         env.reset()

+ 0 - 0
tests/__init__.py


+ 0 - 0
tests/envs/__init__.py


+ 103 - 0
tests/envs/test_envs.py

@@ -0,0 +1,103 @@
+import gym
+import pytest
+from gym.envs.registration import EnvSpec
+from gym.utils.env_checker import check_env
+
+from tests.envs.utils import all_testing_env_specs, assert_equals
+
+# This runs a smoketest on each official registered env. We may want
+# to try also running environments which are not officially registered envs.
+IGNORE_WARNINGS = [
+    "Agent's minimum observation space value is -infinity. This is probably too low.",
+    "Agent's maximum observation space value is infinity. This is probably too high.",
+    "We recommend you to use a symmetric and normalized Box action space (range=[-1, 1]) https://stable-baselines3.readthedocs.io/en/master/guide/rl_tips.html",
+]
+IGNORE_WARNINGS = [f"\x1b[33mWARN: {message}\x1b[0m" for message in IGNORE_WARNINGS]
+
+
+@pytest.mark.parametrize(
+    "spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
+)
+def test_env(spec):
+    # Capture warnings
+    env = spec.make(disable_env_checker=True).unwrapped
+
+    # Test if env adheres to Gym API
+    with pytest.warns(None) as warnings:
+        check_env(env)
+
+    for warning in warnings.list:
+        if warning.message.args[0] not in IGNORE_WARNINGS:
+            raise gym.error.Error(f"Unexpected warning: {warning.message}")
+
+
+# Note that this precludes running this test in multiple threads.
+# However, we probably already can't do multithreading due to some environments.
+SEED = 0
+NUM_STEPS = 50
+
+
+@pytest.mark.parametrize(
+    "env_spec", all_testing_env_specs, ids=[env.id for env in all_testing_env_specs]
+)
+def test_env_determinism_rollout(env_spec: EnvSpec):
+    """Run a rollout with two environments and assert equality.
+
+    This test run a rollout of NUM_STEPS steps with two environments
+    initialized with the same seed and assert that:
+
+    - observation after first reset are the same
+    - same actions are sampled by the two envs
+    - observations are contained in the observation space
+    - obs, rew, done and info are equals between the two envs
+    """
+    # Don't check rollout equality if it's a nondeterministic environment.
+    if env_spec.nondeterministic is True:
+        return
+
+    env_1 = env_spec.make(disable_env_checker=True)
+    env_2 = env_spec.make(disable_env_checker=True)
+
+    initial_obs_1 = env_1.reset(seed=SEED)
+    initial_obs_2 = env_2.reset(seed=SEED)
+    assert_equals(initial_obs_1, initial_obs_2)
+
+    env_1.action_space.seed(SEED)
+
+    for time_step in range(NUM_STEPS):
+        # We don't evaluate the determinism of actions
+        action = env_1.action_space.sample()
+
+        obs_1, rew_1, done_1, info_1 = env_1.step(action)
+        obs_2, rew_2, done_2, info_2 = env_2.step(action)
+
+        assert_equals(obs_1, obs_2, f"[{time_step}] ")
+        assert env_1.observation_space.contains(
+            obs_1
+        )  # obs_2 verified by previous assertion
+
+        assert rew_1 == rew_2, f"[{time_step}] reward 1={rew_1}, reward 2={rew_2}"
+        assert done_1 == done_2, f"[{time_step}] done 1={done_1}, done 2={done_2}"
+        assert_equals(info_1, info_2, f"[{time_step}] ")
+
+        if done_1:  # done_2 verified by previous assertion
+            env_1.reset(seed=SEED)
+            env_2.reset(seed=SEED)
+
+    env_1.close()
+    env_2.close()
+
+
+@pytest.mark.parametrize(
+    "spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
+)
+def test_render_modes(spec):
+    env = spec.make()
+
+    for mode in env.metadata.get("render_modes", []):
+        if mode != "human":
+            new_env = spec.make(render_mode=mode)
+
+            new_env.reset()
+            new_env.step(new_env.action_space.sample())
+            new_env.render()

+ 55 - 0
tests/envs/utils.py

@@ -0,0 +1,55 @@
+"""Finds all the specs that we can test with"""
+from typing import Optional
+
+import gym
+import numpy as np
+from gym import logger
+from gym.envs.registration import EnvSpec
+
+
+def try_make_env(env_spec: EnvSpec) -> Optional[gym.Env]:
+    """Tries to make the environment showing if it is possible. Warning the environments have no wrappers, including time limit and order enforcing."""
+    try:
+        return env_spec.make(disable_env_checker=True).unwrapped
+    except ImportError as e:
+        logger.warn(f"Not testing {env_spec.id} due to error: {e}")
+        return None
+
+
+# Tries to make all gym_minigrid environment to test with
+all_testing_initialised_envs = list(
+    filter(
+        None,
+        [
+            try_make_env(env_spec)
+            for env_spec in gym.envs.registry.values()
+            if env_spec.entry_point.startswith("gym_minigrid.envs")
+        ],
+    )
+)
+all_testing_env_specs = [env.spec for env in all_testing_initialised_envs]
+
+
+def assert_equals(a, b, prefix=None):
+    """Assert equality of data structures `a` and `b`.
+
+    Args:
+        a: first data structure
+        b: second data structure
+        prefix: prefix for failed assertion message for types and dicts
+    """
+    assert type(a) == type(b), f"{prefix}Differing types: {a} and {b}"
+    if isinstance(a, dict):
+        assert list(a.keys()) == list(b.keys()), f"{prefix}Key sets differ: {a} and {b}"
+
+        for k in a.keys():
+            v_a = a[k]
+            v_b = b[k]
+            assert_equals(v_a, v_b)
+    elif isinstance(a, np.ndarray):
+        np.testing.assert_array_equal(a, b)
+    elif isinstance(a, tuple):
+        for elem_from_a, elem_from_b in zip(a, b):
+            assert_equals(elem_from_a, elem_from_b)
+    else:
+        assert a == b