|
@@ -83,7 +83,7 @@ class ImgObsWrapper(gym.core.ObservationWrapper):
|
|
|
def __init__(self, env):
|
|
|
super().__init__(env)
|
|
|
self.__dict__.update(vars(env)) # hack to pass values to super wrapper
|
|
|
- self.observation_space = env.observation_space['image']
|
|
|
+ self.observation_space = env.observation_space.spaces['image']
|
|
|
|
|
|
def observation(self, obs):
|
|
|
return obs['image']
|
|
@@ -91,7 +91,7 @@ class ImgObsWrapper(gym.core.ObservationWrapper):
|
|
|
|
|
|
class FullyObsWrapper(gym.core.ObservationWrapper):
|
|
|
"""
|
|
|
- Fully observable gridworld
|
|
|
+ Fully observable gridworld using a compact grid encoding
|
|
|
"""
|
|
|
|
|
|
def __init__(self, env):
|
|
@@ -99,15 +99,15 @@ class FullyObsWrapper(gym.core.ObservationWrapper):
|
|
|
self.__dict__.update(vars(env)) # hack to pass values to super wrapper
|
|
|
self.observation_space = spaces.Box(
|
|
|
low=0,
|
|
|
- high=255,
|
|
|
- shape=(self.env.grid_size * 32, self.env.grid_size * 32, 3), # number of cells
|
|
|
+ high=self.env.grid_size,
|
|
|
+ shape=(self.env.grid_size, self.env.grid_size, 3), # number of cells
|
|
|
dtype='uint8'
|
|
|
)
|
|
|
|
|
|
def observation(self, obs):
|
|
|
- if self.env.grid_render is None:
|
|
|
- return np.zeros(shape=self.observation_space.shape) # dark screen as init state?
|
|
|
- return self.env.grid_render.getArray()
|
|
|
+ full_grid = self.env.grid.encode()
|
|
|
+ full_grid[self.env.agent_pos[0]][self.env.agent_pos[1]] = np.array([255, self.env.agent_dir, 0])
|
|
|
+ return full_grid
|
|
|
|
|
|
|
|
|
class FlatObsWrapper(gym.core.ObservationWrapper):
|