|
@@ -1,61 +1,65 @@
|
|
-import math
|
|
|
|
import hashlib
|
|
import hashlib
|
|
|
|
+import math
|
|
import string
|
|
import string
|
|
-import gym
|
|
|
|
from enum import IntEnum
|
|
from enum import IntEnum
|
|
|
|
+
|
|
|
|
+import gym
|
|
import numpy as np
|
|
import numpy as np
|
|
-from gym import error, spaces, utils
|
|
|
|
-from .rendering import *
|
|
|
|
|
|
+from gym import spaces
|
|
|
|
|
|
# Size in pixels of a tile in the full-scale human view
|
|
# Size in pixels of a tile in the full-scale human view
|
|
|
|
+from gym_minigrid.rendering import (
|
|
|
|
+ downsample,
|
|
|
|
+ fill_coords,
|
|
|
|
+ highlight_img,
|
|
|
|
+ point_in_circle,
|
|
|
|
+ point_in_line,
|
|
|
|
+ point_in_rect,
|
|
|
|
+ point_in_triangle,
|
|
|
|
+ rotate_fn,
|
|
|
|
+)
|
|
|
|
+
|
|
TILE_PIXELS = 32
|
|
TILE_PIXELS = 32
|
|
|
|
|
|
# Map of color names to RGB values
|
|
# Map of color names to RGB values
|
|
COLORS = {
|
|
COLORS = {
|
|
- 'red': np.array([255, 0, 0]),
|
|
|
|
- 'green': np.array([0, 255, 0]),
|
|
|
|
- 'blue': np.array([0, 0, 255]),
|
|
|
|
- 'purple': np.array([112, 39, 195]),
|
|
|
|
- 'yellow': np.array([255, 255, 0]),
|
|
|
|
- 'grey': np.array([100, 100, 100])
|
|
|
|
|
|
+ "red": np.array([255, 0, 0]),
|
|
|
|
+ "green": np.array([0, 255, 0]),
|
|
|
|
+ "blue": np.array([0, 0, 255]),
|
|
|
|
+ "purple": np.array([112, 39, 195]),
|
|
|
|
+ "yellow": np.array([255, 255, 0]),
|
|
|
|
+ "grey": np.array([100, 100, 100]),
|
|
}
|
|
}
|
|
|
|
|
|
COLOR_NAMES = sorted(list(COLORS.keys()))
|
|
COLOR_NAMES = sorted(list(COLORS.keys()))
|
|
|
|
|
|
# Used to map colors to integers
|
|
# Used to map colors to integers
|
|
-COLOR_TO_IDX = {
|
|
|
|
- 'red': 0,
|
|
|
|
- 'green': 1,
|
|
|
|
- 'blue': 2,
|
|
|
|
- 'purple': 3,
|
|
|
|
- 'yellow': 4,
|
|
|
|
- 'grey': 5
|
|
|
|
-}
|
|
|
|
|
|
+COLOR_TO_IDX = {"red": 0, "green": 1, "blue": 2, "purple": 3, "yellow": 4, "grey": 5}
|
|
|
|
|
|
IDX_TO_COLOR = dict(zip(COLOR_TO_IDX.values(), COLOR_TO_IDX.keys()))
|
|
IDX_TO_COLOR = dict(zip(COLOR_TO_IDX.values(), COLOR_TO_IDX.keys()))
|
|
|
|
|
|
# Map of object type to integers
|
|
# Map of object type to integers
|
|
OBJECT_TO_IDX = {
|
|
OBJECT_TO_IDX = {
|
|
- 'unseen': 0,
|
|
|
|
- 'empty': 1,
|
|
|
|
- 'wall': 2,
|
|
|
|
- 'floor': 3,
|
|
|
|
- 'door': 4,
|
|
|
|
- 'key': 5,
|
|
|
|
- 'ball': 6,
|
|
|
|
- 'box': 7,
|
|
|
|
- 'goal': 8,
|
|
|
|
- 'lava': 9,
|
|
|
|
- 'agent': 10,
|
|
|
|
|
|
+ "unseen": 0,
|
|
|
|
+ "empty": 1,
|
|
|
|
+ "wall": 2,
|
|
|
|
+ "floor": 3,
|
|
|
|
+ "door": 4,
|
|
|
|
+ "key": 5,
|
|
|
|
+ "ball": 6,
|
|
|
|
+ "box": 7,
|
|
|
|
+ "goal": 8,
|
|
|
|
+ "lava": 9,
|
|
|
|
+ "agent": 10,
|
|
}
|
|
}
|
|
|
|
|
|
IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys()))
|
|
IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys()))
|
|
|
|
|
|
# Map of state names to integers
|
|
# Map of state names to integers
|
|
STATE_TO_IDX = {
|
|
STATE_TO_IDX = {
|
|
- 'open': 0,
|
|
|
|
- 'closed': 1,
|
|
|
|
- 'locked': 2,
|
|
|
|
|
|
+ "open": 0,
|
|
|
|
+ "closed": 1,
|
|
|
|
+ "locked": 2,
|
|
}
|
|
}
|
|
|
|
|
|
# Map of agent direction indices to vectors
|
|
# Map of agent direction indices to vectors
|
|
@@ -120,28 +124,28 @@ class WorldObj:
|
|
obj_type = IDX_TO_OBJECT[type_idx]
|
|
obj_type = IDX_TO_OBJECT[type_idx]
|
|
color = IDX_TO_COLOR[color_idx]
|
|
color = IDX_TO_COLOR[color_idx]
|
|
|
|
|
|
- if obj_type == 'empty' or obj_type == 'unseen':
|
|
|
|
|
|
+ if obj_type == "empty" or obj_type == "unseen":
|
|
return None
|
|
return None
|
|
|
|
|
|
# State, 0: open, 1: closed, 2: locked
|
|
# State, 0: open, 1: closed, 2: locked
|
|
is_open = state == 0
|
|
is_open = state == 0
|
|
is_locked = state == 2
|
|
is_locked = state == 2
|
|
|
|
|
|
- if obj_type == 'wall':
|
|
|
|
|
|
+ if obj_type == "wall":
|
|
v = Wall(color)
|
|
v = Wall(color)
|
|
- elif obj_type == 'floor':
|
|
|
|
|
|
+ elif obj_type == "floor":
|
|
v = Floor(color)
|
|
v = Floor(color)
|
|
- elif obj_type == 'ball':
|
|
|
|
|
|
+ elif obj_type == "ball":
|
|
v = Ball(color)
|
|
v = Ball(color)
|
|
- elif obj_type == 'key':
|
|
|
|
|
|
+ elif obj_type == "key":
|
|
v = Key(color)
|
|
v = Key(color)
|
|
- elif obj_type == 'box':
|
|
|
|
|
|
+ elif obj_type == "box":
|
|
v = Box(color)
|
|
v = Box(color)
|
|
- elif obj_type == 'door':
|
|
|
|
|
|
+ elif obj_type == "door":
|
|
v = Door(color, is_open, is_locked)
|
|
v = Door(color, is_open, is_locked)
|
|
- elif obj_type == 'goal':
|
|
|
|
|
|
+ elif obj_type == "goal":
|
|
v = Goal()
|
|
v = Goal()
|
|
- elif obj_type == 'lava':
|
|
|
|
|
|
+ elif obj_type == "lava":
|
|
v = Lava()
|
|
v = Lava()
|
|
else:
|
|
else:
|
|
assert False, "unknown object type in decode '%s'" % obj_type
|
|
assert False, "unknown object type in decode '%s'" % obj_type
|
|
@@ -155,7 +159,7 @@ class WorldObj:
|
|
|
|
|
|
class Goal(WorldObj):
|
|
class Goal(WorldObj):
|
|
def __init__(self):
|
|
def __init__(self):
|
|
- super().__init__('goal', 'green')
|
|
|
|
|
|
+ super().__init__("goal", "green")
|
|
|
|
|
|
def can_overlap(self):
|
|
def can_overlap(self):
|
|
return True
|
|
return True
|
|
@@ -169,8 +173,8 @@ class Floor(WorldObj):
|
|
Colored floor tile the agent can walk over
|
|
Colored floor tile the agent can walk over
|
|
"""
|
|
"""
|
|
|
|
|
|
- def __init__(self, color='blue'):
|
|
|
|
- super().__init__('floor', color)
|
|
|
|
|
|
+ def __init__(self, color="blue"):
|
|
|
|
+ super().__init__("floor", color)
|
|
|
|
|
|
def can_overlap(self):
|
|
def can_overlap(self):
|
|
return True
|
|
return True
|
|
@@ -183,7 +187,7 @@ class Floor(WorldObj):
|
|
|
|
|
|
class Lava(WorldObj):
|
|
class Lava(WorldObj):
|
|
def __init__(self):
|
|
def __init__(self):
|
|
- super().__init__('lava', 'red')
|
|
|
|
|
|
+ super().__init__("lava", "red")
|
|
|
|
|
|
def can_overlap(self):
|
|
def can_overlap(self):
|
|
return True
|
|
return True
|
|
@@ -198,19 +202,15 @@ class Lava(WorldObj):
|
|
for i in range(3):
|
|
for i in range(3):
|
|
ylo = 0.3 + 0.2 * i
|
|
ylo = 0.3 + 0.2 * i
|
|
yhi = 0.4 + 0.2 * i
|
|
yhi = 0.4 + 0.2 * i
|
|
- fill_coords(img, point_in_line(
|
|
|
|
- 0.1, ylo, 0.3, yhi, r=0.03), (0, 0, 0))
|
|
|
|
- fill_coords(img, point_in_line(
|
|
|
|
- 0.3, yhi, 0.5, ylo, r=0.03), (0, 0, 0))
|
|
|
|
- fill_coords(img, point_in_line(
|
|
|
|
- 0.5, ylo, 0.7, yhi, r=0.03), (0, 0, 0))
|
|
|
|
- fill_coords(img, point_in_line(
|
|
|
|
- 0.7, yhi, 0.9, ylo, r=0.03), (0, 0, 0))
|
|
|
|
|
|
+ fill_coords(img, point_in_line(0.1, ylo, 0.3, yhi, r=0.03), (0, 0, 0))
|
|
|
|
+ fill_coords(img, point_in_line(0.3, yhi, 0.5, ylo, r=0.03), (0, 0, 0))
|
|
|
|
+ fill_coords(img, point_in_line(0.5, ylo, 0.7, yhi, r=0.03), (0, 0, 0))
|
|
|
|
+ fill_coords(img, point_in_line(0.7, yhi, 0.9, ylo, r=0.03), (0, 0, 0))
|
|
|
|
|
|
|
|
|
|
class Wall(WorldObj):
|
|
class Wall(WorldObj):
|
|
- def __init__(self, color='grey'):
|
|
|
|
- super().__init__('wall', color)
|
|
|
|
|
|
+ def __init__(self, color="grey"):
|
|
|
|
+ super().__init__("wall", color)
|
|
|
|
|
|
def see_behind(self):
|
|
def see_behind(self):
|
|
return False
|
|
return False
|
|
@@ -221,7 +221,7 @@ class Wall(WorldObj):
|
|
|
|
|
|
class Door(WorldObj):
|
|
class Door(WorldObj):
|
|
def __init__(self, color, is_open=False, is_locked=False):
|
|
def __init__(self, color, is_open=False, is_locked=False):
|
|
- super().__init__('door', color)
|
|
|
|
|
|
+ super().__init__("door", color)
|
|
self.is_open = is_open
|
|
self.is_open = is_open
|
|
self.is_locked = is_locked
|
|
self.is_locked = is_locked
|
|
|
|
|
|
@@ -268,8 +268,7 @@ class Door(WorldObj):
|
|
# Door frame and door
|
|
# Door frame and door
|
|
if self.is_locked:
|
|
if self.is_locked:
|
|
fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)
|
|
fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)
|
|
- fill_coords(img, point_in_rect(
|
|
|
|
- 0.06, 0.94, 0.06, 0.94), 0.45 * np.array(c))
|
|
|
|
|
|
+ fill_coords(img, point_in_rect(0.06, 0.94, 0.06, 0.94), 0.45 * np.array(c))
|
|
|
|
|
|
# Draw key slot
|
|
# Draw key slot
|
|
fill_coords(img, point_in_rect(0.52, 0.75, 0.50, 0.56), c)
|
|
fill_coords(img, point_in_rect(0.52, 0.75, 0.50, 0.56), c)
|
|
@@ -284,8 +283,8 @@ class Door(WorldObj):
|
|
|
|
|
|
|
|
|
|
class Key(WorldObj):
|
|
class Key(WorldObj):
|
|
- def __init__(self, color='blue'):
|
|
|
|
- super(Key, self).__init__('key', color)
|
|
|
|
|
|
+ def __init__(self, color="blue"):
|
|
|
|
+ super().__init__("key", color)
|
|
|
|
|
|
def can_pickup(self):
|
|
def can_pickup(self):
|
|
return True
|
|
return True
|
|
@@ -306,8 +305,8 @@ class Key(WorldObj):
|
|
|
|
|
|
|
|
|
|
class Ball(WorldObj):
|
|
class Ball(WorldObj):
|
|
- def __init__(self, color='blue'):
|
|
|
|
- super(Ball, self).__init__('ball', color)
|
|
|
|
|
|
+ def __init__(self, color="blue"):
|
|
|
|
+ super().__init__("ball", color)
|
|
|
|
|
|
def can_pickup(self):
|
|
def can_pickup(self):
|
|
return True
|
|
return True
|
|
@@ -318,7 +317,7 @@ class Ball(WorldObj):
|
|
|
|
|
|
class Box(WorldObj):
|
|
class Box(WorldObj):
|
|
def __init__(self, color, contains=None):
|
|
def __init__(self, color, contains=None):
|
|
- super(Box, self).__init__('box', color)
|
|
|
|
|
|
+ super().__init__("box", color)
|
|
self.contains = contains
|
|
self.contains = contains
|
|
|
|
|
|
def can_pickup(self):
|
|
def can_pickup(self):
|
|
@@ -382,6 +381,7 @@ class Grid:
|
|
|
|
|
|
def copy(self):
|
|
def copy(self):
|
|
from copy import deepcopy
|
|
from copy import deepcopy
|
|
|
|
+
|
|
return deepcopy(self)
|
|
return deepcopy(self)
|
|
|
|
|
|
def set(self, i, j, v):
|
|
def set(self, i, j, v):
|
|
@@ -408,9 +408,9 @@ class Grid:
|
|
|
|
|
|
def wall_rect(self, x, y, w, h):
|
|
def wall_rect(self, x, y, w, h):
|
|
self.horz_wall(x, y, w)
|
|
self.horz_wall(x, y, w)
|
|
- self.horz_wall(x, y+h-1, w)
|
|
|
|
|
|
+ self.horz_wall(x, y + h - 1, w)
|
|
self.vert_wall(x, y, h)
|
|
self.vert_wall(x, y, h)
|
|
- self.vert_wall(x+w-1, y, h)
|
|
|
|
|
|
+ self.vert_wall(x + w - 1, y, h)
|
|
|
|
|
|
def rotate_left(self):
|
|
def rotate_left(self):
|
|
"""
|
|
"""
|
|
@@ -438,8 +438,7 @@ class Grid:
|
|
x = topX + i
|
|
x = topX + i
|
|
y = topY + j
|
|
y = topY + j
|
|
|
|
|
|
- if x >= 0 and x < self.width and \
|
|
|
|
- y >= 0 and y < self.height:
|
|
|
|
|
|
+ if x >= 0 and x < self.width and y >= 0 and y < self.height:
|
|
v = self.get(x, y)
|
|
v = self.get(x, y)
|
|
else:
|
|
else:
|
|
v = Wall()
|
|
v = Wall()
|
|
@@ -450,12 +449,7 @@ class Grid:
|
|
|
|
|
|
@classmethod
|
|
@classmethod
|
|
def render_tile(
|
|
def render_tile(
|
|
- cls,
|
|
|
|
- obj,
|
|
|
|
- agent_dir=None,
|
|
|
|
- highlight=False,
|
|
|
|
- tile_size=TILE_PIXELS,
|
|
|
|
- subdivs=3
|
|
|
|
|
|
+ cls, obj, agent_dir=None, highlight=False, tile_size=TILE_PIXELS, subdivs=3
|
|
):
|
|
):
|
|
"""
|
|
"""
|
|
Render a tile and cache the result
|
|
Render a tile and cache the result
|
|
@@ -468,14 +462,15 @@ class Grid:
|
|
if key in cls.tile_cache:
|
|
if key in cls.tile_cache:
|
|
return cls.tile_cache[key]
|
|
return cls.tile_cache[key]
|
|
|
|
|
|
- img = np.zeros(shape=(tile_size * subdivs,
|
|
|
|
- tile_size * subdivs, 3), dtype=np.uint8)
|
|
|
|
|
|
+ img = np.zeros(
|
|
|
|
+ shape=(tile_size * subdivs, tile_size * subdivs, 3), dtype=np.uint8
|
|
|
|
+ )
|
|
|
|
|
|
# Draw the grid lines (top and left edges)
|
|
# Draw the grid lines (top and left edges)
|
|
fill_coords(img, point_in_rect(0, 0.031, 0, 1), (100, 100, 100))
|
|
fill_coords(img, point_in_rect(0, 0.031, 0, 1), (100, 100, 100))
|
|
fill_coords(img, point_in_rect(0, 1, 0, 0.031), (100, 100, 100))
|
|
fill_coords(img, point_in_rect(0, 1, 0, 0.031), (100, 100, 100))
|
|
|
|
|
|
- if obj != None:
|
|
|
|
|
|
+ if obj is not None:
|
|
obj.render(img)
|
|
obj.render(img)
|
|
|
|
|
|
# Overlay the agent on top
|
|
# Overlay the agent on top
|
|
@@ -487,8 +482,7 @@ class Grid:
|
|
)
|
|
)
|
|
|
|
|
|
# Rotate the agent based on its direction
|
|
# Rotate the agent based on its direction
|
|
- tri_fn = rotate_fn(tri_fn, cx=0.5, cy=0.5,
|
|
|
|
- theta=0.5*math.pi*agent_dir)
|
|
|
|
|
|
+ tri_fn = rotate_fn(tri_fn, cx=0.5, cy=0.5, theta=0.5 * math.pi * agent_dir)
|
|
fill_coords(img, tri_fn, (255, 0, 0))
|
|
fill_coords(img, tri_fn, (255, 0, 0))
|
|
|
|
|
|
# Highlight the cell if needed
|
|
# Highlight the cell if needed
|
|
@@ -503,13 +497,7 @@ class Grid:
|
|
|
|
|
|
return img
|
|
return img
|
|
|
|
|
|
- def render(
|
|
|
|
- self,
|
|
|
|
- tile_size,
|
|
|
|
- agent_pos=None,
|
|
|
|
- agent_dir=None,
|
|
|
|
- highlight_mask=None
|
|
|
|
- ):
|
|
|
|
|
|
+ def render(self, tile_size, agent_pos=None, agent_dir=None, highlight_mask=None):
|
|
"""
|
|
"""
|
|
Render this grid at a given scale
|
|
Render this grid at a given scale
|
|
:param r: target renderer object
|
|
:param r: target renderer object
|
|
@@ -517,8 +505,7 @@ class Grid:
|
|
"""
|
|
"""
|
|
|
|
|
|
if highlight_mask is None:
|
|
if highlight_mask is None:
|
|
- highlight_mask = np.zeros(
|
|
|
|
- shape=(self.width, self.height), dtype=bool)
|
|
|
|
|
|
+ highlight_mask = np.zeros(shape=(self.width, self.height), dtype=bool)
|
|
|
|
|
|
# Compute the total grid size
|
|
# Compute the total grid size
|
|
width_px = self.width * tile_size
|
|
width_px = self.width * tile_size
|
|
@@ -536,13 +523,13 @@ class Grid:
|
|
cell,
|
|
cell,
|
|
agent_dir=agent_dir if agent_here else None,
|
|
agent_dir=agent_dir if agent_here else None,
|
|
highlight=highlight_mask[i, j],
|
|
highlight=highlight_mask[i, j],
|
|
- tile_size=tile_size
|
|
|
|
|
|
+ tile_size=tile_size,
|
|
)
|
|
)
|
|
|
|
|
|
ymin = j * tile_size
|
|
ymin = j * tile_size
|
|
- ymax = (j+1) * tile_size
|
|
|
|
|
|
+ ymax = (j + 1) * tile_size
|
|
xmin = i * tile_size
|
|
xmin = i * tile_size
|
|
- xmax = (i+1) * tile_size
|
|
|
|
|
|
+ xmax = (i + 1) * tile_size
|
|
img[ymin:ymax, xmin:xmax, :] = tile_img
|
|
img[ymin:ymax, xmin:xmax, :] = tile_img
|
|
|
|
|
|
return img
|
|
return img
|
|
@@ -555,7 +542,7 @@ class Grid:
|
|
if vis_mask is None:
|
|
if vis_mask is None:
|
|
vis_mask = np.ones((self.width, self.height), dtype=bool)
|
|
vis_mask = np.ones((self.width, self.height), dtype=bool)
|
|
|
|
|
|
- array = np.zeros((self.width, self.height, 3), dtype='uint8')
|
|
|
|
|
|
+ array = np.zeros((self.width, self.height, 3), dtype="uint8")
|
|
|
|
|
|
for i in range(self.width):
|
|
for i in range(self.width):
|
|
for j in range(self.height):
|
|
for j in range(self.height):
|
|
@@ -563,7 +550,7 @@ class Grid:
|
|
v = self.get(i, j)
|
|
v = self.get(i, j)
|
|
|
|
|
|
if v is None:
|
|
if v is None:
|
|
- array[i, j, 0] = OBJECT_TO_IDX['empty']
|
|
|
|
|
|
+ array[i, j, 0] = OBJECT_TO_IDX["empty"]
|
|
array[i, j, 1] = 0
|
|
array[i, j, 1] = 0
|
|
array[i, j, 2] = 0
|
|
array[i, j, 2] = 0
|
|
|
|
|
|
@@ -589,7 +576,7 @@ class Grid:
|
|
type_idx, color_idx, state = array[i, j]
|
|
type_idx, color_idx, state = array[i, j]
|
|
v = WorldObj.decode(type_idx, color_idx, state)
|
|
v = WorldObj.decode(type_idx, color_idx, state)
|
|
grid.set(i, j, v)
|
|
grid.set(i, j, v)
|
|
- vis_mask[i, j] = (type_idx != OBJECT_TO_IDX['unseen'])
|
|
|
|
|
|
+ vis_mask[i, j] = type_idx != OBJECT_TO_IDX["unseen"]
|
|
|
|
|
|
return grid, vis_mask
|
|
return grid, vis_mask
|
|
|
|
|
|
@@ -599,7 +586,7 @@ class Grid:
|
|
mask[agent_pos[0], agent_pos[1]] = True
|
|
mask[agent_pos[0], agent_pos[1]] = True
|
|
|
|
|
|
for j in reversed(range(0, grid.height)):
|
|
for j in reversed(range(0, grid.height)):
|
|
- for i in range(0, grid.width-1):
|
|
|
|
|
|
+ for i in range(0, grid.width - 1):
|
|
if not mask[i, j]:
|
|
if not mask[i, j]:
|
|
continue
|
|
continue
|
|
|
|
|
|
@@ -607,10 +594,10 @@ class Grid:
|
|
if cell and not cell.see_behind():
|
|
if cell and not cell.see_behind():
|
|
continue
|
|
continue
|
|
|
|
|
|
- mask[i+1, j] = True
|
|
|
|
|
|
+ mask[i + 1, j] = True
|
|
if j > 0:
|
|
if j > 0:
|
|
- mask[i+1, j-1] = True
|
|
|
|
- mask[i, j-1] = True
|
|
|
|
|
|
+ mask[i + 1, j - 1] = True
|
|
|
|
+ mask[i, j - 1] = True
|
|
|
|
|
|
for i in reversed(range(1, grid.width)):
|
|
for i in reversed(range(1, grid.width)):
|
|
if not mask[i, j]:
|
|
if not mask[i, j]:
|
|
@@ -620,10 +607,10 @@ class Grid:
|
|
if cell and not cell.see_behind():
|
|
if cell and not cell.see_behind():
|
|
continue
|
|
continue
|
|
|
|
|
|
- mask[i-1, j] = True
|
|
|
|
|
|
+ mask[i - 1, j] = True
|
|
if j > 0:
|
|
if j > 0:
|
|
- mask[i-1, j-1] = True
|
|
|
|
- mask[i, j-1] = True
|
|
|
|
|
|
+ mask[i - 1, j - 1] = True
|
|
|
|
+ mask[i, j - 1] = True
|
|
|
|
|
|
for j in range(0, grid.height):
|
|
for j in range(0, grid.height):
|
|
for i in range(0, grid.width):
|
|
for i in range(0, grid.width):
|
|
@@ -640,10 +627,10 @@ class MiniGridEnv(gym.Env):
|
|
|
|
|
|
metadata = {
|
|
metadata = {
|
|
# Deprecated: use 'render_modes' instead
|
|
# Deprecated: use 'render_modes' instead
|
|
- 'render.modes': ['human', 'rgb_array'],
|
|
|
|
- 'video.frames_per_second': 10, # Deprecated: use 'render_fps' instead
|
|
|
|
- 'render_modes': ['human', 'rgb_array'],
|
|
|
|
- 'render_fps': 10
|
|
|
|
|
|
+ "render.modes": ["human", "rgb_array"],
|
|
|
|
+ "video.frames_per_second": 10, # Deprecated: use 'render_fps' instead
|
|
|
|
+ "render_modes": ["human", "rgb_array"],
|
|
|
|
+ "render_fps": 10,
|
|
}
|
|
}
|
|
|
|
|
|
# Enumeration of possible actions
|
|
# Enumeration of possible actions
|
|
@@ -676,7 +663,7 @@ class MiniGridEnv(gym.Env):
|
|
):
|
|
):
|
|
# Can't set both grid_size and width/height
|
|
# Can't set both grid_size and width/height
|
|
if grid_size:
|
|
if grid_size:
|
|
- assert width == None and height == None
|
|
|
|
|
|
+ assert width is None and height is None
|
|
width = grid_size
|
|
width = grid_size
|
|
height = grid_size
|
|
height = grid_size
|
|
|
|
|
|
@@ -697,15 +684,18 @@ class MiniGridEnv(gym.Env):
|
|
low=0,
|
|
low=0,
|
|
high=255,
|
|
high=255,
|
|
shape=(self.agent_view_size, self.agent_view_size, 3),
|
|
shape=(self.agent_view_size, self.agent_view_size, 3),
|
|
- dtype='uint8'
|
|
|
|
|
|
+ dtype="uint8",
|
|
|
|
+ )
|
|
|
|
+ self.observation_space = spaces.Dict(
|
|
|
|
+ {
|
|
|
|
+ "image": self.observation_space,
|
|
|
|
+ "direction": spaces.Discrete(4),
|
|
|
|
+ "mission": spaces.Text(
|
|
|
|
+ max_length=200,
|
|
|
|
+ charset=string.ascii_letters + string.digits + " .,!-",
|
|
|
|
+ ),
|
|
|
|
+ }
|
|
)
|
|
)
|
|
- self.observation_space = spaces.Dict({
|
|
|
|
- 'image': self.observation_space,
|
|
|
|
- 'direction': spaces.Discrete(4),
|
|
|
|
- 'mission': spaces.Text(max_length=200,
|
|
|
|
- charset=string.ascii_letters + string.digits + ' .,!-'
|
|
|
|
- )
|
|
|
|
- })
|
|
|
|
|
|
|
|
# render mode
|
|
# render mode
|
|
self.render_mode = render_mode
|
|
self.render_mode = render_mode
|
|
@@ -762,10 +752,9 @@ class MiniGridEnv(gym.Env):
|
|
"""
|
|
"""
|
|
sample_hash = hashlib.sha256()
|
|
sample_hash = hashlib.sha256()
|
|
|
|
|
|
- to_encode = [self.grid.encode().tolist(), self.agent_pos,
|
|
|
|
- self.agent_dir]
|
|
|
|
|
|
+ to_encode = [self.grid.encode().tolist(), self.agent_pos, self.agent_dir]
|
|
for item in to_encode:
|
|
for item in to_encode:
|
|
- sample_hash.update(str(item).encode('utf8'))
|
|
|
|
|
|
+ sample_hash.update(str(item).encode("utf8"))
|
|
|
|
|
|
return sample_hash.hexdigest()[:size]
|
|
return sample_hash.hexdigest()[:size]
|
|
|
|
|
|
@@ -782,28 +771,20 @@ class MiniGridEnv(gym.Env):
|
|
|
|
|
|
# Map of object types to short string
|
|
# Map of object types to short string
|
|
OBJECT_TO_STR = {
|
|
OBJECT_TO_STR = {
|
|
- 'wall': 'W',
|
|
|
|
- 'floor': 'F',
|
|
|
|
- 'door': 'D',
|
|
|
|
- 'key': 'K',
|
|
|
|
- 'ball': 'A',
|
|
|
|
- 'box': 'B',
|
|
|
|
- 'goal': 'G',
|
|
|
|
- 'lava': 'V',
|
|
|
|
|
|
+ "wall": "W",
|
|
|
|
+ "floor": "F",
|
|
|
|
+ "door": "D",
|
|
|
|
+ "key": "K",
|
|
|
|
+ "ball": "A",
|
|
|
|
+ "box": "B",
|
|
|
|
+ "goal": "G",
|
|
|
|
+ "lava": "V",
|
|
}
|
|
}
|
|
|
|
|
|
- # Short string for opened door
|
|
|
|
- OPENDED_DOOR_IDS = '_'
|
|
|
|
-
|
|
|
|
# Map agent's direction to short string
|
|
# Map agent's direction to short string
|
|
- AGENT_DIR_TO_STR = {
|
|
|
|
- 0: '>',
|
|
|
|
- 1: 'V',
|
|
|
|
- 2: '<',
|
|
|
|
- 3: '^'
|
|
|
|
- }
|
|
|
|
|
|
+ AGENT_DIR_TO_STR = {0: ">", 1: "V", 2: "<", 3: "^"}
|
|
|
|
|
|
- str = ''
|
|
|
|
|
|
+ str = ""
|
|
|
|
|
|
for j in range(self.grid.height):
|
|
for j in range(self.grid.height):
|
|
|
|
|
|
@@ -814,23 +795,23 @@ class MiniGridEnv(gym.Env):
|
|
|
|
|
|
c = self.grid.get(i, j)
|
|
c = self.grid.get(i, j)
|
|
|
|
|
|
- if c == None:
|
|
|
|
- str += ' '
|
|
|
|
|
|
+ if c is None:
|
|
|
|
+ str += " "
|
|
continue
|
|
continue
|
|
|
|
|
|
- if c.type == 'door':
|
|
|
|
|
|
+ if c.type == "door":
|
|
if c.is_open:
|
|
if c.is_open:
|
|
- str += '__'
|
|
|
|
|
|
+ str += "__"
|
|
elif c.is_locked:
|
|
elif c.is_locked:
|
|
- str += 'L' + c.color[0].upper()
|
|
|
|
|
|
+ str += "L" + c.color[0].upper()
|
|
else:
|
|
else:
|
|
- str += 'D' + c.color[0].upper()
|
|
|
|
|
|
+ str += "D" + c.color[0].upper()
|
|
continue
|
|
continue
|
|
|
|
|
|
str += OBJECT_TO_STR[c.type] + c.color[0].upper()
|
|
str += OBJECT_TO_STR[c.type] + c.color[0].upper()
|
|
|
|
|
|
if j < self.grid.height - 1:
|
|
if j < self.grid.height - 1:
|
|
- str += '\n'
|
|
|
|
|
|
+ str += "\n"
|
|
|
|
|
|
return str
|
|
return str
|
|
|
|
|
|
@@ -863,7 +844,7 @@ class MiniGridEnv(gym.Env):
|
|
Generate random boolean value
|
|
Generate random boolean value
|
|
"""
|
|
"""
|
|
|
|
|
|
- return (self.np_random.integers(0, 2) == 0)
|
|
|
|
|
|
+ return self.np_random.integers(0, 2) == 0
|
|
|
|
|
|
def _rand_elem(self, iterable):
|
|
def _rand_elem(self, iterable):
|
|
"""
|
|
"""
|
|
@@ -905,16 +886,10 @@ class MiniGridEnv(gym.Env):
|
|
|
|
|
|
return (
|
|
return (
|
|
self.np_random.integers(xLow, xHigh),
|
|
self.np_random.integers(xLow, xHigh),
|
|
- self.np_random.integers(yLow, yHigh)
|
|
|
|
|
|
+ self.np_random.integers(yLow, yHigh),
|
|
)
|
|
)
|
|
|
|
|
|
- def place_obj(self,
|
|
|
|
- obj,
|
|
|
|
- top=None,
|
|
|
|
- size=None,
|
|
|
|
- reject_fn=None,
|
|
|
|
- max_tries=math.inf
|
|
|
|
- ):
|
|
|
|
|
|
+ def place_obj(self, obj, top=None, size=None, reject_fn=None, max_tries=math.inf):
|
|
"""
|
|
"""
|
|
Place an object at an empty position in the grid
|
|
Place an object at an empty position in the grid
|
|
|
|
|
|
@@ -937,17 +912,19 @@ class MiniGridEnv(gym.Env):
|
|
# This is to handle with rare cases where rejection sampling
|
|
# This is to handle with rare cases where rejection sampling
|
|
# gets stuck in an infinite loop
|
|
# gets stuck in an infinite loop
|
|
if num_tries > max_tries:
|
|
if num_tries > max_tries:
|
|
- raise RecursionError('rejection sampling failed in place_obj')
|
|
|
|
|
|
+ raise RecursionError("rejection sampling failed in place_obj")
|
|
|
|
|
|
num_tries += 1
|
|
num_tries += 1
|
|
|
|
|
|
- pos = np.array((
|
|
|
|
- self._rand_int(top[0], min(top[0] + size[0], self.grid.width)),
|
|
|
|
- self._rand_int(top[1], min(top[1] + size[1], self.grid.height))
|
|
|
|
- ))
|
|
|
|
|
|
+ pos = np.array(
|
|
|
|
+ (
|
|
|
|
+ self._rand_int(top[0], min(top[0] + size[0], self.grid.width)),
|
|
|
|
+ self._rand_int(top[1], min(top[1] + size[1], self.grid.height)),
|
|
|
|
+ )
|
|
|
|
+ )
|
|
|
|
|
|
# Don't place the object on top of another object
|
|
# Don't place the object on top of another object
|
|
- if self.grid.get(*pos) != None:
|
|
|
|
|
|
+ if self.grid.get(*pos) is not None:
|
|
continue
|
|
continue
|
|
|
|
|
|
# Don't place the object where the agent is
|
|
# Don't place the object where the agent is
|
|
@@ -977,13 +954,7 @@ class MiniGridEnv(gym.Env):
|
|
obj.init_pos = (i, j)
|
|
obj.init_pos = (i, j)
|
|
obj.cur_pos = (i, j)
|
|
obj.cur_pos = (i, j)
|
|
|
|
|
|
- def place_agent(
|
|
|
|
- self,
|
|
|
|
- top=None,
|
|
|
|
- size=None,
|
|
|
|
- rand_dir=True,
|
|
|
|
- max_tries=math.inf
|
|
|
|
- ):
|
|
|
|
|
|
+ def place_agent(self, top=None, size=None, rand_dir=True, max_tries=math.inf):
|
|
"""
|
|
"""
|
|
Set the agent's starting point at an empty position in the grid
|
|
Set the agent's starting point at an empty position in the grid
|
|
"""
|
|
"""
|
|
@@ -1038,16 +1009,16 @@ class MiniGridEnv(gym.Env):
|
|
# Compute the absolute coordinates of the top-left view corner
|
|
# Compute the absolute coordinates of the top-left view corner
|
|
sz = self.agent_view_size
|
|
sz = self.agent_view_size
|
|
hs = self.agent_view_size // 2
|
|
hs = self.agent_view_size // 2
|
|
- tx = ax + (dx * (sz-1)) - (rx * hs)
|
|
|
|
- ty = ay + (dy * (sz-1)) - (ry * hs)
|
|
|
|
|
|
+ tx = ax + (dx * (sz - 1)) - (rx * hs)
|
|
|
|
+ ty = ay + (dy * (sz - 1)) - (ry * hs)
|
|
|
|
|
|
lx = i - tx
|
|
lx = i - tx
|
|
ly = j - ty
|
|
ly = j - ty
|
|
|
|
|
|
# Project the coordinates of the object relative to the top-left
|
|
# Project the coordinates of the object relative to the top-left
|
|
# corner onto the agent's own coordinate system
|
|
# corner onto the agent's own coordinate system
|
|
- vx = (rx*lx + ry*ly)
|
|
|
|
- vy = -(dx*lx + dy*ly)
|
|
|
|
|
|
+ vx = rx * lx + ry * ly
|
|
|
|
+ vy = -(dx * lx + dy * ly)
|
|
|
|
|
|
return vx, vy
|
|
return vx, vy
|
|
|
|
|
|
@@ -1114,7 +1085,7 @@ class MiniGridEnv(gym.Env):
|
|
vx, vy = coordinates
|
|
vx, vy = coordinates
|
|
|
|
|
|
obs = self.gen_obs()
|
|
obs = self.gen_obs()
|
|
- obs_grid, _ = Grid.decode(obs['image'])
|
|
|
|
|
|
+ obs_grid, _ = Grid.decode(obs["image"])
|
|
obs_cell = obs_grid.get(vx, vy)
|
|
obs_cell = obs_grid.get(vx, vy)
|
|
world_cell = self.grid.get(x, y)
|
|
world_cell = self.grid.get(x, y)
|
|
|
|
|
|
@@ -1144,12 +1115,12 @@ class MiniGridEnv(gym.Env):
|
|
|
|
|
|
# Move forward
|
|
# Move forward
|
|
elif action == self.actions.forward:
|
|
elif action == self.actions.forward:
|
|
- if fwd_cell == None or fwd_cell.can_overlap():
|
|
|
|
|
|
+ if fwd_cell is None or fwd_cell.can_overlap():
|
|
self.agent_pos = fwd_pos
|
|
self.agent_pos = fwd_pos
|
|
- if fwd_cell != None and fwd_cell.type == 'goal':
|
|
|
|
|
|
+ if fwd_cell is not None and fwd_cell.type == "goal":
|
|
done = True
|
|
done = True
|
|
reward = self._reward()
|
|
reward = self._reward()
|
|
- if fwd_cell != None and fwd_cell.type == 'lava':
|
|
|
|
|
|
+ if fwd_cell is not None and fwd_cell.type == "lava":
|
|
done = True
|
|
done = True
|
|
|
|
|
|
# Pick up an object
|
|
# Pick up an object
|
|
@@ -1206,8 +1177,9 @@ class MiniGridEnv(gym.Env):
|
|
# Process occluders and visibility
|
|
# Process occluders and visibility
|
|
# Note that this incurs some performance cost
|
|
# Note that this incurs some performance cost
|
|
if not self.see_through_walls:
|
|
if not self.see_through_walls:
|
|
- vis_mask = grid.process_vis(agent_pos=(
|
|
|
|
- agent_view_size // 2, agent_view_size - 1))
|
|
|
|
|
|
+ vis_mask = grid.process_vis(
|
|
|
|
+ agent_pos=(agent_view_size // 2, agent_view_size - 1)
|
|
|
|
+ )
|
|
else:
|
|
else:
|
|
vis_mask = np.ones(shape=(grid.width, grid.height), dtype=bool)
|
|
vis_mask = np.ones(shape=(grid.width, grid.height), dtype=bool)
|
|
|
|
|
|
@@ -1233,21 +1205,18 @@ class MiniGridEnv(gym.Env):
|
|
image = grid.encode(vis_mask)
|
|
image = grid.encode(vis_mask)
|
|
|
|
|
|
assert hasattr(
|
|
assert hasattr(
|
|
- self, 'mission'), "environments must define a textual mission string"
|
|
|
|
|
|
+ self, "mission"
|
|
|
|
+ ), "environments must define a textual mission string"
|
|
|
|
|
|
# Observations are dictionaries containing:
|
|
# Observations are dictionaries containing:
|
|
# - an image (partially observable view of the environment)
|
|
# - an image (partially observable view of the environment)
|
|
# - the agent's direction/orientation (acting as a compass)
|
|
# - the agent's direction/orientation (acting as a compass)
|
|
# - a textual mission string (instructions for the agent)
|
|
# - a textual mission string (instructions for the agent)
|
|
- obs = {
|
|
|
|
- 'image': image,
|
|
|
|
- 'direction': self.agent_dir,
|
|
|
|
- 'mission': self.mission
|
|
|
|
- }
|
|
|
|
|
|
+ obs = {"image": image, "direction": self.agent_dir, "mission": self.mission}
|
|
|
|
|
|
return obs
|
|
return obs
|
|
|
|
|
|
- def get_obs_render(self, obs, tile_size=TILE_PIXELS//2):
|
|
|
|
|
|
+ def get_obs_render(self, obs, tile_size=TILE_PIXELS // 2):
|
|
"""
|
|
"""
|
|
Render an agent observation for visualization
|
|
Render an agent observation for visualization
|
|
"""
|
|
"""
|
|
@@ -1259,12 +1228,12 @@ class MiniGridEnv(gym.Env):
|
|
tile_size,
|
|
tile_size,
|
|
agent_pos=(self.agent_view_size // 2, self.agent_view_size - 1),
|
|
agent_pos=(self.agent_view_size // 2, self.agent_view_size - 1),
|
|
agent_dir=3,
|
|
agent_dir=3,
|
|
- highlight_mask=vis_mask
|
|
|
|
|
|
+ highlight_mask=vis_mask,
|
|
)
|
|
)
|
|
|
|
|
|
return img
|
|
return img
|
|
|
|
|
|
- def render(self, mode='human', close=False, highlight=True, tile_size=TILE_PIXELS):
|
|
|
|
|
|
+ def render(self, mode="human", close=False, highlight=True, tile_size=TILE_PIXELS):
|
|
"""
|
|
"""
|
|
Render the whole-grid human view
|
|
Render the whole-grid human view
|
|
"""
|
|
"""
|
|
@@ -1275,9 +1244,10 @@ class MiniGridEnv(gym.Env):
|
|
self.window.close()
|
|
self.window.close()
|
|
return
|
|
return
|
|
|
|
|
|
- if mode == 'human' and not self.window:
|
|
|
|
|
|
+ if mode == "human" and not self.window:
|
|
import gym_minigrid.window
|
|
import gym_minigrid.window
|
|
- self.window = gym_minigrid.window.Window('gym_minigrid')
|
|
|
|
|
|
+
|
|
|
|
+ self.window = gym_minigrid.window.Window("gym_minigrid")
|
|
self.window.show(block=False)
|
|
self.window.show(block=False)
|
|
|
|
|
|
# Compute which cells are visible to the agent
|
|
# Compute which cells are visible to the agent
|
|
@@ -1287,8 +1257,11 @@ class MiniGridEnv(gym.Env):
|
|
# of the agent's view area
|
|
# of the agent's view area
|
|
f_vec = self.dir_vec
|
|
f_vec = self.dir_vec
|
|
r_vec = self.right_vec
|
|
r_vec = self.right_vec
|
|
- top_left = self.agent_pos + f_vec * \
|
|
|
|
- (self.agent_view_size-1) - r_vec * (self.agent_view_size // 2)
|
|
|
|
|
|
+ top_left = (
|
|
|
|
+ self.agent_pos
|
|
|
|
+ + f_vec * (self.agent_view_size - 1)
|
|
|
|
+ - r_vec * (self.agent_view_size // 2)
|
|
|
|
+ )
|
|
|
|
|
|
# Mask of which cells to highlight
|
|
# Mask of which cells to highlight
|
|
highlight_mask = np.zeros(shape=(self.width, self.height), dtype=bool)
|
|
highlight_mask = np.zeros(shape=(self.width, self.height), dtype=bool)
|
|
@@ -1316,10 +1289,10 @@ class MiniGridEnv(gym.Env):
|
|
tile_size,
|
|
tile_size,
|
|
self.agent_pos,
|
|
self.agent_pos,
|
|
self.agent_dir,
|
|
self.agent_dir,
|
|
- highlight_mask=highlight_mask if highlight else None
|
|
|
|
|
|
+ highlight_mask=highlight_mask if highlight else None,
|
|
)
|
|
)
|
|
|
|
|
|
- if mode == 'human':
|
|
|
|
|
|
+ if mode == "human":
|
|
self.window.set_caption(self.mission)
|
|
self.window.set_caption(self.mission)
|
|
self.window.show_img(img)
|
|
self.window.show_img(img)
|
|
|
|
|