瀏覽代碼

Change based on #91

Maxime Chevalier-Boisvert 5 年之前
父節點
當前提交
2f54e86cc0
共有 1 個文件被更改,包括 7 次插入8 次删除
  1. 7 8
      gym_minigrid/wrappers.py

+ 7 - 8
gym_minigrid/wrappers.py

@@ -167,7 +167,7 @@ class RGBImgObsWrapper(gym.core.ObservationWrapper):
         self.observation_space.spaces['image'] = spaces.Box(
             low=0,
             high=255,
-            shape=(self.env.width*tile_size, self.env.height*tile_size, 3),
+            shape=(self.env.width * tile_size, self.env.height * tile_size, 3),
             dtype='uint8'
         )
 
@@ -197,7 +197,7 @@ class RGBImgPartialObsWrapper(gym.core.ObservationWrapper):
 
         self.tile_size = tile_size
 
-        obs_shape = env.observation_space['image'].shape
+        obs_shape = env.observation_space.spaces['image'].shape
         self.observation_space.spaces['image'] = spaces.Box(
             low=0,
             high=255,
@@ -328,7 +328,7 @@ class ViewSizeWrapper(gym.core.Wrapper):
 
     def step(self, action):
         return self.env.step(action)
-        
+
 from .minigrid import Goal
 class DirectionObsWrapper(gym.core.ObservationWrapper):
     """
@@ -339,17 +339,16 @@ class DirectionObsWrapper(gym.core.ObservationWrapper):
         super().__init__(env)
         self.goal_position = None
         self.type = type
-    
+
     def reset(self):
         obs = self.env.reset()
-        if not self.goal_position: 
+        if not self.goal_position:
             self.goal_position = [x for x,y in enumerate(self.grid.grid) if isinstance(y,(Goal) ) ]
             if len(self.goal_position) >= 1: # in case there are multiple goals , needs to be handled for other env types
                 self.goal_position = (int(self.goal_position[0]/self.height) , self.goal_position[0]%self.width)
         return obs
-        
+
     def observation(self, obs):
-        slope = np.divide( self.goal_position[1] - self.agent_pos[1] ,  self.goal_position[0] - self.agent_pos[0]) 
+        slope = np.divide( self.goal_position[1] - self.agent_pos[1] ,  self.goal_position[0] - self.agent_pos[0])
         obs['goal_direction'] = np.arctan( slope ) if self.type == 'angle' else slope
         return obs
-