Explorar o código

Bug fix in DirectionObsWrapper, new tests, & name change (#310)

Co-authored-by: Mark Towers <mark.m.towers@gmail.com>
Bolun %!s(int64=2) %!d(string=hai) anos
pai
achega
9677ebc899
Modificáronse 3 ficheiros con 89 adicións e 12 borrados
  1. 2 2
      docs/api/wrappers.md
  2. 42 7
      minigrid/wrappers.py
  3. 45 3
      tests/test_wrappers.py

+ 2 - 2
docs/api/wrappers.md

@@ -58,10 +58,10 @@ lastpage:
 .. autoclass:: minigrid.wrappers.RGBImgObsWrapper
 ```
 
-# State Bonus
+# Position Bonus
 
 ```{eval-rst}
-.. autoclass:: minigrid.wrappers.StateBonus
+.. autoclass:: minigrid.wrappers.PositionBonus
 ```
 
 # Symbolic Obs

+ 42 - 7
minigrid/wrappers.py

@@ -127,16 +127,18 @@ class ActionBonus(gym.Wrapper):
         return self.env.reset(**kwargs)
 
 
-# Should be named PositionBonus
-class StateBonus(Wrapper):
+class PositionBonus(Wrapper):
     """
     Adds an exploration bonus based on which positions
     are visited on the grid.
 
+    Note:
+        This wrapper was previously called ``StateBonus``.
+
     Example:
         >>> import miniworld
         >>> import gymnasium as gym
-        >>> from minigrid.wrappers import StateBonus
+        >>> from minigrid.wrappers import PositionBonus
         >>> env = gym.make("MiniGrid-Empty-5x5-v0")
         >>> _, _ = env.reset(seed=0)
         >>> _, reward, _, _, _ = env.step(1)
@@ -145,7 +147,7 @@ class StateBonus(Wrapper):
         >>> _, reward, _, _, _ = env.step(1)
         >>> print(reward)
         0
-        >>> env_bonus = StateBonus(env)
+        >>> env_bonus = PositionBonus(env)
         >>> obs, _ = env_bonus.reset(seed=0)
         >>> obs, reward, terminated, truncated, info = env_bonus.step(1)
         >>> print(reward)
@@ -688,6 +690,17 @@ class DirectionObsWrapper(ObservationWrapper):
     """
     Provides the slope/angular direction to the goal with the observations as modeled by (y2 - y2 )/( x2 - x1)
     type = {slope , angle}
+
+    Example:
+        >>> import miniworld
+        >>> import gymnasium as gym
+        >>> import matplotlib.pyplot as plt
+        >>> from minigrid.wrappers import DirectionObsWrapper
+        >>> env = gym.make("MiniGrid-LavaCrossingS11N5-v0")
+        >>> env_obs = DirectionObsWrapper(env, type="slope")
+        >>> obs, _ = env_obs.reset()
+        >>> obs['goal_direction']
+        1.0
     """
 
     def __init__(self, env, type="slope"):
@@ -696,7 +709,8 @@ class DirectionObsWrapper(ObservationWrapper):
         self.type = type
 
     def reset(self):
-        obs = self.env.reset()
+        obs, _ = self.env.reset()
+
         if not self.goal_position:
             self.goal_position = [
                 x for x, y in enumerate(self.grid.grid) if isinstance(y, Goal)
@@ -707,14 +721,20 @@ class DirectionObsWrapper(ObservationWrapper):
                     int(self.goal_position[0] / self.height),
                     self.goal_position[0] % self.width,
                 )
-        return obs
+
+        return self.observation(obs)
 
     def observation(self, obs):
         slope = np.divide(
             self.goal_position[1] - self.agent_pos[1],
             self.goal_position[0] - self.agent_pos[0],
         )
-        obs["goal_direction"] = np.arctan(slope) if self.type == "angle" else slope
+
+        if self.type == "angle":
+            obs["goal_direction"] = np.arctan(slope)
+        else:
+            obs["goal_direction"] = slope
+
         return obs
 
 
@@ -723,6 +743,20 @@ class SymbolicObsWrapper(ObservationWrapper):
     Fully observable grid with a symbolic state representation.
     The symbol is a triple of (X, Y, IDX), where X and Y are
     the coordinates on the grid, and IDX is the id of the object.
+
+    Example:
+        >>> import miniworld
+        >>> import gymnasium as gym
+        >>> import matplotlib.pyplot as plt
+        >>> from minigrid.wrappers import SymbolicObsWrapper
+        >>> env = gym.make("MiniGrid-LavaCrossingS11N5-v0")
+        >>> obs, _ = env.reset()
+        >>> obs['image'].shape
+        (7, 7, 3)
+        >>> env_obs = SymbolicObsWrapper(env)
+        >>> obs, _ = env_obs.reset()
+        >>> obs['image'].shape
+        (11, 11, 3)
     """
 
     def __init__(self, env):
@@ -749,4 +783,5 @@ class SymbolicObsWrapper(ObservationWrapper):
         grid = np.transpose(grid, (1, 2, 0))
         grid[agent_pos[0], agent_pos[1], 2] = OBJECT_TO_IDX["agent"]
         obs["image"] = grid
+
         return obs

+ 45 - 3
tests/test_wrappers.py

@@ -11,14 +11,16 @@ from minigrid.envs import EmptyEnv
 from minigrid.wrappers import (
     ActionBonus,
     DictObservationSpaceWrapper,
+    DirectionObsWrapper,
     FlatObsWrapper,
     FullyObsWrapper,
     ImgObsWrapper,
     OneHotPartialObsWrapper,
+    PositionBonus,
     ReseedWrapper,
     RGBImgObsWrapper,
     RGBImgPartialObsWrapper,
-    StateBonus,
+    SymbolicObsWrapper,
     ViewSizeWrapper,
 )
 from tests.utils import all_testing_env_specs, assert_equals, minigrid_testing_env_specs
@@ -77,9 +79,9 @@ def test_reseed_wrapper(env_spec):
 
 
 @pytest.mark.parametrize("env_id", ["MiniGrid-Empty-16x16-v0"])
-def test_state_bonus_wrapper(env_id):
+def test_position_bonus_wrapper(env_id):
     env = gym.make(env_id)
-    wrapped_env = StateBonus(gym.make(env_id))
+    wrapped_env = PositionBonus(gym.make(env_id))
 
     action_forward = Actions.forward
     action_left = Actions.left
@@ -260,3 +262,43 @@ def test_viewsize_wrapper(view_size):
     obs, _, _, _, _ = env.step(0)
     assert obs["image"].shape == (view_size, view_size, 3)
     env.close()
+
+
+@pytest.mark.parametrize("env_id", ["MiniGrid-LavaCrossingS11N5-v0"])
+@pytest.mark.parametrize("type", ["slope", "angle"])
+def test_direction_obs_wrapper(env_id, type):
+    env = gym.make(env_id)
+    env = DirectionObsWrapper(env, type=type)
+    obs = env.reset()
+
+    slope = np.divide(
+        env.goal_position[1] - env.agent_pos[1],
+        env.goal_position[0] - env.agent_pos[0],
+    )
+    if type == "slope":
+        assert obs["goal_direction"] == slope
+    elif type == "angle":
+        assert obs["goal_direction"] == np.arctan(slope)
+
+    obs, _, _, _, _ = env.step(0)
+    slope = np.divide(
+        env.goal_position[1] - env.agent_pos[1],
+        env.goal_position[0] - env.agent_pos[0],
+    )
+    if type == "slope":
+        assert obs["goal_direction"] == slope
+    elif type == "angle":
+        assert obs["goal_direction"] == np.arctan(slope)
+
+    env.close()
+
+
+@pytest.mark.parametrize("env_id", ["MiniGrid-Empty-16x16-v0"])
+def test_symbolic_obs_wrapper(env_id):
+    env = gym.make(env_id)
+    env = SymbolicObsWrapper(env)
+    obs, _ = env.reset()
+    assert obs["image"].shape == (env.width, env.height, 3)
+    obs, _, _, _, _ = env.step(0)
+    assert obs["image"].shape == (env.width, env.height, 3)
+    env.close()