Forráskód Böngészése

Added object position tracking

Maxime Chevalier-Boisvert 6 éve
szülő
commit
e270e76ee5
2 módosított fájl, 31 hozzáadás és 9 törlés
  1. 1 1
      gym_minigrid/envs/putnear.py
  2. 30 8
      gym_minigrid/minigrid.py

+ 1 - 1
gym_minigrid/envs/putnear.py

@@ -93,7 +93,7 @@ class PutNearEnv(MiniGridEnv):
 
         obs, reward, done, info = super().step(action)
 
-        u, v = self.get_dir_vec()
+        u, v = self.dir_vec
         ox, oy = (self.agent_pos[0] + u, self.agent_pos[1] + v)
         tx, ty = self.target_pos
 

+ 30 - 8
gym_minigrid/minigrid.py

@@ -77,6 +77,12 @@ class WorldObj:
         self.color = color
         self.contains = None
 
+        # Initial position of the object
+        self.init_pos = None
+
+        # Current position of the object
+        self.cur_pos = None
+
     def can_overlap(self):
         """Can the agent overlap with this?"""
         return False
@@ -893,6 +899,10 @@ class MiniGridEnv(gym.Env):
 
         self.grid.set(*pos, obj)
 
+        if obj is not None:
+            obj.init_pos = pos
+            obj.cur_pos = pos
+
         return pos
 
     def place_agent(self, top=None, size=None, rand_dir=True):
@@ -909,7 +919,8 @@ class MiniGridEnv(gym.Env):
 
         return pos
 
-    def get_dir_vec(self):
+    @property
+    def dir_vec(self):
         """
         Get the direction vector for the agent, pointing in the direction
         of forward movement.
@@ -918,14 +929,23 @@ class MiniGridEnv(gym.Env):
         assert self.agent_dir >= 0 and self.agent_dir < 4
         return DIR_TO_VEC[self.agent_dir]
 
-    def get_right_vec(self):
+    @property
+    def right_vec(self):
         """
         Get the vector pointing to the right of the agent.
         """
 
-        dx, dy = self.get_dir_vec()
+        dx, dy = self.dir_vec
         return np.array((-dy, dx))
 
+    @property
+    def front_pos(self):
+        """
+        Get the position of the cell that is right in front of the agent
+        """
+
+        return self.agent_pos + self.dir_vec
+
     def get_view_coords(self, i, j):
         """
         Translate and rotate absolute grid coordinates (i, j) into the
@@ -934,8 +954,8 @@ class MiniGridEnv(gym.Env):
         """
 
         ax, ay = self.agent_pos
-        dx, dy = self.get_dir_vec()
-        rx, ry = self.get_right_vec()
+        dx, dy = self.dir_vec
+        rx, ry = self.right_vec
 
         # Compute the absolute coordinates of the top-left view corner
         sz = AGENT_VIEW_SIZE
@@ -1007,7 +1027,7 @@ class MiniGridEnv(gym.Env):
         done = False
 
         # Get the position in front of the agent
-        fwd_pos = self.agent_pos + self.get_dir_vec()
+        fwd_pos = self.front_pos
 
         # Get the contents of the cell in front of the agent
         fwd_cell = self.grid.get(*fwd_pos)
@@ -1035,12 +1055,14 @@ class MiniGridEnv(gym.Env):
             if fwd_cell and fwd_cell.can_pickup():
                 if self.carrying is None:
                     self.carrying = fwd_cell
+                    self.carrying.cur_pos = np.array([-1, -1])
                     self.grid.set(*fwd_pos, None)
 
         # Drop an object
         elif action == self.actions.drop:
             if not fwd_cell and self.carrying:
                 self.grid.set(*fwd_pos, self.carrying)
+                self.carrying.cur_pos = fwd_pos
                 self.carrying = None
 
         # Toggle/activate an object
@@ -1205,8 +1227,8 @@ class MiniGridEnv(gym.Env):
 
         # Compute the absolute coordinates of the bottom-left corner
         # of the agent's view area
-        f_vec = self.get_dir_vec()
-        r_vec = self.get_right_vec()
+        f_vec = self.dir_vec
+        r_vec = self.right_vec
         top_left = self.agent_pos + f_vec * (AGENT_VIEW_SIZE-1) - r_vec * (AGENT_VIEW_SIZE // 2)
 
         # For each cell in the visibility mask