|
@@ -889,7 +889,13 @@ class MiniGridEnv(gym.Env):
|
|
|
self.np_random.randint(yLow, yHigh)
|
|
|
)
|
|
|
|
|
|
- def place_obj(self, obj, top=None, size=None, reject_fn=None):
|
|
|
+ def place_obj(self,
|
|
|
+ obj,
|
|
|
+ top=None,
|
|
|
+ size=None,
|
|
|
+ reject_fn=None,
|
|
|
+ max_tries=math.inf
|
|
|
+ ):
|
|
|
"""
|
|
|
Place an object at an empty position in the grid
|
|
|
|
|
@@ -904,7 +910,16 @@ class MiniGridEnv(gym.Env):
|
|
|
if size is None:
|
|
|
size = (self.grid.width, self.grid.height)
|
|
|
|
|
|
+ num_tries = 0
|
|
|
+
|
|
|
while True:
|
|
|
+
|
|
|
+
|
|
|
+ if num_tries > max_tries:
|
|
|
+ raise RecursionError('rejection sampling failed in place_obj')
|
|
|
+
|
|
|
+ num_tries += 1
|
|
|
+
|
|
|
pos = np.array((
|
|
|
self._rand_int(top[0], top[0] + size[0]),
|
|
|
self._rand_int(top[1], top[1] + size[1])
|
|
@@ -932,13 +947,19 @@ class MiniGridEnv(gym.Env):
|
|
|
|
|
|
return pos
|
|
|
|
|
|
- def place_agent(self, top=None, size=None, rand_dir=True):
|
|
|
+ def place_agent(
|
|
|
+ self,
|
|
|
+ top=None,
|
|
|
+ size=None,
|
|
|
+ rand_dir=True,
|
|
|
+ max_tries=math.inf
|
|
|
+ ):
|
|
|
"""
|
|
|
Set the agent's starting point at an empty position in the grid
|
|
|
"""
|
|
|
|
|
|
self.start_pos = None
|
|
|
- pos = self.place_obj(None, top, size)
|
|
|
+ pos = self.place_obj(None, top, size, max_tries=max_tries)
|
|
|
self.start_pos = pos
|
|
|
|
|
|
if rand_dir:
|