|
@@ -8,12 +8,6 @@ from gym.utils import seeding
|
|
# Size in pixels of a cell in the full-scale human view
|
|
# Size in pixels of a cell in the full-scale human view
|
|
CELL_PIXELS = 32
|
|
CELL_PIXELS = 32
|
|
|
|
|
|
-# Number of cells (width and height) in the agent view
|
|
|
|
-AGENT_VIEW_SIZE = 7
|
|
|
|
-
|
|
|
|
-# Size of the array given as an observation to the agent
|
|
|
|
-OBS_ARRAY_SIZE = (AGENT_VIEW_SIZE, AGENT_VIEW_SIZE, 3)
|
|
|
|
-
|
|
|
|
# Map of color names to RGB values
|
|
# Map of color names to RGB values
|
|
COLORS = {
|
|
COLORS = {
|
|
'red' : np.array([255, 0, 0]),
|
|
'red' : np.array([255, 0, 0]),
|
|
@@ -669,7 +663,8 @@ class MiniGridEnv(gym.Env):
|
|
height=None,
|
|
height=None,
|
|
max_steps=100,
|
|
max_steps=100,
|
|
see_through_walls=False,
|
|
see_through_walls=False,
|
|
- seed=1337
|
|
|
|
|
|
+ seed=1337,
|
|
|
|
+ agent_view_size=7
|
|
):
|
|
):
|
|
# Can't set both grid_size and width/height
|
|
# Can't set both grid_size and width/height
|
|
if grid_size:
|
|
if grid_size:
|
|
@@ -683,12 +678,15 @@ class MiniGridEnv(gym.Env):
|
|
# Actions are discrete integer values
|
|
# Actions are discrete integer values
|
|
self.action_space = spaces.Discrete(len(self.actions))
|
|
self.action_space = spaces.Discrete(len(self.actions))
|
|
|
|
|
|
|
|
+ # Number of cells (width and height) in the agent view
|
|
|
|
+ self.agent_view_size = agent_view_size
|
|
|
|
+
|
|
# Observations are dictionaries containing an
|
|
# Observations are dictionaries containing an
|
|
# encoding of the grid and a textual 'mission' string
|
|
# encoding of the grid and a textual 'mission' string
|
|
self.observation_space = spaces.Box(
|
|
self.observation_space = spaces.Box(
|
|
low=0,
|
|
low=0,
|
|
high=255,
|
|
high=255,
|
|
- shape=OBS_ARRAY_SIZE,
|
|
|
|
|
|
+ shape=(self.agent_view_size, self.agent_view_size, 3),
|
|
dtype='uint8'
|
|
dtype='uint8'
|
|
)
|
|
)
|
|
self.observation_space = spaces.Dict({
|
|
self.observation_space = spaces.Dict({
|
|
@@ -1009,8 +1007,8 @@ class MiniGridEnv(gym.Env):
|
|
rx, ry = self.right_vec
|
|
rx, ry = self.right_vec
|
|
|
|
|
|
# Compute the absolute coordinates of the top-left view corner
|
|
# Compute the absolute coordinates of the top-left view corner
|
|
- sz = AGENT_VIEW_SIZE
|
|
|
|
- hs = AGENT_VIEW_SIZE // 2
|
|
|
|
|
|
+ sz = self.agent_view_size
|
|
|
|
+ hs = self.agent_view_size // 2
|
|
tx = ax + (dx * (sz-1)) - (rx * hs)
|
|
tx = ax + (dx * (sz-1)) - (rx * hs)
|
|
ty = ay + (dy * (sz-1)) - (ry * hs)
|
|
ty = ay + (dy * (sz-1)) - (ry * hs)
|
|
|
|
|
|
@@ -1033,24 +1031,24 @@ class MiniGridEnv(gym.Env):
|
|
# Facing right
|
|
# Facing right
|
|
if self.agent_dir == 0:
|
|
if self.agent_dir == 0:
|
|
topX = self.agent_pos[0]
|
|
topX = self.agent_pos[0]
|
|
- topY = self.agent_pos[1] - AGENT_VIEW_SIZE // 2
|
|
|
|
|
|
+ topY = self.agent_pos[1] - self.agent_view_size // 2
|
|
# Facing down
|
|
# Facing down
|
|
elif self.agent_dir == 1:
|
|
elif self.agent_dir == 1:
|
|
- topX = self.agent_pos[0] - AGENT_VIEW_SIZE // 2
|
|
|
|
|
|
+ topX = self.agent_pos[0] - self.agent_view_size // 2
|
|
topY = self.agent_pos[1]
|
|
topY = self.agent_pos[1]
|
|
# Facing left
|
|
# Facing left
|
|
elif self.agent_dir == 2:
|
|
elif self.agent_dir == 2:
|
|
- topX = self.agent_pos[0] - AGENT_VIEW_SIZE + 1
|
|
|
|
- topY = self.agent_pos[1] - AGENT_VIEW_SIZE // 2
|
|
|
|
|
|
+ topX = self.agent_pos[0] - self.agent_view_size + 1
|
|
|
|
+ topY = self.agent_pos[1] - self.agent_view_size // 2
|
|
# Facing up
|
|
# Facing up
|
|
elif self.agent_dir == 3:
|
|
elif self.agent_dir == 3:
|
|
- topX = self.agent_pos[0] - AGENT_VIEW_SIZE // 2
|
|
|
|
- topY = self.agent_pos[1] - AGENT_VIEW_SIZE + 1
|
|
|
|
|
|
+ topX = self.agent_pos[0] - self.agent_view_size // 2
|
|
|
|
+ topY = self.agent_pos[1] - self.agent_view_size + 1
|
|
else:
|
|
else:
|
|
assert False, "invalid agent direction"
|
|
assert False, "invalid agent direction"
|
|
|
|
|
|
- botX = topX + AGENT_VIEW_SIZE
|
|
|
|
- botY = topY + AGENT_VIEW_SIZE
|
|
|
|
|
|
+ botX = topX + self.agent_view_size
|
|
|
|
+ botY = topY + self.agent_view_size
|
|
|
|
|
|
return (topX, topY, botX, botY)
|
|
return (topX, topY, botX, botY)
|
|
|
|
|
|
@@ -1061,7 +1059,7 @@ class MiniGridEnv(gym.Env):
|
|
|
|
|
|
vx, vy = self.get_view_coords(x, y)
|
|
vx, vy = self.get_view_coords(x, y)
|
|
|
|
|
|
- if vx < 0 or vy < 0 or vx >= AGENT_VIEW_SIZE or vy >= AGENT_VIEW_SIZE:
|
|
|
|
|
|
+ if vx < 0 or vy < 0 or vx >= self.agent_view_size or vy >= self.agent_view_size:
|
|
return None
|
|
return None
|
|
|
|
|
|
return vx, vy
|
|
return vx, vy
|
|
@@ -1165,7 +1163,7 @@ class MiniGridEnv(gym.Env):
|
|
|
|
|
|
topX, topY, botX, botY = self.get_view_exts()
|
|
topX, topY, botX, botY = self.get_view_exts()
|
|
|
|
|
|
- grid = self.grid.slice(topX, topY, AGENT_VIEW_SIZE, AGENT_VIEW_SIZE)
|
|
|
|
|
|
+ grid = self.grid.slice(topX, topY, self.agent_view_size, self.agent_view_size)
|
|
|
|
|
|
for i in range(self.agent_dir + 1):
|
|
for i in range(self.agent_dir + 1):
|
|
grid = grid.rotate_left()
|
|
grid = grid.rotate_left()
|
|
@@ -1173,7 +1171,7 @@ class MiniGridEnv(gym.Env):
|
|
# Process occluders and visibility
|
|
# Process occluders and visibility
|
|
# Note that this incurs some performance cost
|
|
# Note that this incurs some performance cost
|
|
if not self.see_through_walls:
|
|
if not self.see_through_walls:
|
|
- vis_mask = grid.process_vis(agent_pos=(AGENT_VIEW_SIZE // 2 , AGENT_VIEW_SIZE - 1))
|
|
|
|
|
|
+ vis_mask = grid.process_vis(agent_pos=(self.agent_view_size // 2 , self.agent_view_size - 1))
|
|
else:
|
|
else:
|
|
vis_mask = np.ones(shape=(grid.width, grid.height), dtype=np.bool)
|
|
vis_mask = np.ones(shape=(grid.width, grid.height), dtype=np.bool)
|
|
|
|
|
|
@@ -1220,8 +1218,8 @@ class MiniGridEnv(gym.Env):
|
|
if self.obs_render == None:
|
|
if self.obs_render == None:
|
|
from gym_minigrid.rendering import Renderer
|
|
from gym_minigrid.rendering import Renderer
|
|
self.obs_render = Renderer(
|
|
self.obs_render = Renderer(
|
|
- AGENT_VIEW_SIZE * tile_pixels,
|
|
|
|
- AGENT_VIEW_SIZE * tile_pixels
|
|
|
|
|
|
+ self.agent_view_size * tile_pixels,
|
|
|
|
+ self.agent_view_size * tile_pixels
|
|
)
|
|
)
|
|
|
|
|
|
r = self.obs_render
|
|
r = self.obs_render
|
|
@@ -1238,8 +1236,8 @@ class MiniGridEnv(gym.Env):
|
|
r.push()
|
|
r.push()
|
|
r.scale(ratio, ratio)
|
|
r.scale(ratio, ratio)
|
|
r.translate(
|
|
r.translate(
|
|
- CELL_PIXELS * (0.5 + AGENT_VIEW_SIZE // 2),
|
|
|
|
- CELL_PIXELS * (AGENT_VIEW_SIZE - 0.5)
|
|
|
|
|
|
+ CELL_PIXELS * (0.5 + self.agent_view_size // 2),
|
|
|
|
+ CELL_PIXELS * (self.agent_view_size - 0.5)
|
|
)
|
|
)
|
|
r.rotate(3 * 90)
|
|
r.rotate(3 * 90)
|
|
r.setLineColor(255, 0, 0)
|
|
r.setLineColor(255, 0, 0)
|
|
@@ -1306,11 +1304,11 @@ class MiniGridEnv(gym.Env):
|
|
# of the agent's view area
|
|
# of the agent's view area
|
|
f_vec = self.dir_vec
|
|
f_vec = self.dir_vec
|
|
r_vec = self.right_vec
|
|
r_vec = self.right_vec
|
|
- top_left = self.agent_pos + f_vec * (AGENT_VIEW_SIZE-1) - r_vec * (AGENT_VIEW_SIZE // 2)
|
|
|
|
|
|
+ top_left = self.agent_pos + f_vec * (self.agent_view_size-1) - r_vec * (self.agent_view_size // 2)
|
|
|
|
|
|
# For each cell in the visibility mask
|
|
# For each cell in the visibility mask
|
|
- for vis_j in range(0, AGENT_VIEW_SIZE):
|
|
|
|
- for vis_i in range(0, AGENT_VIEW_SIZE):
|
|
|
|
|
|
+ for vis_j in range(0, self.agent_view_size):
|
|
|
|
+ for vis_i in range(0, self.agent_view_size):
|
|
# If this cell is not visible, don't highlight it
|
|
# If this cell is not visible, don't highlight it
|
|
if not vis_mask[vis_i, vis_j]:
|
|
if not vis_mask[vis_i, vis_j]:
|
|
continue
|
|
continue
|