|
@@ -126,12 +126,13 @@ class OneHotPartialObsWrapper(gym.core.ObservationWrapper):
|
|
|
# Number of bits per cell
|
|
|
num_bits = len(OBJECT_TO_IDX) + len(COLOR_TO_IDX) + len(STATE_TO_IDX)
|
|
|
|
|
|
- self.observation_space.spaces["image"] = spaces.Box(
|
|
|
+ new_image_space = spaces.Box(
|
|
|
low=0,
|
|
|
high=255,
|
|
|
shape=(obs_shape[0], obs_shape[1], num_bits),
|
|
|
dtype='uint8'
|
|
|
)
|
|
|
+ self.observation_space = spaces.Dict({**self.observation_space, 'image':new_image_space})
|
|
|
|
|
|
def observation(self, obs):
|
|
|
img = obs['image']
|
|
@@ -163,13 +164,15 @@ class RGBImgObsWrapper(gym.core.ObservationWrapper):
|
|
|
|
|
|
self.tile_size = tile_size
|
|
|
|
|
|
- self.observation_space.spaces['image'] = spaces.Box(
|
|
|
+ new_image_space = spaces.Box(
|
|
|
low=0,
|
|
|
high=255,
|
|
|
shape=(self.env.width * tile_size, self.env.height * tile_size, 3),
|
|
|
dtype='uint8'
|
|
|
)
|
|
|
|
|
|
+ self.observation_space = spaces.Dict({**self.observation_space, 'image':new_image_space})
|
|
|
+
|
|
|
def observation(self, obs):
|
|
|
env = self.unwrapped
|
|
|
|
|
@@ -197,13 +200,15 @@ class RGBImgPartialObsWrapper(gym.core.ObservationWrapper):
|
|
|
self.tile_size = tile_size
|
|
|
|
|
|
obs_shape = env.observation_space.spaces['image'].shape
|
|
|
- self.observation_space.spaces['image'] = spaces.Box(
|
|
|
+ new_image_space = spaces.Box(
|
|
|
low=0,
|
|
|
high=255,
|
|
|
shape=(obs_shape[0] * tile_size, obs_shape[1] * tile_size, 3),
|
|
|
dtype='uint8'
|
|
|
)
|
|
|
|
|
|
+ self.observation_space = spaces.Dict({**self.observation_space, 'image':new_image_space})
|
|
|
+
|
|
|
def observation(self, obs):
|
|
|
env = self.unwrapped
|
|
|
|
|
@@ -225,13 +230,15 @@ class FullyObsWrapper(gym.core.ObservationWrapper):
|
|
|
def __init__(self, env):
|
|
|
super().__init__(env)
|
|
|
|
|
|
- self.observation_space.spaces["image"] = spaces.Box(
|
|
|
+ new_image_space = spaces.Box(
|
|
|
low=0,
|
|
|
high=255,
|
|
|
shape=(self.env.width, self.env.height, 3), # number of cells
|
|
|
dtype='uint8'
|
|
|
)
|
|
|
|
|
|
+ self.observation_space = spaces.Dict({**self.observation_space, 'image':new_image_space})
|
|
|
+
|
|
|
def observation(self, obs):
|
|
|
env = self.unwrapped
|
|
|
full_grid = env.grid.encode()
|
|
@@ -436,12 +443,13 @@ class SymbolicObsWrapper(gym.core.ObservationWrapper):
|
|
|
def __init__(self, env):
|
|
|
super().__init__(env)
|
|
|
|
|
|
- self.observation_space.spaces["image"] = spaces.Box(
|
|
|
+ new_image_space = spaces.Box(
|
|
|
low=0,
|
|
|
high=max(OBJECT_TO_IDX.values()),
|
|
|
shape=(self.env.width, self.env.height, 3), # number of cells
|
|
|
dtype="uint8",
|
|
|
)
|
|
|
+ self.observation_space = spaces.Dict({**self.observation_space, 'image':new_image_space})
|
|
|
|
|
|
def observation(self, obs):
|
|
|
objects = np.array(
|