|
@@ -127,12 +127,12 @@ class OneHotPartialObsWrapper(gym.core.ObservationWrapper):
|
|
|
|
|
|
num_bits = len(OBJECT_TO_IDX) + len(COLOR_TO_IDX) + len(STATE_TO_IDX)
|
|
|
|
|
|
- self.observation_space = spaces.Box(
|
|
|
- low=0,
|
|
|
- high=255,
|
|
|
- shape=(obs_shape[0], obs_shape[1], num_bits),
|
|
|
- dtype='uint8'
|
|
|
- )
|
|
|
+ self.observation_space.spaces["image"] = spaces.Box(
|
|
|
+ low=0,
|
|
|
+ high=255,
|
|
|
+ shape=(obs_shape[0], obs_shape[1], num_bits),
|
|
|
+ dtype='uint8'
|
|
|
+ )
|
|
|
|
|
|
def observation(self, obs):
|
|
|
img = obs['image']
|
|
@@ -165,7 +165,7 @@ class RGBImgObsWrapper(gym.core.ObservationWrapper):
|
|
|
|
|
|
self.tile_size = tile_size
|
|
|
|
|
|
- self.observation_space = spaces.Box(
|
|
|
+ self.observation_space.spaces['image'] = spaces.Box(
|
|
|
low=0,
|
|
|
high=255,
|
|
|
shape=(self.env.width*tile_size, self.env.height*tile_size, 3),
|
|
@@ -192,7 +192,7 @@ class RGBImgPartialObsWrapper(gym.core.ObservationWrapper):
|
|
|
self.tile_size = tile_size
|
|
|
|
|
|
obs_shape = env.observation_space['image'].shape
|
|
|
- self.observation_space = spaces.Box(
|
|
|
+ self.observation_space.spaces['image'] = spaces.Box(
|
|
|
low=0,
|
|
|
high=255,
|
|
|
shape=(obs_shape[0] * tile_size, obs_shape[1] * tile_size, 3),
|