瀏覽代碼

Implemented raycasting so agent can't see through walls

Maxime Chevalier-Boisvert 7 年之前
父節點
當前提交
97e3182fcf
共有 1 個文件被更改,包括 96 次插入31 次删除
  1. 96 31
      gym_minigrid/minigrid.py

+ 96 - 31
gym_minigrid/minigrid.py

@@ -65,7 +65,7 @@ class WorldObj:
         self.color = color
         self.contains = None
 
-    def canOverlap(self):
+    def can_overlap(self):
         """Can the agent overlap with this?"""
         return False
 
@@ -77,6 +77,10 @@ class WorldObj:
         """Can this contain another object?"""
         return False
 
+    def see_behind(self):
+        """Can the agent see behind this object?"""
+        return True
+
     def toggle(self, env, pos):
         """Method to trigger/toggle an action this object performs"""
         return False
@@ -84,7 +88,7 @@ class WorldObj:
     def render(self, r):
         assert False
 
-    def _setColor(self, r):
+    def _set_color(self, r):
         c = COLORS[self.color]
         r.setLineColor(c[0], c[1], c[2])
         r.setColor(c[0], c[1], c[2])
@@ -93,11 +97,11 @@ class Goal(WorldObj):
     def __init__(self):
         super(Goal, self).__init__('goal', 'green')
 
-    def canOverlap(self):
+    def can_overlap(self):
         return True
 
     def render(self, r):
-        self._setColor(r)
+        self._set_color(r)
         r.drawPolygon([
             (0          , CELL_PIXELS),
             (CELL_PIXELS, CELL_PIXELS),
@@ -109,8 +113,11 @@ class Wall(WorldObj):
     def __init__(self, color='grey'):
         super(Wall, self).__init__('wall', color)
 
+    def see_behind(self):
+        return False
+
     def render(self, r):
-        self._setColor(r)
+        self._set_color(r)
         r.drawPolygon([
             (0          , CELL_PIXELS),
             (CELL_PIXELS, CELL_PIXELS),
@@ -123,6 +130,19 @@ class Door(WorldObj):
         super(Door, self).__init__('door', color)
         self.isOpen = isOpen
 
+    def can_overlap(self):
+        """The agent can only walk over this cell when the door is open"""
+        return self.isOpen
+
+    def see_behind(self):
+        return self.isOpen
+
+    def toggle(self, env, pos):
+        if not self.isOpen:
+            self.isOpen = True
+            return True
+        return False
+
     def render(self, r):
         c = COLORS[self.color]
         r.setLineColor(c[0], c[1], c[2])
@@ -151,21 +171,24 @@ class Door(WorldObj):
         ])
         r.drawCircle(CELL_PIXELS * 0.75, CELL_PIXELS * 0.5, 2)
 
+class LockedDoor(WorldObj):
+    def __init__(self, color, isOpen=False):
+        super(LockedDoor, self).__init__('locked_door', color)
+        self.isOpen = isOpen
+
     def toggle(self, env, pos):
-        if not self.isOpen:
+        # If the player has the right key to open the door
+        if isinstance(env.carrying, Key) and env.carrying.color == self.color:
             self.isOpen = True
+            # The key has been used, remove it from the agent
+            env.carrying = None
             return True
         return False
 
-    def canOverlap(self):
+    def can_overlap(self):
         """The agent can only walk over this cell when the door is open"""
         return self.isOpen
 
-class LockedDoor(WorldObj):
-    def __init__(self, color, isOpen=False):
-        super(LockedDoor, self).__init__('locked_door', color)
-        self.isOpen = isOpen
-
     def render(self, r):
         c = COLORS[self.color]
         r.setLineColor(c[0], c[1], c[2])
@@ -199,19 +222,6 @@ class LockedDoor(WorldObj):
             CELL_PIXELS * 0.5
         )
 
-    def toggle(self, env, pos):
-        # If the player has the right key to open the door
-        if isinstance(env.carrying, Key) and env.carrying.color == self.color:
-            self.isOpen = True
-            # The key has been used, remove it from the agent
-            env.carrying = None
-            return True
-        return False
-
-    def canOverlap(self):
-        """The agent can only walk over this cell when the door is open"""
-        return self.isOpen
-
 class Key(WorldObj):
     def __init__(self, color='blue'):
         super(Key, self).__init__('key', color)
@@ -220,7 +230,7 @@ class Key(WorldObj):
         return True
 
     def render(self, r):
-        self._setColor(r)
+        self._set_color(r)
 
         # Vertical quad
         r.drawPolygon([
@@ -257,7 +267,7 @@ class Ball(WorldObj):
         return True
 
     def render(self, r):
-        self._setColor(r)
+        self._set_color(r)
         r.drawCircle(CELL_PIXELS * 0.5, CELL_PIXELS * 0.5, 10)
 
 class Box(WorldObj):
@@ -532,6 +542,58 @@ class Grid:
 
         return grid
 
+    def process_vis(
+        grid,
+        agent_pos,
+        n_rays = 32,
+        n_steps = 24,
+        a_min = math.pi,
+        a_max = 2 * math.pi
+    ):
+        """
+        Use ray casting to determine the visibility of each grid cell
+        """
+
+        mask = np.zeros(shape=(grid.width, grid.height), dtype=np.bool)
+
+        ang_step = (a_max - a_min) / n_rays
+        dst_step = math.sqrt(grid.width ** 2 + grid.height ** 2) / n_steps
+
+        ax = agent_pos[0] + 0.5
+        ay = agent_pos[1] + 0.5
+
+        for ray_idx in range(0, n_rays):
+            angle = a_min + ang_step * ray_idx
+            dx = dst_step * math.cos(angle)
+            dy = dst_step * math.sin(angle)
+
+            for step_idx in range(0, n_steps):
+                x = ax + (step_idx * dx)
+                y = ay + (step_idx * dy)
+
+                i = math.floor(x)
+                j = math.floor(y)
+
+                # If we're outside of the grid, stop
+                if i < 0 or i >= grid.width or j < 0 or j >= grid.height:
+                    break
+
+                # Mark this cell as visible
+                mask[i, j] = True
+
+                # If we hit the obstructor, stop
+                cell = grid.get(i, j)
+                if cell and not cell.see_behind():
+                    break
+
+        for j in range(0, grid.height):
+            for i in range(0, grid.width):
+                if not mask[i, j]:
+                    #grid.set(i, j, None)
+                    grid.set(i, j, Wall('red'))
+
+        return mask
+
 class MiniGridEnv(gym.Env):
     """
     2D grid world game environment
@@ -623,7 +685,7 @@ class MiniGridEnv(gym.Env):
         self.step_count = 0
 
         # Return first observation
-        obs = self._genObs()
+        obs = self._gen_obs()
         return obs
 
     def seed(self, seed=1337):
@@ -888,7 +950,7 @@ class MiniGridEnv(gym.Env):
 
         # Move forward
         elif action == self.actions.forward:
-            if fwdCell == None or fwdCell.canOverlap():
+            if fwdCell == None or fwdCell.can_overlap():
                 self.agent_pos = fwdPos
             if fwdCell != None and fwdCell.type == 'goal':
                 done = True
@@ -922,11 +984,11 @@ class MiniGridEnv(gym.Env):
         if self.step_count >= self.max_steps:
             done = True
 
-        obs = self._genObs()
+        obs = self._gen_obs()
 
         return obs, reward, done, {}
 
-    def _genObs(self):
+    def _gen_obs(self):
         """
         Generate the agent's view (partially observable, low-resolution encoding)
         """
@@ -947,6 +1009,9 @@ class MiniGridEnv(gym.Env):
         else:
             grid.set(*agent_pos, None)
 
+        # Process occluders and visibility
+        grid.process_vis(agent_pos=(3, 6))
+
         # Encode the partially observable view into a numpy array
         image = grid.encode()