|
@@ -5,7 +5,7 @@ from functools import reduce
|
|
|
import numpy as np
|
|
|
import gym
|
|
|
from gym import error, spaces, utils
|
|
|
-from .minigrid import OBJECT_TO_IDX, COLOR_TO_IDX
|
|
|
+from .minigrid import OBJECT_TO_IDX, COLOR_TO_IDX, STATE_TO_IDX
|
|
|
from .minigrid import CELL_PIXELS
|
|
|
|
|
|
class ReseedWrapper(gym.core.Wrapper):
|
|
@@ -112,6 +112,47 @@ class ImgObsWrapper(gym.core.ObservationWrapper):
|
|
|
def observation(self, obs):
|
|
|
return obs['image']
|
|
|
|
|
|
+class OneHotPartialObsWrapper(gym.core.ObservationWrapper):
|
|
|
+ """
|
|
|
+ Wrapper to get a one-hot encoding of a partially observable
|
|
|
+ agent view as observation.
|
|
|
+ """
|
|
|
+
|
|
|
+ def __init__(self, env, tile_size=8):
|
|
|
+ super().__init__(env)
|
|
|
+
|
|
|
+ self.tile_size = tile_size
|
|
|
+
|
|
|
+ obs_shape = env.observation_space['image'].shape
|
|
|
+
|
|
|
+ num_bits = len(OBJECT_TO_IDX) + len(COLOR_TO_IDX) + len(STATE_TO_IDX)
|
|
|
+
|
|
|
+ self.observation_space = spaces.Box(
|
|
|
+ low=0,
|
|
|
+ high=255,
|
|
|
+ shape=(obs_shape[0], obs_shape[1], num_bits),
|
|
|
+ dtype='uint8'
|
|
|
+ )
|
|
|
+
|
|
|
+ def observation(self, obs):
|
|
|
+ img = obs['image']
|
|
|
+ out = np.zeros(self.observation_space.shape, dtype='uint8')
|
|
|
+
|
|
|
+ for i in range(img.shape[0]):
|
|
|
+ for j in range(img.shape[1]):
|
|
|
+ type = img[i, j, 0]
|
|
|
+ color = img[i, j, 1]
|
|
|
+ state = img[i, j, 2]
|
|
|
+
|
|
|
+ out[i, j, type] = 1
|
|
|
+ out[i, j, len(OBJECT_TO_IDX) + color] = 1
|
|
|
+ out[i, j, len(OBJECT_TO_IDX) + len(COLOR_TO_IDX) + color] = 1
|
|
|
+
|
|
|
+ return {
|
|
|
+ 'mission': obs['mission'],
|
|
|
+ 'image': out
|
|
|
+ }
|
|
|
+
|
|
|
class RGBImgObsWrapper(gym.core.ObservationWrapper):
|
|
|
"""
|
|
|
Wrapper to use fully observable RGB image as the only observation output,
|