|
@@ -74,7 +74,6 @@ class StateBonus(gym.core.Wrapper):
|
|
|
|
|
|
return obs, reward, done, info
|
|
|
|
|
|
-
|
|
|
class ImgObsWrapper(gym.core.ObservationWrapper):
|
|
|
"""
|
|
|
Use rgb image as the only observation output
|
|
@@ -82,13 +81,13 @@ class ImgObsWrapper(gym.core.ObservationWrapper):
|
|
|
|
|
|
def __init__(self, env):
|
|
|
super().__init__(env)
|
|
|
- self.__dict__.update(vars(env)) # hack to pass values to super wrapper
|
|
|
+ # Hack to pass values to super wrapper
|
|
|
+ self.__dict__.update(vars(env))
|
|
|
self.observation_space = env.observation_space.spaces['image']
|
|
|
|
|
|
def observation(self, obs):
|
|
|
return obs['image']
|
|
|
|
|
|
-
|
|
|
class FullyObsWrapper(gym.core.ObservationWrapper):
|
|
|
"""
|
|
|
Fully observable gridworld using a compact grid encoding
|
|
@@ -109,7 +108,6 @@ class FullyObsWrapper(gym.core.ObservationWrapper):
|
|
|
full_grid[self.env.agent_pos[0]][self.env.agent_pos[1]] = np.array([255, self.env.agent_dir, 0])
|
|
|
return full_grid
|
|
|
|
|
|
-
|
|
|
class FlatObsWrapper(gym.core.ObservationWrapper):
|
|
|
"""
|
|
|
Encode mission strings using a one-hot scheme,
|