|
@@ -53,6 +53,18 @@ OBJECT_TO_IDX = {
|
|
|
|
|
|
IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys()))
|
|
|
|
|
|
+# Map of agent direction indices to vectors
|
|
|
+DIR_TO_VEC = [
|
|
|
+ # Pointing right (positive X)
|
|
|
+ np.array((1, 0)),
|
|
|
+ # Down (positive Y)
|
|
|
+ np.array((0, 1)),
|
|
|
+ # Pointing left (negative X)
|
|
|
+ np.array((-1, 0)),
|
|
|
+ # Up (negative Y)
|
|
|
+ np.array((0, -1)),
|
|
|
+]
|
|
|
+
|
|
|
class WorldObj:
|
|
|
"""
|
|
|
Base class for grid world objects
|
|
@@ -723,8 +735,8 @@ class MiniGridEnv(gym.Env):
|
|
|
self._gen_grid(self.grid_size, self.grid_size)
|
|
|
|
|
|
# These fields should be defined by _gen_grid
|
|
|
- assert self.start_pos != None
|
|
|
- assert self.start_dir != None
|
|
|
+ assert self.start_pos is not None
|
|
|
+ assert self.start_dir is not None
|
|
|
|
|
|
# Check that the agent doesn't overlap with an object
|
|
|
assert self.grid.get(*self.start_pos) is None
|
|
@@ -899,17 +911,17 @@ class MiniGridEnv(gym.Env):
|
|
|
size = (self.grid.width, self.grid.height)
|
|
|
|
|
|
while True:
|
|
|
- pos = (
|
|
|
+ pos = np.array((
|
|
|
self._rand_int(top[0], top[0] + size[0]),
|
|
|
self._rand_int(top[1], top[1] + size[1])
|
|
|
- )
|
|
|
+ ))
|
|
|
|
|
|
# Don't place the object on top of another object
|
|
|
if self.grid.get(*pos) != None:
|
|
|
continue
|
|
|
|
|
|
# Don't place the object where the agent is
|
|
|
- if pos == self.start_pos:
|
|
|
+ if np.array_equal(pos, self.start_pos):
|
|
|
continue
|
|
|
|
|
|
# Check if there is a filtering criterion
|
|
@@ -941,20 +953,8 @@ class MiniGridEnv(gym.Env):
|
|
|
of forward movement.
|
|
|
"""
|
|
|
|
|
|
- # Pointing right
|
|
|
- if self.agent_dir == 0:
|
|
|
- return (1, 0)
|
|
|
- # Down (positive Y)
|
|
|
- elif self.agent_dir == 1:
|
|
|
- return (0, 1)
|
|
|
- # Pointing left
|
|
|
- elif self.agent_dir == 2:
|
|
|
- return (-1, 0)
|
|
|
- # Up (negative Y)
|
|
|
- elif self.agent_dir == 3:
|
|
|
- return (0, -1)
|
|
|
- else:
|
|
|
- assert False
|
|
|
+ assert self.agent_dir >= 0 and self.agent_dir < 4
|
|
|
+ return DIR_TO_VEC[self.agent_dir]
|
|
|
|
|
|
def get_right_vec(self):
|
|
|
"""
|
|
@@ -962,7 +962,7 @@ class MiniGridEnv(gym.Env):
|
|
|
"""
|
|
|
|
|
|
dx, dy = self.get_dir_vec()
|
|
|
- return -dy, dx
|
|
|
+ return np.array((-dy, dx))
|
|
|
|
|
|
def get_view_coords(self, i, j):
|
|
|
"""
|
|
@@ -1045,8 +1045,7 @@ class MiniGridEnv(gym.Env):
|
|
|
done = False
|
|
|
|
|
|
# Get the position in front of the agent
|
|
|
- u, v = self.get_dir_vec()
|
|
|
- fwd_pos = (self.agent_pos[0] + u, self.agent_pos[1] + v)
|
|
|
+ fwd_pos = self.agent_pos + self.get_dir_vec()
|
|
|
|
|
|
# Get the contents of the cell in front of the agent
|
|
|
fwd_cell = self.grid.get(*fwd_pos)
|