Mark Towers преди 2 години
родител
ревизия
6cd169375e
променени са 6 файла, в които са добавени 73 реда и са изтрити 80 реда
  1. 2 0
      .github/workflows/build.yml
  2. 13 14
      minigrid/__init__.py
  3. 6 5
      minigrid/core/mission.py
  4. 26 18
      minigrid/minigrid_env.py
  5. 25 42
      minigrid/wrappers.py
  6. 1 1
      tests/test_wrappers.py

+ 2 - 0
.github/workflows/build.yml

@@ -18,3 +18,5 @@ jobs:
              --tag minigrid-docker .      
       - name: Run tests
         run: docker run minigrid-docker pytest
+      - name: Run doctest
+        run: docker run minigrid-docker pytest --doctest-modules minigrid/

+ 13 - 14
minigrid/__init__.py

@@ -6,6 +6,19 @@ from minigrid import minigrid_env, wrappers
 from minigrid.core import roomgrid
 from minigrid.core.world_object import Wall
 
+__version__ = "2.1.1"
+
+
+try:
+    import sys
+
+    from farama_notifications import notifications
+
+    if "minigrid" in notifications and __version__ in notifications["minigrid"]:
+        print(notifications["minigrid"][__version__], file=sys.stderr)
+except Exception:  # nosec
+    pass
+
 
 def register_minigrid_envs():
     # BlockedUnlockPickup
@@ -1071,17 +1084,3 @@ def register_minigrid_envs():
         id="BabyAI-BossLevelNoUnlock-v0",
         entry_point="minigrid.envs.babyai:BossLevelNoUnlock",
     )
-
-
-__version__ = "2.1.0"
-
-
-try:
-    import sys
-
-    from farama_notifications import notifications
-
-    if "minigrid" in notifications and __version__ in notifications["minigrid"]:
-        print(notifications["minigrid"][__version__], file=sys.stderr)
-except Exception:  # nosec
-    pass

+ 6 - 5
minigrid/core/mission.py

@@ -16,13 +16,14 @@ class MissionSpace(spaces.Space[str]):
     The space allows generating random mission strings constructed with an input placeholder list.
     Example Usage::
         >>> observation_space = MissionSpace(mission_func=lambda color: f"Get the {color} ball.",
-                                                ordered_placeholders=[["green", "blue"]])
+        ...                                  ordered_placeholders=[["green", "blue"]])
+        >>> _ = observation_space.seed(123)
         >>> observation_space.sample()
-            "Get the green ball."
-        >>> observation_space = MissionSpace(mission_func=lambda : "Get the ball.".,
-                                                ordered_placeholders=None)
+        'Get the green ball.'
+        >>> observation_space = MissionSpace(mission_func=lambda : "Get the ball.",
+        ...                                  ordered_placeholders=None)
         >>> observation_space.sample()
-            "Get the ball."
+        'Get the ball.'
     """
 
     def __init__(

+ 26 - 18
minigrid/minigrid_env.py

@@ -3,11 +3,12 @@ from __future__ import annotations
 import hashlib
 import math
 from abc import abstractmethod
-from typing import Iterable, TypeVar
+from typing import Any, Iterable, SupportsFloat, TypeVar
 
 import gymnasium as gym
 import numpy as np
 from gymnasium import spaces
+from gymnasium.core import ActType, ObsType
 
 from minigrid.core.actions import Actions
 from minigrid.core.constants import COLOR_NAMES, DIR_TO_VEC, TILE_PIXELS
@@ -110,7 +111,12 @@ class MiniGridEnv(gym.Env):
         self.tile_size = tile_size
         self.agent_pov = agent_pov
 
-    def reset(self, *, seed=None, options=None):
+    def reset(
+        self,
+        *,
+        seed: int | None = None,
+        options: dict[str, Any] | None = None,
+    ) -> tuple[ObsType, dict[str, Any]]:
         super().reset(seed=seed)
 
         # Reinitialize episode-specific variables
@@ -183,36 +189,36 @@ class MiniGridEnv(gym.Env):
         # Map agent's direction to short string
         AGENT_DIR_TO_STR = {0: ">", 1: "V", 2: "<", 3: "^"}
 
-        str = ""
+        output = ""
 
         for j in range(self.grid.height):
 
             for i in range(self.grid.width):
                 if i == self.agent_pos[0] and j == self.agent_pos[1]:
-                    str += 2 * AGENT_DIR_TO_STR[self.agent_dir]
+                    output += 2 * AGENT_DIR_TO_STR[self.agent_dir]
                     continue
 
-                c = self.grid.get(i, j)
+                tile = self.grid.get(i, j)
 
-                if c is None:
-                    str += "  "
+                if tile is None:
+                    output += "  "
                     continue
 
-                if c.type == "door":
-                    if c.is_open:
-                        str += "__"
-                    elif c.is_locked:
-                        str += "L" + c.color[0].upper()
+                if tile.type == "door":
+                    if tile.is_open:
+                        output += "__"
+                    elif tile.is_locked:
+                        output += "L" + tile.color[0].upper()
                     else:
-                        str += "D" + c.color[0].upper()
+                        output += "D" + tile.color[0].upper()
                     continue
 
-                str += OBJECT_TO_STR[c.type] + c.color[0].upper()
+                output += OBJECT_TO_STR[tile.type] + tile.color[0].upper()
 
             if j < self.grid.height - 1:
-                str += "\n"
+                output += "\n"
 
-        return str
+        return output
 
     @abstractmethod
     def _gen_grid(self, width, height):
@@ -462,7 +468,7 @@ class MiniGridEnv(gym.Env):
         botX = topX + agent_view_size
         botY = topY + agent_view_size
 
-        return (topX, topY, botX, botY)
+        return topX, topY, botX, botY
 
     def relative_coords(self, x, y):
         """
@@ -503,7 +509,9 @@ class MiniGridEnv(gym.Env):
 
         return obs_cell is not None and obs_cell.type == world_cell.type
 
-    def step(self, action):
+    def step(
+        self, action: ActType
+    ) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
         self.step_count += 1
 
         reward = 0

+ 25 - 42
minigrid/wrappers.py

@@ -3,11 +3,12 @@ from __future__ import annotations
 import math
 import operator
 from functools import reduce
+from typing import Any
 
 import gymnasium as gym
 import numpy as np
-from gymnasium import spaces
-from gymnasium.core import ObservationWrapper, Wrapper
+from gymnasium import logger, spaces
+from gymnasium.core import ObservationWrapper, ObsType, Wrapper
 
 from minigrid.core.constants import COLOR_TO_IDX, OBJECT_TO_IDX, STATE_TO_IDX
 from minigrid.core.world_object import Goal
@@ -24,8 +25,9 @@ class ReseedWrapper(Wrapper):
         >>> import gymnasium as gym
         >>> from minigrid.wrappers import ReseedWrapper
         >>> env = gym.make("MiniGrid-Empty-5x5-v0")
+        >>> _ = env.reset(seed=123)
         >>> [env.np_random.integers(10) for i in range(10)]
-        [1, 9, 5, 8, 4, 3, 8, 8, 3, 1]
+        [0, 6, 5, 0, 9, 2, 2, 1, 3, 1]
         >>> env = ReseedWrapper(env, seeds=[0, 1], seed_idx=0)
         >>> _, _ = env.reset()
         >>> [env.np_random.integers(10) for i in range(10)]
@@ -41,7 +43,7 @@ class ReseedWrapper(Wrapper):
         [4, 5, 7, 9, 0, 1, 8, 9, 2, 3]
     """
 
-    def __init__(self, env, seeds=[0], seed_idx=0):
+    def __init__(self, env, seeds=(0,), seed_idx=0):
         """A wrapper that always regenerate an environment with the same set of seeds.
 
         Args:
@@ -53,15 +55,16 @@ class ReseedWrapper(Wrapper):
         self.seed_idx = seed_idx
         super().__init__(env)
 
-    def reset(self, **kwargs):
-        """Resets the environment with `kwargs`."""
+    def reset(
+        self, *, seed: int | None = None, options: dict[str, Any] | None = None
+    ) -> tuple[ObsType, dict[str, Any]]:
+        if seed is not None:
+            logger.warn(
+                "A seed has been passed to `ReseedWrapper.reset` which is ignored."
+            )
         seed = self.seeds[self.seed_idx]
         self.seed_idx = (self.seed_idx + 1) % len(self.seeds)
-        return self.env.reset(seed=seed, **kwargs)
-
-    def step(self, action):
-        """Steps through the environment with `action`."""
-        return self.env.step(action)
+        return self.env.reset(seed=seed, options=options)
 
 
 class ActionBonus(gym.Wrapper):
@@ -71,7 +74,6 @@ class ActionBonus(gym.Wrapper):
     visited (state,action) pairs.
 
     Example:
-        >>> import minigrid
         >>> import gymnasium as gym
         >>> from minigrid.wrappers import ActionBonus
         >>> env = gym.make("MiniGrid-Empty-5x5-v0")
@@ -122,10 +124,6 @@ class ActionBonus(gym.Wrapper):
 
         return obs, reward, terminated, truncated, info
 
-    def reset(self, **kwargs):
-        """Resets the environment with `kwargs`."""
-        return self.env.reset(**kwargs)
-
 
 class PositionBonus(Wrapper):
     """
@@ -136,7 +134,6 @@ class PositionBonus(Wrapper):
         This wrapper was previously called ``StateBonus``.
 
     Example:
-        >>> import minigrid
         >>> import gymnasium as gym
         >>> from minigrid.wrappers import PositionBonus
         >>> env = gym.make("MiniGrid-Empty-5x5-v0")
@@ -189,17 +186,12 @@ class PositionBonus(Wrapper):
 
         return obs, reward, terminated, truncated, info
 
-    def reset(self, **kwargs):
-        """Resets the environment with `kwargs`."""
-        return self.env.reset(**kwargs)
-
 
 class ImgObsWrapper(ObservationWrapper):
     """
     Use the image as the only observation output, no language/mission.
 
     Example:
-        >>> import minigrid
         >>> import gymnasium as gym
         >>> from minigrid.wrappers import ImgObsWrapper
         >>> env = gym.make("MiniGrid-Empty-5x5-v0")
@@ -231,7 +223,6 @@ class OneHotPartialObsWrapper(ObservationWrapper):
     agent view as observation.
 
     Example:
-        >>> import minigrid
         >>> import gymnasium as gym
         >>> from minigrid.wrappers import OneHotPartialObsWrapper
         >>> env = gym.make("MiniGrid-Empty-5x5-v0")
@@ -254,7 +245,7 @@ class OneHotPartialObsWrapper(ObservationWrapper):
                [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
                [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
                [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0]],
-               dtype=uint8)
+              dtype=uint8)
     """
 
     def __init__(self, env, tile_size=8):
@@ -302,17 +293,16 @@ class RGBImgObsWrapper(ObservationWrapper):
     This can be used to have the agent to solve the gridworld in pixel space.
 
     Example:
-        >>> import minigrid
         >>> import gymnasium as gym
         >>> import matplotlib.pyplot as plt
         >>> from minigrid.wrappers import RGBImgObsWrapper
         >>> env = gym.make("MiniGrid-Empty-5x5-v0")
         >>> obs, _ = env.reset()
-        >>> plt.imshow(obs['image'])
+        >>> plt.imshow(obs['image'])  # doctest: +SKIP
         ![NoWrapper](../figures/lavacrossing_NoWrapper.png)
         >>> env = RGBImgObsWrapper(env)
         >>> obs, _ = env.reset()
-        >>> plt.imshow(obs['image'])
+        >>> plt.imshow(obs['image'])  # doctest: +SKIP
         ![RGBImgObsWrapper](../figures/lavacrossing_RGBImgObsWrapper.png)
     """
 
@@ -344,21 +334,20 @@ class RGBImgPartialObsWrapper(ObservationWrapper):
     This can be used to have the agent to solve the gridworld in pixel space.
 
     Example:
-        >>> import minigrid
         >>> import gymnasium as gym
         >>> import matplotlib.pyplot as plt
         >>> from minigrid.wrappers import RGBImgObsWrapper, RGBImgPartialObsWrapper
         >>> env = gym.make("MiniGrid-LavaCrossingS11N5-v0")
         >>> obs, _ = env.reset()
-        >>> plt.imshow(obs["image"])
+        >>> plt.imshow(obs["image"])  # doctest: +SKIP
         ![NoWrapper](../figures/lavacrossing_NoWrapper.png)
         >>> env_obs = RGBImgObsWrapper(env)
         >>> obs, _ = env_obs.reset()
-        >>> plt.imshow(obs["image"])
+        >>> plt.imshow(obs["image"])  # doctest: +SKIP
         ![RGBImgObsWrapper](../figures/lavacrossing_RGBImgObsWrapper.png)
         >>> env_obs = RGBImgPartialObsWrapper(env)
         >>> obs, _ = env_obs.reset()
-        >>> plt.imshow(obs["image"])
+        >>> plt.imshow(obs["image"])  # doctest: +SKIP
         ![RGBImgPartialObsWrapper](../figures/lavacrossing_RGBImgPartialObsWrapper.png)
     """
 
@@ -391,7 +380,6 @@ class FullyObsWrapper(ObservationWrapper):
     Fully observable gridworld using a compact grid encoding instead of the agent view.
 
     Example:
-        >>> import minigrid
         >>> import gymnasium as gym
         >>> import matplotlib.pyplot as plt
         >>> from minigrid.wrappers import FullyObsWrapper
@@ -437,7 +425,6 @@ class DictObservationSpaceWrapper(ObservationWrapper):
     This wrapper is not applicable to BabyAI environments, given that these have their own language component.
 
     Example:
-        >>> import minigrid
         >>> import gymnasium as gym
         >>> import matplotlib.pyplot as plt
         >>> from minigrid.wrappers import DictObservationSpaceWrapper
@@ -571,7 +558,6 @@ class FlatObsWrapper(ObservationWrapper):
     This wrapper is not applicable to BabyAI environments, given that these have their own language component.
 
     Example:
-        >>> import minigrid
         >>> import gymnasium as gym
         >>> import matplotlib.pyplot as plt
         >>> from minigrid.wrappers import FlatObsWrapper
@@ -643,9 +629,7 @@ class ViewSizeWrapper(ObservationWrapper):
     This cannot be used with fully observable wrappers.
 
     Example:
-        >>> import minigrid
         >>> import gymnasium as gym
-        >>> import matplotlib.pyplot as plt
         >>> from minigrid.wrappers import ViewSizeWrapper
         >>> env = gym.make("MiniGrid-LavaCrossingS11N5-v0")
         >>> obs, _ = env.reset()
@@ -692,7 +676,6 @@ class DirectionObsWrapper(ObservationWrapper):
     type = {slope , angle}
 
     Example:
-        >>> import minigrid
         >>> import gymnasium as gym
         >>> import matplotlib.pyplot as plt
         >>> from minigrid.wrappers import DirectionObsWrapper
@@ -708,8 +691,10 @@ class DirectionObsWrapper(ObservationWrapper):
         self.goal_position: tuple = None
         self.type = type
 
-    def reset(self):
-        obs, _ = self.env.reset()
+    def reset(
+        self, *, seed: int | None = None, options: dict[str, Any] | None = None
+    ) -> tuple[ObsType, dict[str, Any]]:
+        obs, info = self.env.reset()
 
         if not self.goal_position:
             self.goal_position = [
@@ -722,7 +707,7 @@ class DirectionObsWrapper(ObservationWrapper):
                     self.goal_position[0] % self.width,
                 )
 
-        return self.observation(obs)
+        return self.observation(obs), info
 
     def observation(self, obs):
         slope = np.divide(
@@ -745,9 +730,7 @@ class SymbolicObsWrapper(ObservationWrapper):
     the coordinates on the grid, and IDX is the id of the object.
 
     Example:
-        >>> import minigrid
         >>> import gymnasium as gym
-        >>> import matplotlib.pyplot as plt
         >>> from minigrid.wrappers import SymbolicObsWrapper
         >>> env = gym.make("MiniGrid-LavaCrossingS11N5-v0")
         >>> obs, _ = env.reset()

+ 1 - 1
tests/test_wrappers.py

@@ -270,7 +270,7 @@ def test_viewsize_wrapper(view_size):
 def test_direction_obs_wrapper(env_id, type):
     env = gym.make(env_id)
     env = DirectionObsWrapper(env, type=type)
-    obs = env.reset()
+    obs, _ = env.reset()
 
     slope = np.divide(
         env.goal_position[1] - env.agent_pos[1],