Преглед изворни кода

Use the right setters for the wrapped env's observation_space (i.e. without modifying the unwrapped one's)

saleml пре 2 година
родитељ
комит
bde32a2cd2
1 измењених фајлова са 13 додато и 5 уклоњено
  1. 13 5
      gym_minigrid/wrappers.py

+ 13 - 5
gym_minigrid/wrappers.py

@@ -126,12 +126,13 @@ class OneHotPartialObsWrapper(gym.core.ObservationWrapper):
         # Number of bits per cell
         num_bits = len(OBJECT_TO_IDX) + len(COLOR_TO_IDX) + len(STATE_TO_IDX)
 
-        self.observation_space.spaces["image"] = spaces.Box(
+        new_image_space = spaces.Box(
             low=0,
             high=255,
             shape=(obs_shape[0], obs_shape[1], num_bits),
             dtype='uint8'
         )
+        self.observation_space = spaces.Dict({**self.observation_space, 'image':new_image_space})
 
     def observation(self, obs):
         img = obs['image']
@@ -163,13 +164,15 @@ class RGBImgObsWrapper(gym.core.ObservationWrapper):
 
         self.tile_size = tile_size
 
-        self.observation_space.spaces['image'] = spaces.Box(
+        new_image_space = spaces.Box(
             low=0,
             high=255,
             shape=(self.env.width * tile_size, self.env.height * tile_size, 3),
             dtype='uint8'
         )
 
+        self.observation_space = spaces.Dict({**self.observation_space, 'image':new_image_space})
+
     def observation(self, obs):
         env = self.unwrapped
 
@@ -197,13 +200,15 @@ class RGBImgPartialObsWrapper(gym.core.ObservationWrapper):
         self.tile_size = tile_size
 
         obs_shape = env.observation_space.spaces['image'].shape
-        self.observation_space.spaces['image'] = spaces.Box(
+        new_image_space = spaces.Box(
             low=0,
             high=255,
             shape=(obs_shape[0] * tile_size, obs_shape[1] * tile_size, 3),
             dtype='uint8'
         )
 
+        self.observation_space = spaces.Dict({**self.observation_space, 'image':new_image_space})
+
     def observation(self, obs):
         env = self.unwrapped
 
@@ -225,13 +230,15 @@ class FullyObsWrapper(gym.core.ObservationWrapper):
     def __init__(self, env):
         super().__init__(env)
 
-        self.observation_space.spaces["image"] = spaces.Box(
+        new_image_space = spaces.Box(
             low=0,
             high=255,
             shape=(self.env.width, self.env.height, 3),  # number of cells
             dtype='uint8'
         )
 
+        self.observation_space = spaces.Dict({**self.observation_space, 'image':new_image_space})
+
     def observation(self, obs):
         env = self.unwrapped
         full_grid = env.grid.encode()
@@ -436,12 +443,13 @@ class SymbolicObsWrapper(gym.core.ObservationWrapper):
     def __init__(self, env):
         super().__init__(env)
 
-        self.observation_space.spaces["image"] = spaces.Box(
+        new_image_space = spaces.Box(
             low=0,
             high=max(OBJECT_TO_IDX.values()),
             shape=(self.env.width, self.env.height, 3),  # number of cells
             dtype="uint8",
         )
+        self.observation_space = spaces.Dict({**self.observation_space, 'image':new_image_space})
 
     def observation(self, obs):
         objects = np.array(