|
@@ -12,48 +12,48 @@ TILE_PIXELS = 32
|
|
|
|
|
|
# Map of color names to RGB values
|
|
# Map of color names to RGB values
|
|
COLORS = {
|
|
COLORS = {
|
|
- 'red' : np.array([255, 0, 0]),
|
|
|
|
- 'green' : np.array([0, 255, 0]),
|
|
|
|
- 'blue' : np.array([0, 0, 255]),
|
|
|
|
|
|
+ 'red': np.array([255, 0, 0]),
|
|
|
|
+ 'green': np.array([0, 255, 0]),
|
|
|
|
+ 'blue': np.array([0, 0, 255]),
|
|
'purple': np.array([112, 39, 195]),
|
|
'purple': np.array([112, 39, 195]),
|
|
'yellow': np.array([255, 255, 0]),
|
|
'yellow': np.array([255, 255, 0]),
|
|
- 'grey' : np.array([100, 100, 100])
|
|
|
|
|
|
+ 'grey': np.array([100, 100, 100])
|
|
}
|
|
}
|
|
|
|
|
|
COLOR_NAMES = sorted(list(COLORS.keys()))
|
|
COLOR_NAMES = sorted(list(COLORS.keys()))
|
|
|
|
|
|
# Used to map colors to integers
|
|
# Used to map colors to integers
|
|
COLOR_TO_IDX = {
|
|
COLOR_TO_IDX = {
|
|
- 'red' : 0,
|
|
|
|
- 'green' : 1,
|
|
|
|
- 'blue' : 2,
|
|
|
|
|
|
+ 'red': 0,
|
|
|
|
+ 'green': 1,
|
|
|
|
+ 'blue': 2,
|
|
'purple': 3,
|
|
'purple': 3,
|
|
'yellow': 4,
|
|
'yellow': 4,
|
|
- 'grey' : 5
|
|
|
|
|
|
+ 'grey': 5
|
|
}
|
|
}
|
|
|
|
|
|
IDX_TO_COLOR = dict(zip(COLOR_TO_IDX.values(), COLOR_TO_IDX.keys()))
|
|
IDX_TO_COLOR = dict(zip(COLOR_TO_IDX.values(), COLOR_TO_IDX.keys()))
|
|
|
|
|
|
# Map of object type to integers
|
|
# Map of object type to integers
|
|
OBJECT_TO_IDX = {
|
|
OBJECT_TO_IDX = {
|
|
- 'unseen' : 0,
|
|
|
|
- 'empty' : 1,
|
|
|
|
- 'wall' : 2,
|
|
|
|
- 'floor' : 3,
|
|
|
|
- 'door' : 4,
|
|
|
|
- 'key' : 5,
|
|
|
|
- 'ball' : 6,
|
|
|
|
- 'box' : 7,
|
|
|
|
- 'goal' : 8,
|
|
|
|
- 'lava' : 9,
|
|
|
|
- 'agent' : 10,
|
|
|
|
|
|
+ 'unseen': 0,
|
|
|
|
+ 'empty': 1,
|
|
|
|
+ 'wall': 2,
|
|
|
|
+ 'floor': 3,
|
|
|
|
+ 'door': 4,
|
|
|
|
+ 'key': 5,
|
|
|
|
+ 'ball': 6,
|
|
|
|
+ 'box': 7,
|
|
|
|
+ 'goal': 8,
|
|
|
|
+ 'lava': 9,
|
|
|
|
+ 'agent': 10,
|
|
}
|
|
}
|
|
|
|
|
|
IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys()))
|
|
IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys()))
|
|
|
|
|
|
# Map of state names to integers
|
|
# Map of state names to integers
|
|
STATE_TO_IDX = {
|
|
STATE_TO_IDX = {
|
|
- 'open' : 0,
|
|
|
|
|
|
+ 'open': 0,
|
|
'closed': 1,
|
|
'closed': 1,
|
|
'locked': 2,
|
|
'locked': 2,
|
|
}
|
|
}
|
|
@@ -70,6 +70,7 @@ DIR_TO_VEC = [
|
|
np.array((0, -1)),
|
|
np.array((0, -1)),
|
|
]
|
|
]
|
|
|
|
|
|
|
|
+
|
|
class WorldObj:
|
|
class WorldObj:
|
|
"""
|
|
"""
|
|
Base class for grid world objects
|
|
Base class for grid world objects
|
|
@@ -151,6 +152,7 @@ class WorldObj:
|
|
"""Draw this object with the given renderer"""
|
|
"""Draw this object with the given renderer"""
|
|
raise NotImplementedError
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
+
|
|
class Goal(WorldObj):
|
|
class Goal(WorldObj):
|
|
def __init__(self):
|
|
def __init__(self):
|
|
super().__init__('goal', 'green')
|
|
super().__init__('goal', 'green')
|
|
@@ -161,6 +163,7 @@ class Goal(WorldObj):
|
|
def render(self, img):
|
|
def render(self, img):
|
|
fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color])
|
|
fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color])
|
|
|
|
|
|
|
|
+
|
|
class Floor(WorldObj):
|
|
class Floor(WorldObj):
|
|
"""
|
|
"""
|
|
Colored floor tile the agent can walk over
|
|
Colored floor tile the agent can walk over
|
|
@@ -195,10 +198,15 @@ class Lava(WorldObj):
|
|
for i in range(3):
|
|
for i in range(3):
|
|
ylo = 0.3 + 0.2 * i
|
|
ylo = 0.3 + 0.2 * i
|
|
yhi = 0.4 + 0.2 * i
|
|
yhi = 0.4 + 0.2 * i
|
|
- fill_coords(img, point_in_line(0.1, ylo, 0.3, yhi, r=0.03), (0,0,0))
|
|
|
|
- fill_coords(img, point_in_line(0.3, yhi, 0.5, ylo, r=0.03), (0,0,0))
|
|
|
|
- fill_coords(img, point_in_line(0.5, ylo, 0.7, yhi, r=0.03), (0,0,0))
|
|
|
|
- fill_coords(img, point_in_line(0.7, yhi, 0.9, ylo, r=0.03), (0,0,0))
|
|
|
|
|
|
+ fill_coords(img, point_in_line(
|
|
|
|
+ 0.1, ylo, 0.3, yhi, r=0.03), (0, 0, 0))
|
|
|
|
+ fill_coords(img, point_in_line(
|
|
|
|
+ 0.3, yhi, 0.5, ylo, r=0.03), (0, 0, 0))
|
|
|
|
+ fill_coords(img, point_in_line(
|
|
|
|
+ 0.5, ylo, 0.7, yhi, r=0.03), (0, 0, 0))
|
|
|
|
+ fill_coords(img, point_in_line(
|
|
|
|
+ 0.7, yhi, 0.9, ylo, r=0.03), (0, 0, 0))
|
|
|
|
+
|
|
|
|
|
|
class Wall(WorldObj):
|
|
class Wall(WorldObj):
|
|
def __init__(self, color='grey'):
|
|
def __init__(self, color='grey'):
|
|
@@ -210,6 +218,7 @@ class Wall(WorldObj):
|
|
def render(self, img):
|
|
def render(self, img):
|
|
fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color])
|
|
fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color])
|
|
|
|
|
|
|
|
+
|
|
class Door(WorldObj):
|
|
class Door(WorldObj):
|
|
def __init__(self, color, is_open=False, is_locked=False):
|
|
def __init__(self, color, is_open=False, is_locked=False):
|
|
super().__init__('door', color)
|
|
super().__init__('door', color)
|
|
@@ -253,25 +262,27 @@ class Door(WorldObj):
|
|
|
|
|
|
if self.is_open:
|
|
if self.is_open:
|
|
fill_coords(img, point_in_rect(0.88, 1.00, 0.00, 1.00), c)
|
|
fill_coords(img, point_in_rect(0.88, 1.00, 0.00, 1.00), c)
|
|
- fill_coords(img, point_in_rect(0.92, 0.96, 0.04, 0.96), (0,0,0))
|
|
|
|
|
|
+ fill_coords(img, point_in_rect(0.92, 0.96, 0.04, 0.96), (0, 0, 0))
|
|
return
|
|
return
|
|
|
|
|
|
# Door frame and door
|
|
# Door frame and door
|
|
if self.is_locked:
|
|
if self.is_locked:
|
|
fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)
|
|
fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)
|
|
- fill_coords(img, point_in_rect(0.06, 0.94, 0.06, 0.94), 0.45 * np.array(c))
|
|
|
|
|
|
+ fill_coords(img, point_in_rect(
|
|
|
|
+ 0.06, 0.94, 0.06, 0.94), 0.45 * np.array(c))
|
|
|
|
|
|
# Draw key slot
|
|
# Draw key slot
|
|
fill_coords(img, point_in_rect(0.52, 0.75, 0.50, 0.56), c)
|
|
fill_coords(img, point_in_rect(0.52, 0.75, 0.50, 0.56), c)
|
|
else:
|
|
else:
|
|
fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)
|
|
fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)
|
|
- fill_coords(img, point_in_rect(0.04, 0.96, 0.04, 0.96), (0,0,0))
|
|
|
|
|
|
+ fill_coords(img, point_in_rect(0.04, 0.96, 0.04, 0.96), (0, 0, 0))
|
|
fill_coords(img, point_in_rect(0.08, 0.92, 0.08, 0.92), c)
|
|
fill_coords(img, point_in_rect(0.08, 0.92, 0.08, 0.92), c)
|
|
- fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), (0,0,0))
|
|
|
|
|
|
+ fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), (0, 0, 0))
|
|
|
|
|
|
# Draw door handle
|
|
# Draw door handle
|
|
fill_coords(img, point_in_circle(cx=0.75, cy=0.50, r=0.08), c)
|
|
fill_coords(img, point_in_circle(cx=0.75, cy=0.50, r=0.08), c)
|
|
|
|
|
|
|
|
+
|
|
class Key(WorldObj):
|
|
class Key(WorldObj):
|
|
def __init__(self, color='blue'):
|
|
def __init__(self, color='blue'):
|
|
super(Key, self).__init__('key', color)
|
|
super(Key, self).__init__('key', color)
|
|
@@ -291,7 +302,8 @@ class Key(WorldObj):
|
|
|
|
|
|
# Ring
|
|
# Ring
|
|
fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.190), c)
|
|
fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.190), c)
|
|
- fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.064), (0,0,0))
|
|
|
|
|
|
+ fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.064), (0, 0, 0))
|
|
|
|
+
|
|
|
|
|
|
class Ball(WorldObj):
|
|
class Ball(WorldObj):
|
|
def __init__(self, color='blue'):
|
|
def __init__(self, color='blue'):
|
|
@@ -303,6 +315,7 @@ class Ball(WorldObj):
|
|
def render(self, img):
|
|
def render(self, img):
|
|
fill_coords(img, point_in_circle(0.5, 0.5, 0.31), COLORS[self.color])
|
|
fill_coords(img, point_in_circle(0.5, 0.5, 0.31), COLORS[self.color])
|
|
|
|
|
|
|
|
+
|
|
class Box(WorldObj):
|
|
class Box(WorldObj):
|
|
def __init__(self, color, contains=None):
|
|
def __init__(self, color, contains=None):
|
|
super(Box, self).__init__('box', color)
|
|
super(Box, self).__init__('box', color)
|
|
@@ -316,7 +329,7 @@ class Box(WorldObj):
|
|
|
|
|
|
# Outline
|
|
# Outline
|
|
fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), c)
|
|
fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), c)
|
|
- fill_coords(img, point_in_rect(0.18, 0.82, 0.18, 0.82), (0,0,0))
|
|
|
|
|
|
+ fill_coords(img, point_in_rect(0.18, 0.82, 0.18, 0.82), (0, 0, 0))
|
|
|
|
|
|
# Horizontal slit
|
|
# Horizontal slit
|
|
fill_coords(img, point_in_rect(0.16, 0.84, 0.47, 0.53), c)
|
|
fill_coords(img, point_in_rect(0.16, 0.84, 0.47, 0.53), c)
|
|
@@ -326,6 +339,7 @@ class Box(WorldObj):
|
|
env.grid.set(*pos, self.contains)
|
|
env.grid.set(*pos, self.contains)
|
|
return True
|
|
return True
|
|
|
|
|
|
|
|
+
|
|
class Grid:
|
|
class Grid:
|
|
"""
|
|
"""
|
|
Represent a grid and operations on it
|
|
Represent a grid and operations on it
|
|
@@ -359,7 +373,7 @@ class Grid:
|
|
return False
|
|
return False
|
|
|
|
|
|
def __eq__(self, other):
|
|
def __eq__(self, other):
|
|
- grid1 = self.encode()
|
|
|
|
|
|
+ grid1 = self.encode()
|
|
grid2 = other.encode()
|
|
grid2 = other.encode()
|
|
return np.array_equal(grid2, grid1)
|
|
return np.array_equal(grid2, grid1)
|
|
|
|
|
|
@@ -454,7 +468,8 @@ class Grid:
|
|
if key in cls.tile_cache:
|
|
if key in cls.tile_cache:
|
|
return cls.tile_cache[key]
|
|
return cls.tile_cache[key]
|
|
|
|
|
|
- img = np.zeros(shape=(tile_size * subdivs, tile_size * subdivs, 3), dtype=np.uint8)
|
|
|
|
|
|
+ img = np.zeros(shape=(tile_size * subdivs,
|
|
|
|
+ tile_size * subdivs, 3), dtype=np.uint8)
|
|
|
|
|
|
# Draw the grid lines (top and left edges)
|
|
# Draw the grid lines (top and left edges)
|
|
fill_coords(img, point_in_rect(0, 0.031, 0, 1), (100, 100, 100))
|
|
fill_coords(img, point_in_rect(0, 0.031, 0, 1), (100, 100, 100))
|
|
@@ -472,7 +487,8 @@ class Grid:
|
|
)
|
|
)
|
|
|
|
|
|
# Rotate the agent based on its direction
|
|
# Rotate the agent based on its direction
|
|
- tri_fn = rotate_fn(tri_fn, cx=0.5, cy=0.5, theta=0.5*math.pi*agent_dir)
|
|
|
|
|
|
+ tri_fn = rotate_fn(tri_fn, cx=0.5, cy=0.5,
|
|
|
|
+ theta=0.5*math.pi*agent_dir)
|
|
fill_coords(img, tri_fn, (255, 0, 0))
|
|
fill_coords(img, tri_fn, (255, 0, 0))
|
|
|
|
|
|
# Highlight the cell if needed
|
|
# Highlight the cell if needed
|
|
@@ -501,7 +517,8 @@ class Grid:
|
|
"""
|
|
"""
|
|
|
|
|
|
if highlight_mask is None:
|
|
if highlight_mask is None:
|
|
- highlight_mask = np.zeros(shape=(self.width, self.height), dtype=bool)
|
|
|
|
|
|
+ highlight_mask = np.zeros(
|
|
|
|
+ shape=(self.width, self.height), dtype=bool)
|
|
|
|
|
|
# Compute the total grid size
|
|
# Compute the total grid size
|
|
width_px = self.width * tile_size
|
|
width_px = self.width * tile_size
|
|
@@ -615,14 +632,18 @@ class Grid:
|
|
|
|
|
|
return mask
|
|
return mask
|
|
|
|
|
|
|
|
+
|
|
class MiniGridEnv(gym.Env):
|
|
class MiniGridEnv(gym.Env):
|
|
"""
|
|
"""
|
|
2D grid world game environment
|
|
2D grid world game environment
|
|
"""
|
|
"""
|
|
|
|
|
|
metadata = {
|
|
metadata = {
|
|
|
|
+ # Deprecated: use 'render_modes' instead
|
|
|
|
+ 'render.modes': ['human', 'rgb_array'],
|
|
|
|
+ 'video.frames_per_second': 10, # Deprecated: use 'render_fps' instead
|
|
'render_modes': ['human', 'rgb_array'],
|
|
'render_modes': ['human', 'rgb_array'],
|
|
- 'render_fps' : 10
|
|
|
|
|
|
+ 'render_fps': 10
|
|
}
|
|
}
|
|
|
|
|
|
# Enumeration of possible actions
|
|
# Enumeration of possible actions
|
|
@@ -682,7 +703,7 @@ class MiniGridEnv(gym.Env):
|
|
'direction': spaces.Discrete(4),
|
|
'direction': spaces.Discrete(4),
|
|
'mission': spaces.Text(max_length=200,
|
|
'mission': spaces.Text(max_length=200,
|
|
charset=string.ascii_letters + string.digits + ' .,!- '
|
|
charset=string.ascii_letters + string.digits + ' .,!- '
|
|
- )
|
|
|
|
|
|
+ )
|
|
})
|
|
})
|
|
|
|
|
|
# render mode
|
|
# render mode
|
|
@@ -704,7 +725,6 @@ class MiniGridEnv(gym.Env):
|
|
self.agent_pos = None
|
|
self.agent_pos = None
|
|
self.agent_dir = None
|
|
self.agent_dir = None
|
|
|
|
|
|
-
|
|
|
|
# Initialize the state
|
|
# Initialize the state
|
|
self.reset()
|
|
self.reset()
|
|
|
|
|
|
@@ -735,14 +755,14 @@ class MiniGridEnv(gym.Env):
|
|
obs = self.gen_obs()
|
|
obs = self.gen_obs()
|
|
return obs
|
|
return obs
|
|
|
|
|
|
-
|
|
|
|
def hash(self, size=16):
|
|
def hash(self, size=16):
|
|
"""Compute a hash that uniquely identifies the current state of the environment.
|
|
"""Compute a hash that uniquely identifies the current state of the environment.
|
|
:param size: Size of the hashing
|
|
:param size: Size of the hashing
|
|
"""
|
|
"""
|
|
sample_hash = hashlib.sha256()
|
|
sample_hash = hashlib.sha256()
|
|
|
|
|
|
- to_encode = [self.grid.encode().tolist(), self.agent_pos, self.agent_dir]
|
|
|
|
|
|
+ to_encode = [self.grid.encode().tolist(), self.agent_pos,
|
|
|
|
+ self.agent_dir]
|
|
for item in to_encode:
|
|
for item in to_encode:
|
|
sample_hash.update(str(item).encode('utf8'))
|
|
sample_hash.update(str(item).encode('utf8'))
|
|
|
|
|
|
@@ -761,14 +781,14 @@ class MiniGridEnv(gym.Env):
|
|
|
|
|
|
# Map of object types to short string
|
|
# Map of object types to short string
|
|
OBJECT_TO_STR = {
|
|
OBJECT_TO_STR = {
|
|
- 'wall' : 'W',
|
|
|
|
- 'floor' : 'F',
|
|
|
|
- 'door' : 'D',
|
|
|
|
- 'key' : 'K',
|
|
|
|
- 'ball' : 'A',
|
|
|
|
- 'box' : 'B',
|
|
|
|
- 'goal' : 'G',
|
|
|
|
- 'lava' : 'V',
|
|
|
|
|
|
+ 'wall': 'W',
|
|
|
|
+ 'floor': 'F',
|
|
|
|
+ 'door': 'D',
|
|
|
|
+ 'key': 'K',
|
|
|
|
+ 'ball': 'A',
|
|
|
|
+ 'box': 'B',
|
|
|
|
+ 'goal': 'G',
|
|
|
|
+ 'lava': 'V',
|
|
}
|
|
}
|
|
|
|
|
|
# Short string for opened door
|
|
# Short string for opened door
|
|
@@ -888,12 +908,12 @@ class MiniGridEnv(gym.Env):
|
|
)
|
|
)
|
|
|
|
|
|
def place_obj(self,
|
|
def place_obj(self,
|
|
- obj,
|
|
|
|
- top=None,
|
|
|
|
- size=None,
|
|
|
|
- reject_fn=None,
|
|
|
|
- max_tries=math.inf
|
|
|
|
- ):
|
|
|
|
|
|
+ obj,
|
|
|
|
+ top=None,
|
|
|
|
+ size=None,
|
|
|
|
+ reject_fn=None,
|
|
|
|
+ max_tries=math.inf
|
|
|
|
+ ):
|
|
"""
|
|
"""
|
|
Place an object at an empty position in the grid
|
|
Place an object at an empty position in the grid
|
|
|
|
|
|
@@ -1174,7 +1194,7 @@ class MiniGridEnv(gym.Env):
|
|
"""
|
|
"""
|
|
|
|
|
|
topX, topY, botX, botY = self.get_view_exts(agent_view_size)
|
|
topX, topY, botX, botY = self.get_view_exts(agent_view_size)
|
|
-
|
|
|
|
|
|
+
|
|
agent_view_size = agent_view_size or self.agent_view_size
|
|
agent_view_size = agent_view_size or self.agent_view_size
|
|
|
|
|
|
grid = self.grid.slice(topX, topY, agent_view_size, agent_view_size)
|
|
grid = self.grid.slice(topX, topY, agent_view_size, agent_view_size)
|
|
@@ -1185,7 +1205,8 @@ class MiniGridEnv(gym.Env):
|
|
# Process occluders and visibility
|
|
# Process occluders and visibility
|
|
# Note that this incurs some performance cost
|
|
# Note that this incurs some performance cost
|
|
if not self.see_through_walls:
|
|
if not self.see_through_walls:
|
|
- vis_mask = grid.process_vis(agent_pos=(agent_view_size // 2 , agent_view_size - 1))
|
|
|
|
|
|
+ vis_mask = grid.process_vis(agent_pos=(
|
|
|
|
+ agent_view_size // 2, agent_view_size - 1))
|
|
else:
|
|
else:
|
|
vis_mask = np.ones(shape=(grid.width, grid.height), dtype=bool)
|
|
vis_mask = np.ones(shape=(grid.width, grid.height), dtype=bool)
|
|
|
|
|
|
@@ -1210,7 +1231,8 @@ class MiniGridEnv(gym.Env):
|
|
# Encode the partially observable view into a numpy array
|
|
# Encode the partially observable view into a numpy array
|
|
image = grid.encode(vis_mask)
|
|
image = grid.encode(vis_mask)
|
|
|
|
|
|
- assert hasattr(self, 'mission'), "environments must define a textual mission string"
|
|
|
|
|
|
+ assert hasattr(
|
|
|
|
+ self, 'mission'), "environments must define a textual mission string"
|
|
|
|
|
|
# Observations are dictionaries containing:
|
|
# Observations are dictionaries containing:
|
|
# - an image (partially observable view of the environment)
|
|
# - an image (partially observable view of the environment)
|
|
@@ -1264,7 +1286,8 @@ class MiniGridEnv(gym.Env):
|
|
# of the agent's view area
|
|
# of the agent's view area
|
|
f_vec = self.dir_vec
|
|
f_vec = self.dir_vec
|
|
r_vec = self.right_vec
|
|
r_vec = self.right_vec
|
|
- top_left = self.agent_pos + f_vec * (self.agent_view_size-1) - r_vec * (self.agent_view_size // 2)
|
|
|
|
|
|
+ top_left = self.agent_pos + f_vec * \
|
|
|
|
+ (self.agent_view_size-1) - r_vec * (self.agent_view_size // 2)
|
|
|
|
|
|
# Mask of which cells to highlight
|
|
# Mask of which cells to highlight
|
|
highlight_mask = np.zeros(shape=(self.width, self.height), dtype=bool)
|
|
highlight_mask = np.zeros(shape=(self.width, self.height), dtype=bool)
|