Bläddra i källkod

Fixed bug in SymbolicObsWrapper (#331)

Co-authored-by: Mark Towers <mark.m.towers@gmail.com>
Bolun 2 år sedan
förälder
incheckning
d537e1c167
2 ändrade filer med 45 tillägg och 18 borttagningar
  1. 17 15
      minigrid/wrappers.py
  2. 28 3
      tests/test_wrappers.py

+ 17 - 15
minigrid/wrappers.py

@@ -71,7 +71,7 @@ class ActionBonus(gym.Wrapper):
     visited (state,action) pairs.
 
     Example:
-        >>> import miniworld
+        >>> import minigrid
         >>> import gymnasium as gym
         >>> from minigrid.wrappers import ActionBonus
         >>> env = gym.make("MiniGrid-Empty-5x5-v0")
@@ -136,7 +136,7 @@ class PositionBonus(Wrapper):
         This wrapper was previously called ``StateBonus``.
 
     Example:
-        >>> import miniworld
+        >>> import minigrid
         >>> import gymnasium as gym
         >>> from minigrid.wrappers import PositionBonus
         >>> env = gym.make("MiniGrid-Empty-5x5-v0")
@@ -199,7 +199,7 @@ class ImgObsWrapper(ObservationWrapper):
     Use the image as the only observation output, no language/mission.
 
     Example:
-        >>> import miniworld
+        >>> import minigrid
         >>> import gymnasium as gym
         >>> from minigrid.wrappers import ImgObsWrapper
         >>> env = gym.make("MiniGrid-Empty-5x5-v0")
@@ -231,7 +231,7 @@ class OneHotPartialObsWrapper(ObservationWrapper):
     agent view as observation.
 
     Example:
-        >>> import miniworld
+        >>> import minigrid
         >>> import gymnasium as gym
         >>> from minigrid.wrappers import OneHotPartialObsWrapper
         >>> env = gym.make("MiniGrid-Empty-5x5-v0")
@@ -302,7 +302,7 @@ class RGBImgObsWrapper(ObservationWrapper):
     This can be used to have the agent to solve the gridworld in pixel space.
 
     Example:
-        >>> import miniworld
+        >>> import minigrid
         >>> import gymnasium as gym
         >>> import matplotlib.pyplot as plt
         >>> from minigrid.wrappers import RGBImgObsWrapper
@@ -344,7 +344,7 @@ class RGBImgPartialObsWrapper(ObservationWrapper):
     This can be used to have the agent to solve the gridworld in pixel space.
 
     Example:
-        >>> import miniworld
+        >>> import minigrid
         >>> import gymnasium as gym
         >>> import matplotlib.pyplot as plt
         >>> from minigrid.wrappers import RGBImgObsWrapper, RGBImgPartialObsWrapper
@@ -391,7 +391,7 @@ class FullyObsWrapper(ObservationWrapper):
     Fully observable gridworld using a compact grid encoding instead of the agent view.
 
     Example:
-        >>> import miniworld
+        >>> import minigrid
         >>> import gymnasium as gym
         >>> import matplotlib.pyplot as plt
         >>> from minigrid.wrappers import FullyObsWrapper
@@ -437,7 +437,7 @@ class DictObservationSpaceWrapper(ObservationWrapper):
     This wrapper is not applicable to BabyAI environments, given that these have their own language component.
 
     Example:
-        >>> import miniworld
+        >>> import minigrid
         >>> import gymnasium as gym
         >>> import matplotlib.pyplot as plt
         >>> from minigrid.wrappers import DictObservationSpaceWrapper
@@ -571,7 +571,7 @@ class FlatObsWrapper(ObservationWrapper):
     This wrapper is not applicable to BabyAI environments, given that these have their own language component.
 
     Example:
-        >>> import miniworld
+        >>> import minigrid
         >>> import gymnasium as gym
         >>> import matplotlib.pyplot as plt
         >>> from minigrid.wrappers import FlatObsWrapper
@@ -643,7 +643,7 @@ class ViewSizeWrapper(ObservationWrapper):
     This cannot be used with fully observable wrappers.
 
     Example:
-        >>> import miniworld
+        >>> import minigrid
         >>> import gymnasium as gym
         >>> import matplotlib.pyplot as plt
         >>> from minigrid.wrappers import ViewSizeWrapper
@@ -692,7 +692,7 @@ class DirectionObsWrapper(ObservationWrapper):
     type = {slope , angle}
 
     Example:
-        >>> import miniworld
+        >>> import minigrid
         >>> import gymnasium as gym
         >>> import matplotlib.pyplot as plt
         >>> from minigrid.wrappers import DirectionObsWrapper
@@ -745,7 +745,7 @@ class SymbolicObsWrapper(ObservationWrapper):
     the coordinates on the grid, and IDX is the id of the object.
 
     Example:
-        >>> import miniworld
+        >>> import minigrid
         >>> import gymnasium as gym
         >>> import matplotlib.pyplot as plt
         >>> from minigrid.wrappers import SymbolicObsWrapper
@@ -777,9 +777,11 @@ class SymbolicObsWrapper(ObservationWrapper):
             [OBJECT_TO_IDX[o.type] if o is not None else -1 for o in self.grid.grid]
         )
         agent_pos = self.env.agent_pos
-        w, h = self.width, self.height
-        grid = np.mgrid[:w, :h]
-        grid = np.concatenate([grid, objects.reshape(1, w, h)])
+        ncol, nrow = self.width, self.height
+        grid = np.mgrid[:ncol, :nrow]
+        _objects = np.transpose(objects.reshape(1, nrow, ncol), (0, 2, 1))
+
+        grid = np.concatenate([grid, _objects])
         grid = np.transpose(grid, (1, 2, 0))
         grid[agent_pos[0], agent_pos[1], 2] = OBJECT_TO_IDX["agent"]
         obs["image"] = grid

+ 28 - 3
tests/test_wrappers.py

@@ -7,6 +7,7 @@ import numpy as np
 import pytest
 
 from minigrid.core.actions import Actions
+from minigrid.core.constants import OBJECT_TO_IDX
 from minigrid.envs import EmptyEnv
 from minigrid.wrappers import (
     ActionBonus,
@@ -293,12 +294,36 @@ def test_direction_obs_wrapper(env_id, type):
     env.close()
 
 
-@pytest.mark.parametrize("env_id", ["MiniGrid-Empty-16x16-v0"])
+@pytest.mark.parametrize("env_id", ["MiniGrid-DistShift1-v0"])
 def test_symbolic_obs_wrapper(env_id):
     env = gym.make(env_id)
+
     env = SymbolicObsWrapper(env)
-    obs, _ = env.reset()
+    obs, _ = env.reset(seed=123)
+    agent_pos = env.agent_pos
+    goal_pos = env.goal_pos
+
     assert obs["image"].shape == (env.width, env.height, 3)
-    obs, _, _, _, _ = env.step(0)
+    assert np.alltrue(
+        obs["image"][agent_pos[0], agent_pos[1], :]
+        == np.array([agent_pos[0], agent_pos[1], OBJECT_TO_IDX["agent"]])
+    )
+    assert np.alltrue(
+        obs["image"][goal_pos[0], goal_pos[1], :]
+        == np.array([goal_pos[0], goal_pos[1], OBJECT_TO_IDX["goal"]])
+    )
+
+    obs, _, _, _, _ = env.step(2)
+    agent_pos = env.agent_pos
+    goal_pos = env.goal_pos
+
     assert obs["image"].shape == (env.width, env.height, 3)
+    assert np.alltrue(
+        obs["image"][agent_pos[0], agent_pos[1], :]
+        == np.array([agent_pos[0], agent_pos[1], OBJECT_TO_IDX["agent"]])
+    )
+    assert np.alltrue(
+        obs["image"][goal_pos[0], goal_pos[1], :]
+        == np.array([goal_pos[0], goal_pos[1], OBJECT_TO_IDX["goal"]])
+    )
     env.close()