瀏覽代碼

Add doctest (#330)

Mark Towers 2 年之前
父節點
當前提交
6cd169375e
共有 6 個文件被更改,包括 73 次插入80 次删除
  1. 2 0
      .github/workflows/build.yml
  2. 13 14
      minigrid/__init__.py
  3. 6 5
      minigrid/core/mission.py
  4. 26 18
      minigrid/minigrid_env.py
  5. 25 42
      minigrid/wrappers.py
  6. 1 1
      tests/test_wrappers.py

+ 2 - 0
.github/workflows/build.yml

@@ -18,3 +18,5 @@ jobs:
              --tag minigrid-docker .      
              --tag minigrid-docker .      
       - name: Run tests
       - name: Run tests
         run: docker run minigrid-docker pytest
         run: docker run minigrid-docker pytest
+      - name: Run doctest
+        run: docker run minigrid-docker pytest --doctest-modules minigrid/

+ 13 - 14
minigrid/__init__.py

@@ -6,6 +6,19 @@ from minigrid import minigrid_env, wrappers
 from minigrid.core import roomgrid
 from minigrid.core import roomgrid
 from minigrid.core.world_object import Wall
 from minigrid.core.world_object import Wall
 
 
+__version__ = "2.1.1"
+
+
+try:
+    import sys
+
+    from farama_notifications import notifications
+
+    if "minigrid" in notifications and __version__ in notifications["minigrid"]:
+        print(notifications["minigrid"][__version__], file=sys.stderr)
+except Exception:  # nosec
+    pass
+
 
 
 def register_minigrid_envs():
 def register_minigrid_envs():
     # BlockedUnlockPickup
     # BlockedUnlockPickup
@@ -1071,17 +1084,3 @@ def register_minigrid_envs():
         id="BabyAI-BossLevelNoUnlock-v0",
         id="BabyAI-BossLevelNoUnlock-v0",
         entry_point="minigrid.envs.babyai:BossLevelNoUnlock",
         entry_point="minigrid.envs.babyai:BossLevelNoUnlock",
     )
     )
-
-
-__version__ = "2.1.0"
-
-
-try:
-    import sys
-
-    from farama_notifications import notifications
-
-    if "minigrid" in notifications and __version__ in notifications["minigrid"]:
-        print(notifications["minigrid"][__version__], file=sys.stderr)
-except Exception:  # nosec
-    pass

+ 6 - 5
minigrid/core/mission.py

@@ -16,13 +16,14 @@ class MissionSpace(spaces.Space[str]):
     The space allows generating random mission strings constructed with an input placeholder list.
     The space allows generating random mission strings constructed with an input placeholder list.
     Example Usage::
     Example Usage::
         >>> observation_space = MissionSpace(mission_func=lambda color: f"Get the {color} ball.",
         >>> observation_space = MissionSpace(mission_func=lambda color: f"Get the {color} ball.",
-                                                ordered_placeholders=[["green", "blue"]])
+        ...                                  ordered_placeholders=[["green", "blue"]])
+        >>> _ = observation_space.seed(123)
         >>> observation_space.sample()
         >>> observation_space.sample()
-            "Get the green ball."
-        >>> observation_space = MissionSpace(mission_func=lambda : "Get the ball.".,
-                                                ordered_placeholders=None)
+        'Get the green ball.'
+        >>> observation_space = MissionSpace(mission_func=lambda : "Get the ball.",
+        ...                                  ordered_placeholders=None)
         >>> observation_space.sample()
         >>> observation_space.sample()
-            "Get the ball."
+        'Get the ball.'
     """
     """
 
 
     def __init__(
     def __init__(

+ 26 - 18
minigrid/minigrid_env.py

@@ -3,11 +3,12 @@ from __future__ import annotations
 import hashlib
 import hashlib
 import math
 import math
 from abc import abstractmethod
 from abc import abstractmethod
-from typing import Iterable, TypeVar
+from typing import Any, Iterable, SupportsFloat, TypeVar
 
 
 import gymnasium as gym
 import gymnasium as gym
 import numpy as np
 import numpy as np
 from gymnasium import spaces
 from gymnasium import spaces
+from gymnasium.core import ActType, ObsType
 
 
 from minigrid.core.actions import Actions
 from minigrid.core.actions import Actions
 from minigrid.core.constants import COLOR_NAMES, DIR_TO_VEC, TILE_PIXELS
 from minigrid.core.constants import COLOR_NAMES, DIR_TO_VEC, TILE_PIXELS
@@ -110,7 +111,12 @@ class MiniGridEnv(gym.Env):
         self.tile_size = tile_size
         self.tile_size = tile_size
         self.agent_pov = agent_pov
         self.agent_pov = agent_pov
 
 
-    def reset(self, *, seed=None, options=None):
+    def reset(
+        self,
+        *,
+        seed: int | None = None,
+        options: dict[str, Any] | None = None,
+    ) -> tuple[ObsType, dict[str, Any]]:
         super().reset(seed=seed)
         super().reset(seed=seed)
 
 
         # Reinitialize episode-specific variables
         # Reinitialize episode-specific variables
@@ -183,36 +189,36 @@ class MiniGridEnv(gym.Env):
         # Map agent's direction to short string
         # Map agent's direction to short string
         AGENT_DIR_TO_STR = {0: ">", 1: "V", 2: "<", 3: "^"}
         AGENT_DIR_TO_STR = {0: ">", 1: "V", 2: "<", 3: "^"}
 
 
-        str = ""
+        output = ""
 
 
         for j in range(self.grid.height):
         for j in range(self.grid.height):
 
 
             for i in range(self.grid.width):
             for i in range(self.grid.width):
                 if i == self.agent_pos[0] and j == self.agent_pos[1]:
                 if i == self.agent_pos[0] and j == self.agent_pos[1]:
-                    str += 2 * AGENT_DIR_TO_STR[self.agent_dir]
+                    output += 2 * AGENT_DIR_TO_STR[self.agent_dir]
                     continue
                     continue
 
 
-                c = self.grid.get(i, j)
+                tile = self.grid.get(i, j)
 
 
-                if c is None:
-                    str += "  "
+                if tile is None:
+                    output += "  "
                     continue
                     continue
 
 
-                if c.type == "door":
-                    if c.is_open:
-                        str += "__"
-                    elif c.is_locked:
-                        str += "L" + c.color[0].upper()
+                if tile.type == "door":
+                    if tile.is_open:
+                        output += "__"
+                    elif tile.is_locked:
+                        output += "L" + tile.color[0].upper()
                     else:
                     else:
-                        str += "D" + c.color[0].upper()
+                        output += "D" + tile.color[0].upper()
                     continue
                     continue
 
 
-                str += OBJECT_TO_STR[c.type] + c.color[0].upper()
+                output += OBJECT_TO_STR[tile.type] + tile.color[0].upper()
 
 
             if j < self.grid.height - 1:
             if j < self.grid.height - 1:
-                str += "\n"
+                output += "\n"
 
 
-        return str
+        return output
 
 
     @abstractmethod
     @abstractmethod
     def _gen_grid(self, width, height):
     def _gen_grid(self, width, height):
@@ -462,7 +468,7 @@ class MiniGridEnv(gym.Env):
         botX = topX + agent_view_size
         botX = topX + agent_view_size
         botY = topY + agent_view_size
         botY = topY + agent_view_size
 
 
-        return (topX, topY, botX, botY)
+        return topX, topY, botX, botY
 
 
     def relative_coords(self, x, y):
     def relative_coords(self, x, y):
         """
         """
@@ -503,7 +509,9 @@ class MiniGridEnv(gym.Env):
 
 
         return obs_cell is not None and obs_cell.type == world_cell.type
         return obs_cell is not None and obs_cell.type == world_cell.type
 
 
-    def step(self, action):
+    def step(
+        self, action: ActType
+    ) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
         self.step_count += 1
         self.step_count += 1
 
 
         reward = 0
         reward = 0

+ 25 - 42
minigrid/wrappers.py

@@ -3,11 +3,12 @@ from __future__ import annotations
 import math
 import math
 import operator
 import operator
 from functools import reduce
 from functools import reduce
+from typing import Any
 
 
 import gymnasium as gym
 import gymnasium as gym
 import numpy as np
 import numpy as np
-from gymnasium import spaces
-from gymnasium.core import ObservationWrapper, Wrapper
+from gymnasium import logger, spaces
+from gymnasium.core import ObservationWrapper, ObsType, Wrapper
 
 
 from minigrid.core.constants import COLOR_TO_IDX, OBJECT_TO_IDX, STATE_TO_IDX
 from minigrid.core.constants import COLOR_TO_IDX, OBJECT_TO_IDX, STATE_TO_IDX
 from minigrid.core.world_object import Goal
 from minigrid.core.world_object import Goal
@@ -24,8 +25,9 @@ class ReseedWrapper(Wrapper):
         >>> import gymnasium as gym
         >>> import gymnasium as gym
         >>> from minigrid.wrappers import ReseedWrapper
         >>> from minigrid.wrappers import ReseedWrapper
         >>> env = gym.make("MiniGrid-Empty-5x5-v0")
         >>> env = gym.make("MiniGrid-Empty-5x5-v0")
+        >>> _ = env.reset(seed=123)
         >>> [env.np_random.integers(10) for i in range(10)]
         >>> [env.np_random.integers(10) for i in range(10)]
-        [1, 9, 5, 8, 4, 3, 8, 8, 3, 1]
+        [0, 6, 5, 0, 9, 2, 2, 1, 3, 1]
         >>> env = ReseedWrapper(env, seeds=[0, 1], seed_idx=0)
         >>> env = ReseedWrapper(env, seeds=[0, 1], seed_idx=0)
         >>> _, _ = env.reset()
         >>> _, _ = env.reset()
         >>> [env.np_random.integers(10) for i in range(10)]
         >>> [env.np_random.integers(10) for i in range(10)]
@@ -41,7 +43,7 @@ class ReseedWrapper(Wrapper):
         [4, 5, 7, 9, 0, 1, 8, 9, 2, 3]
         [4, 5, 7, 9, 0, 1, 8, 9, 2, 3]
     """
     """
 
 
-    def __init__(self, env, seeds=[0], seed_idx=0):
+    def __init__(self, env, seeds=(0,), seed_idx=0):
         """A wrapper that always regenerate an environment with the same set of seeds.
         """A wrapper that always regenerate an environment with the same set of seeds.
 
 
         Args:
         Args:
@@ -53,15 +55,16 @@ class ReseedWrapper(Wrapper):
         self.seed_idx = seed_idx
         self.seed_idx = seed_idx
         super().__init__(env)
         super().__init__(env)
 
 
-    def reset(self, **kwargs):
-        """Resets the environment with `kwargs`."""
+    def reset(
+        self, *, seed: int | None = None, options: dict[str, Any] | None = None
+    ) -> tuple[ObsType, dict[str, Any]]:
+        if seed is not None:
+            logger.warn(
+                "A seed has been passed to `ReseedWrapper.reset` which is ignored."
+            )
         seed = self.seeds[self.seed_idx]
         seed = self.seeds[self.seed_idx]
         self.seed_idx = (self.seed_idx + 1) % len(self.seeds)
         self.seed_idx = (self.seed_idx + 1) % len(self.seeds)
-        return self.env.reset(seed=seed, **kwargs)
-
-    def step(self, action):
-        """Steps through the environment with `action`."""
-        return self.env.step(action)
+        return self.env.reset(seed=seed, options=options)
 
 
 
 
 class ActionBonus(gym.Wrapper):
 class ActionBonus(gym.Wrapper):
@@ -71,7 +74,6 @@ class ActionBonus(gym.Wrapper):
     visited (state,action) pairs.
     visited (state,action) pairs.
 
 
     Example:
     Example:
-        >>> import minigrid
         >>> import gymnasium as gym
         >>> import gymnasium as gym
         >>> from minigrid.wrappers import ActionBonus
         >>> from minigrid.wrappers import ActionBonus
         >>> env = gym.make("MiniGrid-Empty-5x5-v0")
         >>> env = gym.make("MiniGrid-Empty-5x5-v0")
@@ -122,10 +124,6 @@ class ActionBonus(gym.Wrapper):
 
 
         return obs, reward, terminated, truncated, info
         return obs, reward, terminated, truncated, info
 
 
-    def reset(self, **kwargs):
-        """Resets the environment with `kwargs`."""
-        return self.env.reset(**kwargs)
-
 
 
 class PositionBonus(Wrapper):
 class PositionBonus(Wrapper):
     """
     """
@@ -136,7 +134,6 @@ class PositionBonus(Wrapper):
         This wrapper was previously called ``StateBonus``.
         This wrapper was previously called ``StateBonus``.
 
 
     Example:
     Example:
-        >>> import minigrid
         >>> import gymnasium as gym
         >>> import gymnasium as gym
         >>> from minigrid.wrappers import PositionBonus
         >>> from minigrid.wrappers import PositionBonus
         >>> env = gym.make("MiniGrid-Empty-5x5-v0")
         >>> env = gym.make("MiniGrid-Empty-5x5-v0")
@@ -189,17 +186,12 @@ class PositionBonus(Wrapper):
 
 
         return obs, reward, terminated, truncated, info
         return obs, reward, terminated, truncated, info
 
 
-    def reset(self, **kwargs):
-        """Resets the environment with `kwargs`."""
-        return self.env.reset(**kwargs)
-
 
 
 class ImgObsWrapper(ObservationWrapper):
 class ImgObsWrapper(ObservationWrapper):
     """
     """
     Use the image as the only observation output, no language/mission.
     Use the image as the only observation output, no language/mission.
 
 
     Example:
     Example:
-        >>> import minigrid
         >>> import gymnasium as gym
         >>> import gymnasium as gym
         >>> from minigrid.wrappers import ImgObsWrapper
         >>> from minigrid.wrappers import ImgObsWrapper
         >>> env = gym.make("MiniGrid-Empty-5x5-v0")
         >>> env = gym.make("MiniGrid-Empty-5x5-v0")
@@ -231,7 +223,6 @@ class OneHotPartialObsWrapper(ObservationWrapper):
     agent view as observation.
     agent view as observation.
 
 
     Example:
     Example:
-        >>> import minigrid
         >>> import gymnasium as gym
         >>> import gymnasium as gym
         >>> from minigrid.wrappers import OneHotPartialObsWrapper
         >>> from minigrid.wrappers import OneHotPartialObsWrapper
         >>> env = gym.make("MiniGrid-Empty-5x5-v0")
         >>> env = gym.make("MiniGrid-Empty-5x5-v0")
@@ -254,7 +245,7 @@ class OneHotPartialObsWrapper(ObservationWrapper):
                [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
                [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
                [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
                [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
                [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0]],
                [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0]],
-               dtype=uint8)
+              dtype=uint8)
     """
     """
 
 
     def __init__(self, env, tile_size=8):
     def __init__(self, env, tile_size=8):
@@ -302,17 +293,16 @@ class RGBImgObsWrapper(ObservationWrapper):
     This can be used to have the agent to solve the gridworld in pixel space.
     This can be used to have the agent to solve the gridworld in pixel space.
 
 
     Example:
     Example:
-        >>> import minigrid
         >>> import gymnasium as gym
         >>> import gymnasium as gym
         >>> import matplotlib.pyplot as plt
         >>> import matplotlib.pyplot as plt
         >>> from minigrid.wrappers import RGBImgObsWrapper
         >>> from minigrid.wrappers import RGBImgObsWrapper
         >>> env = gym.make("MiniGrid-Empty-5x5-v0")
         >>> env = gym.make("MiniGrid-Empty-5x5-v0")
         >>> obs, _ = env.reset()
         >>> obs, _ = env.reset()
-        >>> plt.imshow(obs['image'])
+        >>> plt.imshow(obs['image'])  # doctest: +SKIP
         ![NoWrapper](../figures/lavacrossing_NoWrapper.png)
         ![NoWrapper](../figures/lavacrossing_NoWrapper.png)
         >>> env = RGBImgObsWrapper(env)
         >>> env = RGBImgObsWrapper(env)
         >>> obs, _ = env.reset()
         >>> obs, _ = env.reset()
-        >>> plt.imshow(obs['image'])
+        >>> plt.imshow(obs['image'])  # doctest: +SKIP
         ![RGBImgObsWrapper](../figures/lavacrossing_RGBImgObsWrapper.png)
         ![RGBImgObsWrapper](../figures/lavacrossing_RGBImgObsWrapper.png)
     """
     """
 
 
@@ -344,21 +334,20 @@ class RGBImgPartialObsWrapper(ObservationWrapper):
     This can be used to have the agent to solve the gridworld in pixel space.
     This can be used to have the agent to solve the gridworld in pixel space.
 
 
     Example:
     Example:
-        >>> import minigrid
         >>> import gymnasium as gym
         >>> import gymnasium as gym
         >>> import matplotlib.pyplot as plt
         >>> import matplotlib.pyplot as plt
         >>> from minigrid.wrappers import RGBImgObsWrapper, RGBImgPartialObsWrapper
         >>> from minigrid.wrappers import RGBImgObsWrapper, RGBImgPartialObsWrapper
         >>> env = gym.make("MiniGrid-LavaCrossingS11N5-v0")
         >>> env = gym.make("MiniGrid-LavaCrossingS11N5-v0")
         >>> obs, _ = env.reset()
         >>> obs, _ = env.reset()
-        >>> plt.imshow(obs["image"])
+        >>> plt.imshow(obs["image"])  # doctest: +SKIP
         ![NoWrapper](../figures/lavacrossing_NoWrapper.png)
         ![NoWrapper](../figures/lavacrossing_NoWrapper.png)
         >>> env_obs = RGBImgObsWrapper(env)
         >>> env_obs = RGBImgObsWrapper(env)
         >>> obs, _ = env_obs.reset()
         >>> obs, _ = env_obs.reset()
-        >>> plt.imshow(obs["image"])
+        >>> plt.imshow(obs["image"])  # doctest: +SKIP
         ![RGBImgObsWrapper](../figures/lavacrossing_RGBImgObsWrapper.png)
         ![RGBImgObsWrapper](../figures/lavacrossing_RGBImgObsWrapper.png)
         >>> env_obs = RGBImgPartialObsWrapper(env)
         >>> env_obs = RGBImgPartialObsWrapper(env)
         >>> obs, _ = env_obs.reset()
         >>> obs, _ = env_obs.reset()
-        >>> plt.imshow(obs["image"])
+        >>> plt.imshow(obs["image"])  # doctest: +SKIP
         ![RGBImgPartialObsWrapper](../figures/lavacrossing_RGBImgPartialObsWrapper.png)
         ![RGBImgPartialObsWrapper](../figures/lavacrossing_RGBImgPartialObsWrapper.png)
     """
     """
 
 
@@ -391,7 +380,6 @@ class FullyObsWrapper(ObservationWrapper):
     Fully observable gridworld using a compact grid encoding instead of the agent view.
     Fully observable gridworld using a compact grid encoding instead of the agent view.
 
 
     Example:
     Example:
-        >>> import minigrid
         >>> import gymnasium as gym
         >>> import gymnasium as gym
         >>> import matplotlib.pyplot as plt
         >>> import matplotlib.pyplot as plt
         >>> from minigrid.wrappers import FullyObsWrapper
         >>> from minigrid.wrappers import FullyObsWrapper
@@ -437,7 +425,6 @@ class DictObservationSpaceWrapper(ObservationWrapper):
     This wrapper is not applicable to BabyAI environments, given that these have their own language component.
     This wrapper is not applicable to BabyAI environments, given that these have their own language component.
 
 
     Example:
     Example:
-        >>> import minigrid
         >>> import gymnasium as gym
         >>> import gymnasium as gym
         >>> import matplotlib.pyplot as plt
         >>> import matplotlib.pyplot as plt
         >>> from minigrid.wrappers import DictObservationSpaceWrapper
         >>> from minigrid.wrappers import DictObservationSpaceWrapper
@@ -571,7 +558,6 @@ class FlatObsWrapper(ObservationWrapper):
     This wrapper is not applicable to BabyAI environments, given that these have their own language component.
     This wrapper is not applicable to BabyAI environments, given that these have their own language component.
 
 
     Example:
     Example:
-        >>> import minigrid
         >>> import gymnasium as gym
         >>> import gymnasium as gym
         >>> import matplotlib.pyplot as plt
         >>> import matplotlib.pyplot as plt
         >>> from minigrid.wrappers import FlatObsWrapper
         >>> from minigrid.wrappers import FlatObsWrapper
@@ -643,9 +629,7 @@ class ViewSizeWrapper(ObservationWrapper):
     This cannot be used with fully observable wrappers.
     This cannot be used with fully observable wrappers.
 
 
     Example:
     Example:
-        >>> import minigrid
         >>> import gymnasium as gym
         >>> import gymnasium as gym
-        >>> import matplotlib.pyplot as plt
         >>> from minigrid.wrappers import ViewSizeWrapper
         >>> from minigrid.wrappers import ViewSizeWrapper
         >>> env = gym.make("MiniGrid-LavaCrossingS11N5-v0")
         >>> env = gym.make("MiniGrid-LavaCrossingS11N5-v0")
         >>> obs, _ = env.reset()
         >>> obs, _ = env.reset()
@@ -692,7 +676,6 @@ class DirectionObsWrapper(ObservationWrapper):
     type = {slope , angle}
     type = {slope , angle}
 
 
     Example:
     Example:
-        >>> import minigrid
         >>> import gymnasium as gym
         >>> import gymnasium as gym
         >>> import matplotlib.pyplot as plt
         >>> import matplotlib.pyplot as plt
         >>> from minigrid.wrappers import DirectionObsWrapper
         >>> from minigrid.wrappers import DirectionObsWrapper
@@ -708,8 +691,10 @@ class DirectionObsWrapper(ObservationWrapper):
         self.goal_position: tuple = None
         self.goal_position: tuple = None
         self.type = type
         self.type = type
 
 
-    def reset(self):
-        obs, _ = self.env.reset()
+    def reset(
+        self, *, seed: int | None = None, options: dict[str, Any] | None = None
+    ) -> tuple[ObsType, dict[str, Any]]:
+        obs, info = self.env.reset()
 
 
         if not self.goal_position:
         if not self.goal_position:
             self.goal_position = [
             self.goal_position = [
@@ -722,7 +707,7 @@ class DirectionObsWrapper(ObservationWrapper):
                     self.goal_position[0] % self.width,
                     self.goal_position[0] % self.width,
                 )
                 )
 
 
-        return self.observation(obs)
+        return self.observation(obs), info
 
 
     def observation(self, obs):
     def observation(self, obs):
         slope = np.divide(
         slope = np.divide(
@@ -745,9 +730,7 @@ class SymbolicObsWrapper(ObservationWrapper):
     the coordinates on the grid, and IDX is the id of the object.
     the coordinates on the grid, and IDX is the id of the object.
 
 
     Example:
     Example:
-        >>> import minigrid
         >>> import gymnasium as gym
         >>> import gymnasium as gym
-        >>> import matplotlib.pyplot as plt
         >>> from minigrid.wrappers import SymbolicObsWrapper
         >>> from minigrid.wrappers import SymbolicObsWrapper
         >>> env = gym.make("MiniGrid-LavaCrossingS11N5-v0")
         >>> env = gym.make("MiniGrid-LavaCrossingS11N5-v0")
         >>> obs, _ = env.reset()
         >>> obs, _ = env.reset()

+ 1 - 1
tests/test_wrappers.py

@@ -270,7 +270,7 @@ def test_viewsize_wrapper(view_size):
 def test_direction_obs_wrapper(env_id, type):
 def test_direction_obs_wrapper(env_id, type):
     env = gym.make(env_id)
     env = gym.make(env_id)
     env = DirectionObsWrapper(env, type=type)
     env = DirectionObsWrapper(env, type=type)
-    obs = env.reset()
+    obs, _ = env.reset()
 
 
     slope = np.divide(
     slope = np.divide(
         env.goal_position[1] - env.agent_pos[1],
         env.goal_position[1] - env.agent_pos[1],