Browse Source

Fix type of observation_space for wrappers. (#72)

* Fix type of observation_space for wrappers.

Wrappers changing the observation space need to be of return
gym.spaces.Dict types.

* Replace f-strings with format().
Florin Gogianu 5 years ago
parent
commit
60d3df9039
2 changed files with 25 additions and 8 deletions
  1. 8 8
      gym_minigrid/wrappers.py
  2. 17 0
      run_tests.py

+ 8 - 8
gym_minigrid/wrappers.py

@@ -127,12 +127,12 @@ class OneHotPartialObsWrapper(gym.core.ObservationWrapper):
 
         num_bits = len(OBJECT_TO_IDX) + len(COLOR_TO_IDX) + len(STATE_TO_IDX)
 
-        self.observation_space = spaces.Box(
-            low=0,
-            high=255,
-            shape=(obs_shape[0], obs_shape[1], num_bits),
-            dtype='uint8'
-        )
+        self.observation_space.spaces["image"] = spaces.Box(
+                low=0,
+                high=255,
+                shape=(obs_shape[0], obs_shape[1], num_bits),
+                dtype='uint8'
+            )
 
     def observation(self, obs):
         img = obs['image']
@@ -165,7 +165,7 @@ class RGBImgObsWrapper(gym.core.ObservationWrapper):
 
         self.tile_size = tile_size
 
-        self.observation_space = spaces.Box(
+        self.observation_space.spaces['image'] = spaces.Box(
             low=0,
             high=255,
             shape=(self.env.width*tile_size, self.env.height*tile_size, 3),
@@ -192,7 +192,7 @@ class RGBImgPartialObsWrapper(gym.core.ObservationWrapper):
         self.tile_size = tile_size
 
         obs_shape = env.observation_space['image'].shape
-        self.observation_space = spaces.Box(
+        self.observation_space.spaces['image'] = spaces.Box(
             low=0,
             high=255,
             shape=(obs_shape[0] * tile_size, obs_shape[1] * tile_size, 3),

+ 17 - 0
run_tests.py

@@ -103,6 +103,23 @@ for env_name in env_list:
     env.step(0)
     env.close()
 
+    # Test the wrappers return proper observation spaces.
+    wrappers = [
+        RGBImgObsWrapper,
+        RGBImgPartialObsWrapper,
+        OneHotPartialObsWrapper
+    ]
+    for wrapper in wrappers:
+        env = wrapper(gym.make(env_name))
+        obs_space, wrapper_name = env.observation_space, wrapper.__name__
+        assert isinstance(
+            obs_space, spaces.Dict
+        ), "Observation space for {0} is not a Dict: {1}.".format(
+            wrapper_name, obs_space
+        )
+        # this shuld not fail either
+        ImgObsWrapper(env)
+
 ##############################################################################
 
 print('testing agent_sees method')