|
@@ -77,6 +77,12 @@ class WorldObj:
|
|
self.color = color
|
|
self.color = color
|
|
self.contains = None
|
|
self.contains = None
|
|
|
|
|
|
|
|
+ # Initial position of the object
|
|
|
|
+ self.init_pos = None
|
|
|
|
+
|
|
|
|
+ # Current position of the object
|
|
|
|
+ self.cur_pos = None
|
|
|
|
+
|
|
def can_overlap(self):
|
|
def can_overlap(self):
|
|
"""Can the agent overlap with this?"""
|
|
"""Can the agent overlap with this?"""
|
|
return False
|
|
return False
|
|
@@ -893,6 +899,10 @@ class MiniGridEnv(gym.Env):
|
|
|
|
|
|
self.grid.set(*pos, obj)
|
|
self.grid.set(*pos, obj)
|
|
|
|
|
|
|
|
+ if obj is not None:
|
|
|
|
+ obj.init_pos = pos
|
|
|
|
+ obj.cur_pos = pos
|
|
|
|
+
|
|
return pos
|
|
return pos
|
|
|
|
|
|
def place_agent(self, top=None, size=None, rand_dir=True):
|
|
def place_agent(self, top=None, size=None, rand_dir=True):
|
|
@@ -909,7 +919,8 @@ class MiniGridEnv(gym.Env):
|
|
|
|
|
|
return pos
|
|
return pos
|
|
|
|
|
|
- def get_dir_vec(self):
|
|
|
|
|
|
+ @property
|
|
|
|
+ def dir_vec(self):
|
|
"""
|
|
"""
|
|
Get the direction vector for the agent, pointing in the direction
|
|
Get the direction vector for the agent, pointing in the direction
|
|
of forward movement.
|
|
of forward movement.
|
|
@@ -918,14 +929,23 @@ class MiniGridEnv(gym.Env):
|
|
assert self.agent_dir >= 0 and self.agent_dir < 4
|
|
assert self.agent_dir >= 0 and self.agent_dir < 4
|
|
return DIR_TO_VEC[self.agent_dir]
|
|
return DIR_TO_VEC[self.agent_dir]
|
|
|
|
|
|
- def get_right_vec(self):
|
|
|
|
|
|
+ @property
|
|
|
|
+ def right_vec(self):
|
|
"""
|
|
"""
|
|
Get the vector pointing to the right of the agent.
|
|
Get the vector pointing to the right of the agent.
|
|
"""
|
|
"""
|
|
|
|
|
|
- dx, dy = self.get_dir_vec()
|
|
|
|
|
|
+ dx, dy = self.dir_vec
|
|
return np.array((-dy, dx))
|
|
return np.array((-dy, dx))
|
|
|
|
|
|
|
|
+ @property
|
|
|
|
+ def front_pos(self):
|
|
|
|
+ """
|
|
|
|
+ Get the position of the cell that is right in front of the agent
|
|
|
|
+ """
|
|
|
|
+
|
|
|
|
+ return self.agent_pos + self.dir_vec
|
|
|
|
+
|
|
def get_view_coords(self, i, j):
|
|
def get_view_coords(self, i, j):
|
|
"""
|
|
"""
|
|
Translate and rotate absolute grid coordinates (i, j) into the
|
|
Translate and rotate absolute grid coordinates (i, j) into the
|
|
@@ -934,8 +954,8 @@ class MiniGridEnv(gym.Env):
|
|
"""
|
|
"""
|
|
|
|
|
|
ax, ay = self.agent_pos
|
|
ax, ay = self.agent_pos
|
|
- dx, dy = self.get_dir_vec()
|
|
|
|
- rx, ry = self.get_right_vec()
|
|
|
|
|
|
+ dx, dy = self.dir_vec
|
|
|
|
+ rx, ry = self.right_vec
|
|
|
|
|
|
# Compute the absolute coordinates of the top-left view corner
|
|
# Compute the absolute coordinates of the top-left view corner
|
|
sz = AGENT_VIEW_SIZE
|
|
sz = AGENT_VIEW_SIZE
|
|
@@ -1007,7 +1027,7 @@ class MiniGridEnv(gym.Env):
|
|
done = False
|
|
done = False
|
|
|
|
|
|
# Get the position in front of the agent
|
|
# Get the position in front of the agent
|
|
- fwd_pos = self.agent_pos + self.get_dir_vec()
|
|
|
|
|
|
+ fwd_pos = self.front_pos
|
|
|
|
|
|
# Get the contents of the cell in front of the agent
|
|
# Get the contents of the cell in front of the agent
|
|
fwd_cell = self.grid.get(*fwd_pos)
|
|
fwd_cell = self.grid.get(*fwd_pos)
|
|
@@ -1035,12 +1055,14 @@ class MiniGridEnv(gym.Env):
|
|
if fwd_cell and fwd_cell.can_pickup():
|
|
if fwd_cell and fwd_cell.can_pickup():
|
|
if self.carrying is None:
|
|
if self.carrying is None:
|
|
self.carrying = fwd_cell
|
|
self.carrying = fwd_cell
|
|
|
|
+ self.carrying.cur_pos = np.array([-1, -1])
|
|
self.grid.set(*fwd_pos, None)
|
|
self.grid.set(*fwd_pos, None)
|
|
|
|
|
|
# Drop an object
|
|
# Drop an object
|
|
elif action == self.actions.drop:
|
|
elif action == self.actions.drop:
|
|
if not fwd_cell and self.carrying:
|
|
if not fwd_cell and self.carrying:
|
|
self.grid.set(*fwd_pos, self.carrying)
|
|
self.grid.set(*fwd_pos, self.carrying)
|
|
|
|
+ self.carrying.cur_pos = fwd_pos
|
|
self.carrying = None
|
|
self.carrying = None
|
|
|
|
|
|
# Toggle/activate an object
|
|
# Toggle/activate an object
|
|
@@ -1205,8 +1227,8 @@ class MiniGridEnv(gym.Env):
|
|
|
|
|
|
# Compute the absolute coordinates of the bottom-left corner
|
|
# Compute the absolute coordinates of the bottom-left corner
|
|
# of the agent's view area
|
|
# of the agent's view area
|
|
- f_vec = self.get_dir_vec()
|
|
|
|
- r_vec = self.get_right_vec()
|
|
|
|
|
|
+ f_vec = self.dir_vec
|
|
|
|
+ r_vec = self.right_vec
|
|
top_left = self.agent_pos + f_vec * (AGENT_VIEW_SIZE-1) - r_vec * (AGENT_VIEW_SIZE // 2)
|
|
top_left = self.agent_pos + f_vec * (AGENT_VIEW_SIZE-1) - r_vec * (AGENT_VIEW_SIZE // 2)
|
|
|
|
|
|
# For each cell in the visibility mask
|
|
# For each cell in the visibility mask
|