|
@@ -126,20 +126,20 @@ class Wall(WorldObj):
|
|
|
])
|
|
|
|
|
|
class Door(WorldObj):
|
|
|
- def __init__(self, color, isOpen=False):
|
|
|
+ def __init__(self, color, is_open=False):
|
|
|
super(Door, self).__init__('door', color)
|
|
|
- self.isOpen = isOpen
|
|
|
+ self.is_open = is_open
|
|
|
|
|
|
def can_overlap(self):
|
|
|
"""The agent can only walk over this cell when the door is open"""
|
|
|
- return self.isOpen
|
|
|
+ return self.is_open
|
|
|
|
|
|
def see_behind(self):
|
|
|
- return self.isOpen
|
|
|
+ return self.is_open
|
|
|
|
|
|
def toggle(self, env, pos):
|
|
|
- if not self.isOpen:
|
|
|
- self.isOpen = True
|
|
|
+ if not self.is_open:
|
|
|
+ self.is_open = True
|
|
|
return True
|
|
|
return False
|
|
|
|
|
@@ -148,7 +148,7 @@ class Door(WorldObj):
|
|
|
r.setLineColor(c[0], c[1], c[2])
|
|
|
r.setColor(0, 0, 0)
|
|
|
|
|
|
- if self.isOpen:
|
|
|
+ if self.is_open:
|
|
|
r.drawPolygon([
|
|
|
(CELL_PIXELS-2, CELL_PIXELS),
|
|
|
(CELL_PIXELS , CELL_PIXELS),
|
|
@@ -172,14 +172,14 @@ class Door(WorldObj):
|
|
|
r.drawCircle(CELL_PIXELS * 0.75, CELL_PIXELS * 0.5, 2)
|
|
|
|
|
|
class LockedDoor(WorldObj):
|
|
|
- def __init__(self, color, isOpen=False):
|
|
|
+ def __init__(self, color, is_open=False):
|
|
|
super(LockedDoor, self).__init__('locked_door', color)
|
|
|
- self.isOpen = isOpen
|
|
|
+ self.is_open = is_open
|
|
|
|
|
|
def toggle(self, env, pos):
|
|
|
# If the player has the right key to open the door
|
|
|
if isinstance(env.carrying, Key) and env.carrying.color == self.color:
|
|
|
- self.isOpen = True
|
|
|
+ self.is_open = True
|
|
|
# The key has been used, remove it from the agent
|
|
|
env.carrying = None
|
|
|
return True
|
|
@@ -187,14 +187,14 @@ class LockedDoor(WorldObj):
|
|
|
|
|
|
def can_overlap(self):
|
|
|
"""The agent can only walk over this cell when the door is open"""
|
|
|
- return self.isOpen
|
|
|
+ return self.is_open
|
|
|
|
|
|
def render(self, r):
|
|
|
c = COLORS[self.color]
|
|
|
r.setLineColor(c[0], c[1], c[2])
|
|
|
r.setColor(c[0], c[1], c[2], 50)
|
|
|
|
|
|
- if self.isOpen:
|
|
|
+ if self.is_open:
|
|
|
r.drawPolygon([
|
|
|
(CELL_PIXELS-2, CELL_PIXELS),
|
|
|
(CELL_PIXELS , CELL_PIXELS),
|
|
@@ -422,6 +422,7 @@ class Grid:
|
|
|
widthPx = self.width * CELL_PIXELS
|
|
|
heightPx = self.height * CELL_PIXELS
|
|
|
|
|
|
+ """
|
|
|
# Draw background (out-of-world) tiles the same colors as walls
|
|
|
# so the agent understands these areas are not reachable
|
|
|
c = COLORS['grey']
|
|
@@ -433,6 +434,7 @@ class Grid:
|
|
|
(widthPx, 0),
|
|
|
(0 , 0)
|
|
|
])
|
|
|
+ """
|
|
|
|
|
|
r.push()
|
|
|
|
|
@@ -491,7 +493,7 @@ class Grid:
|
|
|
array[i, j, 0] = OBJECT_TO_IDX[v.type]
|
|
|
array[i, j, 1] = COLOR_TO_IDX[v.color]
|
|
|
|
|
|
- if hasattr(v, 'isOpen') and v.isOpen:
|
|
|
+ if hasattr(v, 'is_open') and v.is_open:
|
|
|
array[i, j, 2] = 1
|
|
|
|
|
|
return array
|
|
@@ -519,7 +521,7 @@ class Grid:
|
|
|
|
|
|
objType = IDX_TO_OBJECT[typeIdx]
|
|
|
color = IDX_TO_COLOR[colorIdx]
|
|
|
- isOpen = True if openIdx == 1 else 0
|
|
|
+ is_open = True if openIdx == 1 else 0
|
|
|
|
|
|
if objType == 'wall':
|
|
|
v = Wall(color)
|
|
@@ -530,9 +532,9 @@ class Grid:
|
|
|
elif objType == 'box':
|
|
|
v = Box(color)
|
|
|
elif objType == 'door':
|
|
|
- v = Door(color, isOpen)
|
|
|
+ v = Door(color, is_open)
|
|
|
elif objType == 'locked_door':
|
|
|
- v = LockedDoor(color, isOpen)
|
|
|
+ v = LockedDoor(color, is_open)
|
|
|
elif objType == 'goal':
|
|
|
v = Goal()
|
|
|
else:
|
|
@@ -589,8 +591,8 @@ class Grid:
|
|
|
for j in range(0, grid.height):
|
|
|
for i in range(0, grid.width):
|
|
|
if not mask[i, j]:
|
|
|
- #grid.set(i, j, None)
|
|
|
- grid.set(i, j, Wall('red'))
|
|
|
+ grid.set(i, j, None)
|
|
|
+ #grid.set(i, j, Wall('red'))
|
|
|
|
|
|
return mask
|
|
|
|
|
@@ -887,7 +889,42 @@ class MiniGridEnv(gym.Env):
|
|
|
else:
|
|
|
assert False
|
|
|
|
|
|
- def getViewExts(self):
|
|
|
+ def get_right_vec(self):
|
|
|
+ """
|
|
|
+ Get the vector pointing to the right of the agent.
|
|
|
+ """
|
|
|
+
|
|
|
+ dx, dy = self.getDirVec()
|
|
|
+ return -dy, dx
|
|
|
+
|
|
|
+ def get_view_coords(self, i, j):
|
|
|
+ """
|
|
|
+ Translate and rotate absolute grid coordinates (i, j) into the
|
|
|
+ agent's partially observable view (sub-grid). Note that the resulting
|
|
|
+ coordinates may be negative or outside of the agent's view size.
|
|
|
+ """
|
|
|
+
|
|
|
+ ax, ay = self.agent_pos
|
|
|
+ dx, dy = self.getDirVec()
|
|
|
+ rx, ry = self.get_right_vec()
|
|
|
+
|
|
|
+ # Compute the absolute coordinates of the top-left view corner
|
|
|
+ sz = AGENT_VIEW_SIZE
|
|
|
+ hs = AGENT_VIEW_SIZE // 2
|
|
|
+ tx = ax + (dx * (sz-1)) - (rx * hs)
|
|
|
+ ty = ay + (dy * (sz-1)) - (ry * hs)
|
|
|
+
|
|
|
+ lx = i - tx
|
|
|
+ ly = j - ty
|
|
|
+
|
|
|
+ # Project the coordinates of the object relative to the top-left
|
|
|
+ # corner onto the agent's own coordinate system
|
|
|
+ vx = (rx*lx + ry*ly)
|
|
|
+ vy = -(dx*lx + dy*ly)
|
|
|
+
|
|
|
+ return vx, vy
|
|
|
+
|
|
|
+ def get_view_exts(self):
|
|
|
"""
|
|
|
Get the extents of the square set of tiles visible to the agent
|
|
|
Note: the bottom extent indices are not included in the set
|
|
@@ -917,13 +954,22 @@ class MiniGridEnv(gym.Env):
|
|
|
|
|
|
return (topX, topY, botX, botY)
|
|
|
|
|
|
- def agentSees(self, x, y):
|
|
|
+ def agent_sees(self, x, y):
|
|
|
"""
|
|
|
Check if a grid position is visible to the agent
|
|
|
"""
|
|
|
|
|
|
- topX, topY, botX, botY = self.getViewExts()
|
|
|
- return (x >= topX and x < botX and y >= topY and y < botY)
|
|
|
+ vx, vy = self.get_view_coords(x, y)
|
|
|
+
|
|
|
+ if vx < 0 or vy < 0 or vx >= AGENT_VIEW_SIZE or vy >= AGENT_VIEW_SIZE:
|
|
|
+ return False
|
|
|
+
|
|
|
+ obs = self._gen_obs()
|
|
|
+ obs_grid = Grid.decode(obs['image'])
|
|
|
+ obs_cell = obs_grid.get(vx, vy)
|
|
|
+ world_cell = self.grid.get(x, y)
|
|
|
+
|
|
|
+ return obs_cell is not None and obs_cell.type == world_cell.type
|
|
|
|
|
|
def step(self, action):
|
|
|
self.step_count += 1
|
|
@@ -993,7 +1039,7 @@ class MiniGridEnv(gym.Env):
|
|
|
Generate the agent's view (partially observable, low-resolution encoding)
|
|
|
"""
|
|
|
|
|
|
- topX, topY, botX, botY = self.getViewExts()
|
|
|
+ topX, topY, botX, botY = self.get_view_exts()
|
|
|
|
|
|
grid = self.grid.slice(topX, topY, AGENT_VIEW_SIZE, AGENT_VIEW_SIZE)
|
|
|
|
|
@@ -1111,7 +1157,7 @@ class MiniGridEnv(gym.Env):
|
|
|
r.pop()
|
|
|
|
|
|
# Highlight what the agent can see
|
|
|
- topX, topY, botX, botY = self.getViewExts()
|
|
|
+ topX, topY, botX, botY = self.get_view_exts()
|
|
|
r.fillRect(
|
|
|
topX * CELL_PIXELS,
|
|
|
topY * CELL_PIXELS,
|