Browse Source

fix unintended observation space size change (#345)

Joseph Bloom 2 years ago
parent
commit
9242742873
2 changed files with 12 additions and 7 deletions
  1. 1 7
      minigrid/wrappers.py
  2. 11 0
      tests/test_wrappers.py

+ 1 - 7
minigrid/wrappers.py

@@ -452,15 +452,9 @@ class DictObservationSpaceWrapper(ObservationWrapper):
         self.max_words_in_mission = max_words_in_mission
         self.word_dict = word_dict
 
-        image_observation_space = spaces.Box(
-            low=0,
-            high=255,
-            shape=(self.agent_view_size, self.agent_view_size, 3),
-            dtype="uint8",
-        )
         self.observation_space = spaces.Dict(
             {
-                "image": image_observation_space,
+                "image": env.observation_space["image"],
                 "direction": spaces.Discrete(4),
                 "mission": spaces.MultiDiscrete(
                     [len(self.word_dict.keys())] * max_words_in_mission

+ 11 - 0
tests/test_wrappers.py

@@ -327,3 +327,14 @@ def test_symbolic_obs_wrapper(env_id):
         == np.array([goal_pos[0], goal_pos[1], OBJECT_TO_IDX["goal"]])
     )
     env.close()
+
+
+def test_dict_observation_space_doesnt_clash_with_one_hot():
+    env = gym.make("MiniGrid-Empty-5x5-v0")
+    env = OneHotPartialObsWrapper(env)
+    env = DictObservationSpaceWrapper(env)
+    env.reset()
+    obs, _, _, _, _ = env.step(0)
+    assert obs["image"].shape == (7, 7, 20)
+    assert env.observation_space["image"].shape == (7, 7, 20)
+    env.close()