|
@@ -69,11 +69,11 @@ class WorldObj:
|
|
|
"""Can the agent overlap with this?"""
|
|
|
return False
|
|
|
|
|
|
- def canPickup(self):
|
|
|
+ def can_pickup(self):
|
|
|
"""Can the agent pick this up?"""
|
|
|
return False
|
|
|
|
|
|
- def canContain(self):
|
|
|
+ def can_contain(self):
|
|
|
"""Can this contain another object?"""
|
|
|
return False
|
|
|
|
|
@@ -86,9 +86,11 @@ class WorldObj:
|
|
|
return False
|
|
|
|
|
|
def render(self, r):
|
|
|
- assert False
|
|
|
+ """Draw this object with the given renderer"""
|
|
|
+ raise NotImplementedError
|
|
|
|
|
|
def _set_color(self, r):
|
|
|
+ """Set the color of this object as the active drawing color"""
|
|
|
c = COLORS[self.color]
|
|
|
r.setLineColor(c[0], c[1], c[2])
|
|
|
r.setColor(c[0], c[1], c[2])
|
|
@@ -189,6 +191,9 @@ class LockedDoor(WorldObj):
|
|
|
"""The agent can only walk over this cell when the door is open"""
|
|
|
return self.is_open
|
|
|
|
|
|
+ def see_behind(self):
|
|
|
+ return self.is_open
|
|
|
+
|
|
|
def render(self, r):
|
|
|
c = COLORS[self.color]
|
|
|
r.setLineColor(c[0], c[1], c[2])
|
|
@@ -226,7 +231,7 @@ class Key(WorldObj):
|
|
|
def __init__(self, color='blue'):
|
|
|
super(Key, self).__init__('key', color)
|
|
|
|
|
|
- def canPickup(self):
|
|
|
+ def can_pickup(self):
|
|
|
return True
|
|
|
|
|
|
def render(self, r):
|
|
@@ -263,7 +268,7 @@ class Ball(WorldObj):
|
|
|
def __init__(self, color='blue'):
|
|
|
super(Ball, self).__init__('ball', color)
|
|
|
|
|
|
- def canPickup(self):
|
|
|
+ def can_pickup(self):
|
|
|
return True
|
|
|
|
|
|
def render(self, r):
|
|
@@ -275,7 +280,7 @@ class Box(WorldObj):
|
|
|
super(Box, self).__init__('box', color)
|
|
|
self.contains = contains
|
|
|
|
|
|
- def canPickup(self):
|
|
|
+ def can_pickup(self):
|
|
|
return True
|
|
|
|
|
|
def render(self, r):
|
|
@@ -596,6 +601,45 @@ class Grid:
|
|
|
|
|
|
return mask
|
|
|
|
|
|
+ def process_vis_prop(
|
|
|
+ grid,
|
|
|
+ agent_pos
|
|
|
+ ):
|
|
|
+ mask = np.zeros(shape=(grid.width, grid.height), dtype=np.bool)
|
|
|
+
|
|
|
+ mask[agent_pos[0], agent_pos[1]] = True
|
|
|
+
|
|
|
+ for j in reversed(range(1, grid.height)):
|
|
|
+ for i in range(0, grid.width-1):
|
|
|
+ if not mask[i, j]:
|
|
|
+ continue
|
|
|
+
|
|
|
+ cell = grid.get(i, j)
|
|
|
+ if cell and not cell.see_behind():
|
|
|
+ continue
|
|
|
+
|
|
|
+ mask[i+1, j] = True
|
|
|
+ mask[i+1, j-1] = True
|
|
|
+ mask[i, j-1] = True
|
|
|
+
|
|
|
+ for i in reversed(range(1, grid.width)):
|
|
|
+ if not mask[i, j]:
|
|
|
+ continue
|
|
|
+
|
|
|
+ cell = grid.get(i, j)
|
|
|
+ if cell and not cell.see_behind():
|
|
|
+ continue
|
|
|
+
|
|
|
+ mask[i-1, j-1] = True
|
|
|
+ mask[i-1, j] = True
|
|
|
+ mask[i, j-1] = True
|
|
|
+
|
|
|
+ for j in range(0, grid.height):
|
|
|
+ for i in range(0, grid.width):
|
|
|
+ if not mask[i, j]:
|
|
|
+ grid.set(i, j, None)
|
|
|
+ #grid.set(i, j, Wall('red'))
|
|
|
+
|
|
|
class MiniGridEnv(gym.Env):
|
|
|
"""
|
|
|
2D grid world game environment
|
|
@@ -623,7 +667,12 @@ class MiniGridEnv(gym.Env):
|
|
|
# Wait/stay put/do nothing
|
|
|
wait = 6
|
|
|
|
|
|
- def __init__(self, grid_size=16, max_steps=100):
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ grid_size=16,
|
|
|
+ max_steps=100,
|
|
|
+ see_through_walls=False
|
|
|
+ ):
|
|
|
# Action enumeration for this environment
|
|
|
self.actions = MiniGridEnv.Actions
|
|
|
|
|
@@ -654,6 +703,7 @@ class MiniGridEnv(gym.Env):
|
|
|
# Environment configuration
|
|
|
self.grid_size = grid_size
|
|
|
self.max_steps = max_steps
|
|
|
+ self.see_through_walls = see_through_walls
|
|
|
|
|
|
# Starting position and direction for the agent
|
|
|
self.start_pos = None
|
|
@@ -667,9 +717,9 @@ class MiniGridEnv(gym.Env):
|
|
|
# Generate a new random grid at the start of each episode
|
|
|
# To keep the same grid for each episode, call env.seed() with
|
|
|
# the same seed before calling env.reset()
|
|
|
- self._genGrid(self.grid_size, self.grid_size)
|
|
|
+ self._gen_grid(self.grid_size, self.grid_size)
|
|
|
|
|
|
- # These fields should be defined by _genGrid
|
|
|
+ # These fields should be defined by _gen_grid
|
|
|
assert self.start_pos != None
|
|
|
assert self.start_dir != None
|
|
|
|
|
@@ -788,8 +838,8 @@ class MiniGridEnv(gym.Env):
|
|
|
|
|
|
return "\n".join([" ".join(line) for line in new_array])
|
|
|
|
|
|
- def _genGrid(self, width, height):
|
|
|
- assert False, "_genGrid needs to be implemented by each environment"
|
|
|
+ def _gen_grid(self, width, height):
|
|
|
+ assert False, "_gen_grid needs to be implemented by each environment"
|
|
|
|
|
|
def _randInt(self, low, high):
|
|
|
"""
|
|
@@ -1005,7 +1055,7 @@ class MiniGridEnv(gym.Env):
|
|
|
|
|
|
# Pick up an object
|
|
|
elif action == self.actions.pickup:
|
|
|
- if fwdCell and fwdCell.canPickup():
|
|
|
+ if fwdCell and fwdCell.can_pickup():
|
|
|
if self.carrying is None:
|
|
|
self.carrying = fwdCell
|
|
|
self.grid.set(*fwdPos, None)
|
|
@@ -1057,7 +1107,9 @@ class MiniGridEnv(gym.Env):
|
|
|
grid.set(*agent_pos, None)
|
|
|
|
|
|
# Process occluders and visibility
|
|
|
- grid.process_vis(agent_pos=(3, 6))
|
|
|
+ # Note that this incurs some performance cost
|
|
|
+ if not self.see_through_walls:
|
|
|
+ grid.process_vis_prop(agent_pos=(3, 6))
|
|
|
|
|
|
# Encode the partially observable view into a numpy array
|
|
|
image = grid.encode()
|