فهرست منبع

Added DIR_TO_VEC array. Agent position is now a numpy array.

Maxime Chevalier-Boisvert 7 سال پیش
والد
کامیت
76b43b7534
1فایلهای تغییر یافته به همراه21 افزوده شده و 22 حذف شده
  1. 21 22
      gym_minigrid/minigrid.py

+ 21 - 22
gym_minigrid/minigrid.py

@@ -53,6 +53,18 @@ OBJECT_TO_IDX = {
 
 IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys()))
 
+# Map of agent direction indices to vectors
+DIR_TO_VEC = [
+    # Pointing right (positive X)
+    np.array((1, 0)),
+    # Down (positive Y)
+    np.array((0, 1)),
+    # Pointing left (negative X)
+    np.array((-1, 0)),
+    # Up (negative Y)
+    np.array((0, -1)),
+]
+
 class WorldObj:
     """
     Base class for grid world objects
@@ -723,8 +735,8 @@ class MiniGridEnv(gym.Env):
         self._gen_grid(self.grid_size, self.grid_size)
 
         # These fields should be defined by _gen_grid
-        assert self.start_pos != None
-        assert self.start_dir != None
+        assert self.start_pos is not None
+        assert self.start_dir is not None
 
         # Check that the agent doesn't overlap with an object
         assert self.grid.get(*self.start_pos) is None
@@ -899,17 +911,17 @@ class MiniGridEnv(gym.Env):
             size = (self.grid.width, self.grid.height)
 
         while True:
-            pos = (
+            pos = np.array((
                 self._rand_int(top[0], top[0] + size[0]),
                 self._rand_int(top[1], top[1] + size[1])
-            )
+            ))
 
             # Don't place the object on top of another object
             if self.grid.get(*pos) != None:
                 continue
 
             # Don't place the object where the agent is
-            if pos == self.start_pos:
+            if np.array_equal(pos, self.start_pos):
                 continue
 
             # Check if there is a filtering criterion
@@ -941,20 +953,8 @@ class MiniGridEnv(gym.Env):
         of forward movement.
         """
 
-        # Pointing right
-        if self.agent_dir == 0:
-            return (1, 0)
-        # Down (positive Y)
-        elif self.agent_dir == 1:
-            return (0, 1)
-        # Pointing left
-        elif self.agent_dir == 2:
-            return (-1, 0)
-        # Up (negative Y)
-        elif self.agent_dir == 3:
-            return (0, -1)
-        else:
-            assert False
+        assert self.agent_dir >= 0 and self.agent_dir < 4
+        return DIR_TO_VEC[self.agent_dir]
 
     def get_right_vec(self):
         """
@@ -962,7 +962,7 @@ class MiniGridEnv(gym.Env):
         """
 
         dx, dy = self.get_dir_vec()
-        return -dy, dx
+        return np.array((-dy, dx))
 
     def get_view_coords(self, i, j):
         """
@@ -1045,8 +1045,7 @@ class MiniGridEnv(gym.Env):
         done = False
 
         # Get the position in front of the agent
-        u, v = self.get_dir_vec()
-        fwd_pos = (self.agent_pos[0] + u, self.agent_pos[1] + v)
+        fwd_pos = self.agent_pos + self.get_dir_vec()
 
         # Get the contents of the cell in front of the agent
         fwd_cell = self.grid.get(*fwd_pos)