Browse Source

Added timeout to place_obj and place_agent

Maxime Chevalier-Boisvert 6 years ago
parent
commit
7acd1ea326
2 changed files with 25 additions and 4 deletions
  1. 1 1
      gym_minigrid/envs/redbluedoors.py
  2. 24 3
      gym_minigrid/minigrid.py

+ 1 - 1
gym_minigrid/envs/redbluedoors.py

@@ -11,7 +11,7 @@ class RedBlueDoorEnv(MiniGridEnv):
 
         super().__init__(
             grid_size=2*size,
-            max_steps=20*size*size
+            max_steps=10*size*size
         )
 
     def _gen_grid(self, width, height):

+ 24 - 3
gym_minigrid/minigrid.py

@@ -889,7 +889,13 @@ class MiniGridEnv(gym.Env):
             self.np_random.randint(yLow, yHigh)
         )
 
-    def place_obj(self, obj, top=None, size=None, reject_fn=None):
+    def place_obj(self,
+        obj,
+        top=None,
+        size=None,
+        reject_fn=None,
+        max_tries=math.inf
+    ):
         """
         Place an object at an empty position in the grid
 
@@ -904,7 +910,16 @@ class MiniGridEnv(gym.Env):
         if size is None:
             size = (self.grid.width, self.grid.height)
 
+        num_tries = 0
+
         while True:
+            # This is to handle with rare cases where rejection sampling
+            # gets stuck in an infinite loop
+            if num_tries > max_tries:
+                raise RecursionError('rejection sampling failed in place_obj')
+
+            num_tries += 1
+
             pos = np.array((
                 self._rand_int(top[0], top[0] + size[0]),
                 self._rand_int(top[1], top[1] + size[1])
@@ -932,13 +947,19 @@ class MiniGridEnv(gym.Env):
 
         return pos
 
-    def place_agent(self, top=None, size=None, rand_dir=True):
+    def place_agent(
+        self,
+        top=None,
+        size=None,
+        rand_dir=True,
+        max_tries=math.inf
+    ):
         """
         Set the agent's starting point at an empty position in the grid
         """
 
         self.start_pos = None
-        pos = self.place_obj(None, top, size)
+        pos = self.place_obj(None, top, size, max_tries=max_tries)
         self.start_pos = pos
 
         if rand_dir: