Browse Source

Update wrappers.py (#85)

* Update wrappers.py

* Update wrappers.py
Aishwarya Dabhade 5 years ago
parent
commit
af0ac22e1d
1 changed files with 25 additions and 0 deletions
  1. 25 0
      gym_minigrid/wrappers.py

+ 25 - 0
gym_minigrid/wrappers.py

@@ -328,3 +328,28 @@ class ViewSizeWrapper(gym.core.Wrapper):
 
     def step(self, action):
         return self.env.step(action)
+        
+from .minigrid import Goal
+class DirectionObsWrapper(gym.core.ObservationWrapper):
+    """
+    Provides the slope/angular direction to the goal with the observations as modeled by (y2 - y2 )/( x2 - x1)
+    type = {slope , angle}
+    """
+    def __init__(self, env,type='slope'):
+        super().__init__(env)
+        self.goal_position = None
+        self.type = type
+    
+    def reset(self):
+        obs = self.env.reset()
+        if not self.goal_position: 
+            self.goal_position = [x for x,y in enumerate(self.grid.grid) if isinstance(y,(Goal) ) ]
+            if len(self.goal_position) >= 1: # in case there are multiple goals , needs to be handled for other env types
+                self.goal_position = (int(self.goal_position[0]/self.height) , self.goal_position[0]%self.width)
+        return obs
+        
+    def observation(self, obs):
+        slope = np.divide( self.goal_position[1] - self.agent_pos[1] ,  self.goal_position[0] - self.agent_pos[0]) 
+        obs['goal_direction'] = np.arctan( slope ) if self.type == 'angle' else slope
+        return obs
+