|
@@ -106,7 +106,6 @@ class ImgObsWrapper(gym.core.ObservationWrapper):
|
|
|
|
|
|
def __init__(self, env):
|
|
|
super().__init__(env)
|
|
|
-
|
|
|
self.observation_space = env.observation_space.spaces['image']
|
|
|
|
|
|
def observation(self, obs):
|
|
@@ -125,14 +124,15 @@ class OneHotPartialObsWrapper(gym.core.ObservationWrapper):
|
|
|
|
|
|
obs_shape = env.observation_space['image'].shape
|
|
|
|
|
|
+ # Number of bits per cell
|
|
|
num_bits = len(OBJECT_TO_IDX) + len(COLOR_TO_IDX) + len(STATE_TO_IDX)
|
|
|
|
|
|
self.observation_space.spaces["image"] = spaces.Box(
|
|
|
- low=0,
|
|
|
- high=255,
|
|
|
- shape=(obs_shape[0], obs_shape[1], num_bits),
|
|
|
- dtype='uint8'
|
|
|
- )
|
|
|
+ low=0,
|
|
|
+ high=255,
|
|
|
+ shape=(obs_shape[0], obs_shape[1], num_bits),
|
|
|
+ dtype='uint8'
|
|
|
+ )
|
|
|
|
|
|
def observation(self, obs):
|
|
|
img = obs['image']
|
|
@@ -174,12 +174,19 @@ class RGBImgObsWrapper(gym.core.ObservationWrapper):
|
|
|
|
|
|
def observation(self, obs):
|
|
|
env = self.unwrapped
|
|
|
- return env.render(
|
|
|
+
|
|
|
+ rgb_img = env.render(
|
|
|
mode='rgb_array',
|
|
|
highlight=False,
|
|
|
tile_size=self.tile_size
|
|
|
)
|
|
|
|
|
|
+ return {
|
|
|
+ 'mission': obs['mission'],
|
|
|
+ 'image': rgb_img
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
class RGBImgPartialObsWrapper(gym.core.ObservationWrapper):
|
|
|
"""
|
|
|
Wrapper to use partially observable RGB image as the only observation output
|
|
@@ -201,9 +208,16 @@ class RGBImgPartialObsWrapper(gym.core.ObservationWrapper):
|
|
|
|
|
|
def observation(self, obs):
|
|
|
env = self.unwrapped
|
|
|
+
|
|
|
+ rgb_img_partial = env.get_obs_render(
|
|
|
+ obs['image'],
|
|
|
+ tile_size=self.tile_size,
|
|
|
+ mode='rgb_array'
|
|
|
+ )
|
|
|
+
|
|
|
return {
|
|
|
'mission': obs['mission'],
|
|
|
- 'image': env.get_obs_render(obs['image'], tile_size=self.tile_size, mode='rgb_array')
|
|
|
+ 'image': rgb_img_partial
|
|
|
}
|
|
|
|
|
|
class FullyObsWrapper(gym.core.ObservationWrapper):
|
|
@@ -214,7 +228,7 @@ class FullyObsWrapper(gym.core.ObservationWrapper):
|
|
|
def __init__(self, env):
|
|
|
super().__init__(env)
|
|
|
|
|
|
- self.observation_space = spaces.Box(
|
|
|
+ self.observation_space.spaces["image"] = spaces.Box(
|
|
|
low=0,
|
|
|
high=255,
|
|
|
shape=(self.env.width, self.env.height, 3), # number of cells
|
|
@@ -230,7 +244,10 @@ class FullyObsWrapper(gym.core.ObservationWrapper):
|
|
|
env.agent_dir
|
|
|
])
|
|
|
|
|
|
- return full_grid
|
|
|
+ return {
|
|
|
+ 'mission': obs['mission'],
|
|
|
+ 'image': full_grid
|
|
|
+ }
|
|
|
|
|
|
class FlatObsWrapper(gym.core.ObservationWrapper):
|
|
|
"""
|
|
@@ -283,13 +300,14 @@ class FlatObsWrapper(gym.core.ObservationWrapper):
|
|
|
|
|
|
return obs
|
|
|
|
|
|
-class AgentViewWrapper(gym.core.Wrapper):
|
|
|
+class ViewSizeWrapper(gym.core.Wrapper):
|
|
|
"""
|
|
|
Wrapper to customize the agent field of view size.
|
|
|
+ This cannot be used with fully observable wrappers.
|
|
|
"""
|
|
|
|
|
|
def __init__(self, env, agent_view_size=7):
|
|
|
- super(AgentViewWrapper, self).__init__(env)
|
|
|
+ super().__init__(env)
|
|
|
|
|
|
# Override default view size
|
|
|
env.unwrapped.agent_view_size = agent_view_size
|