|
@@ -167,7 +167,7 @@ class RGBImgObsWrapper(gym.core.ObservationWrapper):
|
|
|
self.observation_space.spaces['image'] = spaces.Box(
|
|
|
low=0,
|
|
|
high=255,
|
|
|
- shape=(self.env.width*tile_size, self.env.height*tile_size, 3),
|
|
|
+ shape=(self.env.width * tile_size, self.env.height * tile_size, 3),
|
|
|
dtype='uint8'
|
|
|
)
|
|
|
|
|
@@ -197,7 +197,7 @@ class RGBImgPartialObsWrapper(gym.core.ObservationWrapper):
|
|
|
|
|
|
self.tile_size = tile_size
|
|
|
|
|
|
- obs_shape = env.observation_space['image'].shape
|
|
|
+ obs_shape = env.observation_space.spaces['image'].shape
|
|
|
self.observation_space.spaces['image'] = spaces.Box(
|
|
|
low=0,
|
|
|
high=255,
|
|
@@ -328,7 +328,7 @@ class ViewSizeWrapper(gym.core.Wrapper):
|
|
|
|
|
|
def step(self, action):
|
|
|
return self.env.step(action)
|
|
|
-
|
|
|
+
|
|
|
from .minigrid import Goal
|
|
|
class DirectionObsWrapper(gym.core.ObservationWrapper):
|
|
|
"""
|
|
@@ -339,17 +339,16 @@ class DirectionObsWrapper(gym.core.ObservationWrapper):
|
|
|
super().__init__(env)
|
|
|
self.goal_position = None
|
|
|
self.type = type
|
|
|
-
|
|
|
+
|
|
|
def reset(self):
|
|
|
obs = self.env.reset()
|
|
|
- if not self.goal_position:
|
|
|
+ if not self.goal_position:
|
|
|
self.goal_position = [x for x,y in enumerate(self.grid.grid) if isinstance(y,(Goal) ) ]
|
|
|
if len(self.goal_position) >= 1: # in case there are multiple goals , needs to be handled for other env types
|
|
|
self.goal_position = (int(self.goal_position[0]/self.height) , self.goal_position[0]%self.width)
|
|
|
return obs
|
|
|
-
|
|
|
+
|
|
|
def observation(self, obs):
|
|
|
- slope = np.divide( self.goal_position[1] - self.agent_pos[1] , self.goal_position[0] - self.agent_pos[0])
|
|
|
+ slope = np.divide( self.goal_position[1] - self.agent_pos[1] , self.goal_position[0] - self.agent_pos[0])
|
|
|
obs['goal_direction'] = np.arctan( slope ) if self.type == 'angle' else slope
|
|
|
return obs
|
|
|
-
|