Browse Source

Fixed bug in SymbolicObsWrapper (#331)

Co-authored-by: Mark Towers <mark.m.towers@gmail.com>
Bolun 2 years ago
parent
commit
d537e1c167
2 changed files with 45 additions and 18 deletions
  1. 17 15
      minigrid/wrappers.py
  2. 28 3
      tests/test_wrappers.py

+ 17 - 15
minigrid/wrappers.py

@@ -71,7 +71,7 @@ class ActionBonus(gym.Wrapper):
     visited (state,action) pairs.
     visited (state,action) pairs.
 
 
     Example:
     Example:
-        >>> import miniworld
+        >>> import minigrid
         >>> import gymnasium as gym
         >>> import gymnasium as gym
         >>> from minigrid.wrappers import ActionBonus
         >>> from minigrid.wrappers import ActionBonus
         >>> env = gym.make("MiniGrid-Empty-5x5-v0")
         >>> env = gym.make("MiniGrid-Empty-5x5-v0")
@@ -136,7 +136,7 @@ class PositionBonus(Wrapper):
         This wrapper was previously called ``StateBonus``.
         This wrapper was previously called ``StateBonus``.
 
 
     Example:
     Example:
-        >>> import miniworld
+        >>> import minigrid
         >>> import gymnasium as gym
         >>> import gymnasium as gym
         >>> from minigrid.wrappers import PositionBonus
         >>> from minigrid.wrappers import PositionBonus
         >>> env = gym.make("MiniGrid-Empty-5x5-v0")
         >>> env = gym.make("MiniGrid-Empty-5x5-v0")
@@ -199,7 +199,7 @@ class ImgObsWrapper(ObservationWrapper):
     Use the image as the only observation output, no language/mission.
     Use the image as the only observation output, no language/mission.
 
 
     Example:
     Example:
-        >>> import miniworld
+        >>> import minigrid
         >>> import gymnasium as gym
         >>> import gymnasium as gym
         >>> from minigrid.wrappers import ImgObsWrapper
         >>> from minigrid.wrappers import ImgObsWrapper
         >>> env = gym.make("MiniGrid-Empty-5x5-v0")
         >>> env = gym.make("MiniGrid-Empty-5x5-v0")
@@ -231,7 +231,7 @@ class OneHotPartialObsWrapper(ObservationWrapper):
     agent view as observation.
     agent view as observation.
 
 
     Example:
     Example:
-        >>> import miniworld
+        >>> import minigrid
         >>> import gymnasium as gym
         >>> import gymnasium as gym
         >>> from minigrid.wrappers import OneHotPartialObsWrapper
         >>> from minigrid.wrappers import OneHotPartialObsWrapper
         >>> env = gym.make("MiniGrid-Empty-5x5-v0")
         >>> env = gym.make("MiniGrid-Empty-5x5-v0")
@@ -302,7 +302,7 @@ class RGBImgObsWrapper(ObservationWrapper):
     This can be used to have the agent to solve the gridworld in pixel space.
     This can be used to have the agent to solve the gridworld in pixel space.
 
 
     Example:
     Example:
-        >>> import miniworld
+        >>> import minigrid
         >>> import gymnasium as gym
         >>> import gymnasium as gym
         >>> import matplotlib.pyplot as plt
         >>> import matplotlib.pyplot as plt
         >>> from minigrid.wrappers import RGBImgObsWrapper
         >>> from minigrid.wrappers import RGBImgObsWrapper
@@ -344,7 +344,7 @@ class RGBImgPartialObsWrapper(ObservationWrapper):
     This can be used to have the agent to solve the gridworld in pixel space.
     This can be used to have the agent to solve the gridworld in pixel space.
 
 
     Example:
     Example:
-        >>> import miniworld
+        >>> import minigrid
         >>> import gymnasium as gym
         >>> import gymnasium as gym
         >>> import matplotlib.pyplot as plt
         >>> import matplotlib.pyplot as plt
         >>> from minigrid.wrappers import RGBImgObsWrapper, RGBImgPartialObsWrapper
         >>> from minigrid.wrappers import RGBImgObsWrapper, RGBImgPartialObsWrapper
@@ -391,7 +391,7 @@ class FullyObsWrapper(ObservationWrapper):
     Fully observable gridworld using a compact grid encoding instead of the agent view.
     Fully observable gridworld using a compact grid encoding instead of the agent view.
 
 
     Example:
     Example:
-        >>> import miniworld
+        >>> import minigrid
         >>> import gymnasium as gym
         >>> import gymnasium as gym
         >>> import matplotlib.pyplot as plt
         >>> import matplotlib.pyplot as plt
         >>> from minigrid.wrappers import FullyObsWrapper
         >>> from minigrid.wrappers import FullyObsWrapper
@@ -437,7 +437,7 @@ class DictObservationSpaceWrapper(ObservationWrapper):
     This wrapper is not applicable to BabyAI environments, given that these have their own language component.
     This wrapper is not applicable to BabyAI environments, given that these have their own language component.
 
 
     Example:
     Example:
-        >>> import miniworld
+        >>> import minigrid
         >>> import gymnasium as gym
         >>> import gymnasium as gym
         >>> import matplotlib.pyplot as plt
         >>> import matplotlib.pyplot as plt
         >>> from minigrid.wrappers import DictObservationSpaceWrapper
         >>> from minigrid.wrappers import DictObservationSpaceWrapper
@@ -571,7 +571,7 @@ class FlatObsWrapper(ObservationWrapper):
     This wrapper is not applicable to BabyAI environments, given that these have their own language component.
     This wrapper is not applicable to BabyAI environments, given that these have their own language component.
 
 
     Example:
     Example:
-        >>> import miniworld
+        >>> import minigrid
         >>> import gymnasium as gym
         >>> import gymnasium as gym
         >>> import matplotlib.pyplot as plt
         >>> import matplotlib.pyplot as plt
         >>> from minigrid.wrappers import FlatObsWrapper
         >>> from minigrid.wrappers import FlatObsWrapper
@@ -643,7 +643,7 @@ class ViewSizeWrapper(ObservationWrapper):
     This cannot be used with fully observable wrappers.
     This cannot be used with fully observable wrappers.
 
 
     Example:
     Example:
-        >>> import miniworld
+        >>> import minigrid
         >>> import gymnasium as gym
         >>> import gymnasium as gym
         >>> import matplotlib.pyplot as plt
         >>> import matplotlib.pyplot as plt
         >>> from minigrid.wrappers import ViewSizeWrapper
         >>> from minigrid.wrappers import ViewSizeWrapper
@@ -692,7 +692,7 @@ class DirectionObsWrapper(ObservationWrapper):
     type = {slope , angle}
     type = {slope , angle}
 
 
     Example:
     Example:
-        >>> import miniworld
+        >>> import minigrid
         >>> import gymnasium as gym
         >>> import gymnasium as gym
         >>> import matplotlib.pyplot as plt
         >>> import matplotlib.pyplot as plt
         >>> from minigrid.wrappers import DirectionObsWrapper
         >>> from minigrid.wrappers import DirectionObsWrapper
@@ -745,7 +745,7 @@ class SymbolicObsWrapper(ObservationWrapper):
     the coordinates on the grid, and IDX is the id of the object.
     the coordinates on the grid, and IDX is the id of the object.
 
 
     Example:
     Example:
-        >>> import miniworld
+        >>> import minigrid
         >>> import gymnasium as gym
         >>> import gymnasium as gym
         >>> import matplotlib.pyplot as plt
         >>> import matplotlib.pyplot as plt
         >>> from minigrid.wrappers import SymbolicObsWrapper
         >>> from minigrid.wrappers import SymbolicObsWrapper
@@ -777,9 +777,11 @@ class SymbolicObsWrapper(ObservationWrapper):
             [OBJECT_TO_IDX[o.type] if o is not None else -1 for o in self.grid.grid]
             [OBJECT_TO_IDX[o.type] if o is not None else -1 for o in self.grid.grid]
         )
         )
         agent_pos = self.env.agent_pos
         agent_pos = self.env.agent_pos
-        w, h = self.width, self.height
-        grid = np.mgrid[:w, :h]
-        grid = np.concatenate([grid, objects.reshape(1, w, h)])
+        ncol, nrow = self.width, self.height
+        grid = np.mgrid[:ncol, :nrow]
+        _objects = np.transpose(objects.reshape(1, nrow, ncol), (0, 2, 1))
+
+        grid = np.concatenate([grid, _objects])
         grid = np.transpose(grid, (1, 2, 0))
         grid = np.transpose(grid, (1, 2, 0))
         grid[agent_pos[0], agent_pos[1], 2] = OBJECT_TO_IDX["agent"]
         grid[agent_pos[0], agent_pos[1], 2] = OBJECT_TO_IDX["agent"]
         obs["image"] = grid
         obs["image"] = grid

+ 28 - 3
tests/test_wrappers.py

@@ -7,6 +7,7 @@ import numpy as np
 import pytest
 import pytest
 
 
 from minigrid.core.actions import Actions
 from minigrid.core.actions import Actions
+from minigrid.core.constants import OBJECT_TO_IDX
 from minigrid.envs import EmptyEnv
 from minigrid.envs import EmptyEnv
 from minigrid.wrappers import (
 from minigrid.wrappers import (
     ActionBonus,
     ActionBonus,
@@ -293,12 +294,36 @@ def test_direction_obs_wrapper(env_id, type):
     env.close()
     env.close()
 
 
 
 
-@pytest.mark.parametrize("env_id", ["MiniGrid-Empty-16x16-v0"])
+@pytest.mark.parametrize("env_id", ["MiniGrid-DistShift1-v0"])
 def test_symbolic_obs_wrapper(env_id):
 def test_symbolic_obs_wrapper(env_id):
     env = gym.make(env_id)
     env = gym.make(env_id)
+
     env = SymbolicObsWrapper(env)
     env = SymbolicObsWrapper(env)
-    obs, _ = env.reset()
+    obs, _ = env.reset(seed=123)
+    agent_pos = env.agent_pos
+    goal_pos = env.goal_pos
+
     assert obs["image"].shape == (env.width, env.height, 3)
     assert obs["image"].shape == (env.width, env.height, 3)
-    obs, _, _, _, _ = env.step(0)
+    assert np.alltrue(
+        obs["image"][agent_pos[0], agent_pos[1], :]
+        == np.array([agent_pos[0], agent_pos[1], OBJECT_TO_IDX["agent"]])
+    )
+    assert np.alltrue(
+        obs["image"][goal_pos[0], goal_pos[1], :]
+        == np.array([goal_pos[0], goal_pos[1], OBJECT_TO_IDX["goal"]])
+    )
+
+    obs, _, _, _, _ = env.step(2)
+    agent_pos = env.agent_pos
+    goal_pos = env.goal_pos
+
     assert obs["image"].shape == (env.width, env.height, 3)
     assert obs["image"].shape == (env.width, env.height, 3)
+    assert np.alltrue(
+        obs["image"][agent_pos[0], agent_pos[1], :]
+        == np.array([agent_pos[0], agent_pos[1], OBJECT_TO_IDX["agent"]])
+    )
+    assert np.alltrue(
+        obs["image"][goal_pos[0], goal_pos[1], :]
+        == np.array([goal_pos[0], goal_pos[1], OBJECT_TO_IDX["goal"]])
+    )
     env.close()
     env.close()