Browse Source

Added position filtering function for object placement

Maxime Chevalier-Boisvert 7 years ago
parent
commit
03a5d5099b
2 changed files with 17 additions and 7 deletions
  1. 2 2
      gym_minigrid/envs/roomgrid.py
  2. 15 5
      gym_minigrid/minigrid.py

+ 2 - 2
gym_minigrid/envs/roomgrid.py

@@ -134,7 +134,7 @@ class RoomGrid(MiniGridEnv):
         # By default, this environment has no mission
         self.mission = ''
 
-    def add_object(self, i, j, kind, color):
+    def add_object(self, i, j, kind, color, reject_fn=None):
         """
         Add a new object to room (i, j)
         """
@@ -150,7 +150,7 @@ class RoomGrid(MiniGridEnv):
 
         room = self.get_room(i, j)
 
-        self.placeObj(obj, room.top, room.size)
+        self.placeObj(obj, room.top, room.size, reject_fn)
 
         room.objs.append(obj)
 

+ 15 - 5
gym_minigrid/minigrid.py

@@ -596,7 +596,7 @@ class MiniGridEnv(gym.Env):
         # Initialize the state
         self.seed()
         self.reset()
-    
+
     def __str__(self):
         """
         Produce a pretty string of the environment's grid along with the agent.
@@ -652,7 +652,7 @@ class MiniGridEnv(gym.Env):
             2: '⏪',
             3: '⏫'
         }
-        
+
         array = self.grid.encode()
 
         array = rotate_left(array)
@@ -679,10 +679,10 @@ class MiniGridEnv(gym.Env):
                 new_line.append(object_ids + color_ids)
 
             new_array.append(new_line)
-        
+
         # Add the agent
         new_array[self.agentPos[1]][self.agentPos[0]] = AGENT_DIR_TO_IDS[self.agentDir]
-        
+
         return "\n".join([" ".join(line) for line in new_array])
 
     def _genGrid(self, width, height):
@@ -740,12 +740,13 @@ class MiniGridEnv(gym.Env):
             self.np_random.randint(yLow, yHigh)
         )
 
-    def placeObj(self, obj, top=None, size=None):
+    def placeObj(self, obj, top=None, size=None, reject_fn=None):
         """
         Place an object at an empty position in the grid
 
         :param top: top-left position of the rectangle where to place
         :param size: size of the rectangle where to place
+        :param reject_fn: function to filter out potential positions
         """
 
         if top is None:
@@ -759,10 +760,19 @@ class MiniGridEnv(gym.Env):
                 self._randInt(top[0], top[0] + size[0]),
                 self._randInt(top[1], top[1] + size[1])
             )
+
+            # Don't place the object on top of another object
             if self.grid.get(*pos) != None:
                 continue
+
+            # Don't place the object where the agent is
             if pos == self.startPos:
                 continue
+
+            # Check if there is a filtering criterion
+            if reject_fn and reject_fn(self, pos):
+                continue
+
             break
 
         self.grid.set(*pos, obj)