|
@@ -709,9 +709,9 @@ class MiniGridEnv(gym.Env):
|
|
|
self.max_steps = max_steps
|
|
|
self.see_through_walls = see_through_walls
|
|
|
|
|
|
- # Starting position and direction for the agent
|
|
|
- self.start_pos = None
|
|
|
- self.start_dir = None
|
|
|
+ # Current position and direction of the agent
|
|
|
+ self.agent_pos = None
|
|
|
+ self.agent_dir = None
|
|
|
|
|
|
# Initialize the RNG
|
|
|
self.seed(seed=seed)
|
|
@@ -720,23 +720,23 @@ class MiniGridEnv(gym.Env):
|
|
|
self.reset()
|
|
|
|
|
|
def reset(self):
|
|
|
+ # Current position and direction of the agent
|
|
|
+ self.agent_pos = None
|
|
|
+ self.agent_dir = None
|
|
|
+
|
|
|
# Generate a new random grid at the start of each episode
|
|
|
# To keep the same grid for each episode, call env.seed() with
|
|
|
# the same seed before calling env.reset()
|
|
|
self._gen_grid(self.width, self.height)
|
|
|
|
|
|
# These fields should be defined by _gen_grid
|
|
|
- assert self.start_pos is not None
|
|
|
- assert self.start_dir is not None
|
|
|
+ assert self.agent_pos is not None
|
|
|
+ assert self.agent_dir is not None
|
|
|
|
|
|
# Check that the agent doesn't overlap with an object
|
|
|
- start_cell = self.grid.get(*self.start_pos)
|
|
|
+ start_cell = self.grid.get(*self.agent_pos)
|
|
|
assert start_cell is None or start_cell.can_overlap()
|
|
|
|
|
|
- # Place the agent in the starting position and direction
|
|
|
- self.agent_pos = self.start_pos
|
|
|
- self.agent_dir = self.start_dir
|
|
|
-
|
|
|
# Item picked up, being carried, initially nothing
|
|
|
self.carrying = None
|
|
|
|
|
@@ -934,7 +934,7 @@ class MiniGridEnv(gym.Env):
|
|
|
continue
|
|
|
|
|
|
# Don't place the object where the agent is
|
|
|
- if np.array_equal(pos, self.start_pos):
|
|
|
+ if np.array_equal(pos, self.agent_pos):
|
|
|
continue
|
|
|
|
|
|
# Check if there is a filtering criterion
|
|
@@ -962,12 +962,12 @@ class MiniGridEnv(gym.Env):
|
|
|
Set the agent's starting point at an empty position in the grid
|
|
|
"""
|
|
|
|
|
|
- self.start_pos = None
|
|
|
+ self.agent_pos = None
|
|
|
pos = self.place_obj(None, top, size, max_tries=max_tries)
|
|
|
- self.start_pos = pos
|
|
|
+ self.agent_pos = pos
|
|
|
|
|
|
if rand_dir:
|
|
|
- self.start_dir = self._rand_int(0, 4)
|
|
|
+ self.agent_dir = self._rand_int(0, 4)
|
|
|
|
|
|
return pos
|
|
|
|