|
@@ -139,6 +139,32 @@ class RGBImgObsWrapper(gym.core.ObservationWrapper):
|
|
|
tile_size=self.tile_size
|
|
|
)
|
|
|
|
|
|
+class RGBImgPartialObsWrapper(gym.core.ObservationWrapper):
|
|
|
+ """
|
|
|
+ Wrapper to use partially observable RGB image as the only observation output
|
|
|
+ This can be used to have the agent to solve the gridworld in pixel space.
|
|
|
+ """
|
|
|
+
|
|
|
+ def __init__(self, env, tile_size=8):
|
|
|
+ super().__init__(env)
|
|
|
+
|
|
|
+ self.tile_size = tile_size
|
|
|
+
|
|
|
+ obs_shape = env.observation_space['image'].shape
|
|
|
+ self.observation_space = spaces.Box(
|
|
|
+ low=0,
|
|
|
+ high=255,
|
|
|
+ shape=(obs_shape[0] * tile_size, obs_shape[1] * tile_size, 3),
|
|
|
+ dtype='uint8'
|
|
|
+ )
|
|
|
+
|
|
|
+ def observation(self, obs):
|
|
|
+ env = self.unwrapped
|
|
|
+ return {
|
|
|
+ 'mission': obs['mission'],
|
|
|
+ 'image': env.get_obs_render(obs['image'], tile_size=self.tile_size, mode='rgb_array')
|
|
|
+ }
|
|
|
+
|
|
|
class FullyObsWrapper(gym.core.ObservationWrapper):
|
|
|
"""
|
|
|
Fully observable gridworld using a compact grid encoding
|