|
@@ -6,7 +6,7 @@ from enum import IntEnum
|
|
|
import gym
|
|
|
import numpy as np
|
|
|
from gym import spaces
|
|
|
-
|
|
|
+from abc import abstractmethod
|
|
|
# Size in pixels of a tile in the full-scale human view
|
|
|
from gym_minigrid.rendering import (
|
|
|
downsample,
|
|
@@ -34,7 +34,8 @@ COLORS = {
|
|
|
COLOR_NAMES = sorted(list(COLORS.keys()))
|
|
|
|
|
|
# Used to map colors to integers
|
|
|
-COLOR_TO_IDX = {"red": 0, "green": 1, "blue": 2, "purple": 3, "yellow": 4, "grey": 5}
|
|
|
+COLOR_TO_IDX = {"red": 0, "green": 1, "blue": 2,
|
|
|
+ "purple": 3, "yellow": 4, "grey": 5}
|
|
|
|
|
|
IDX_TO_COLOR = dict(zip(COLOR_TO_IDX.values(), COLOR_TO_IDX.keys()))
|
|
|
|
|
@@ -202,10 +203,14 @@ class Lava(WorldObj):
|
|
|
for i in range(3):
|
|
|
ylo = 0.3 + 0.2 * i
|
|
|
yhi = 0.4 + 0.2 * i
|
|
|
- fill_coords(img, point_in_line(0.1, ylo, 0.3, yhi, r=0.03), (0, 0, 0))
|
|
|
- fill_coords(img, point_in_line(0.3, yhi, 0.5, ylo, r=0.03), (0, 0, 0))
|
|
|
- fill_coords(img, point_in_line(0.5, ylo, 0.7, yhi, r=0.03), (0, 0, 0))
|
|
|
- fill_coords(img, point_in_line(0.7, yhi, 0.9, ylo, r=0.03), (0, 0, 0))
|
|
|
+ fill_coords(img, point_in_line(
|
|
|
+ 0.1, ylo, 0.3, yhi, r=0.03), (0, 0, 0))
|
|
|
+ fill_coords(img, point_in_line(
|
|
|
+ 0.3, yhi, 0.5, ylo, r=0.03), (0, 0, 0))
|
|
|
+ fill_coords(img, point_in_line(
|
|
|
+ 0.5, ylo, 0.7, yhi, r=0.03), (0, 0, 0))
|
|
|
+ fill_coords(img, point_in_line(
|
|
|
+ 0.7, yhi, 0.9, ylo, r=0.03), (0, 0, 0))
|
|
|
|
|
|
|
|
|
class Wall(WorldObj):
|
|
@@ -252,7 +257,7 @@ class Door(WorldObj):
|
|
|
state = 0
|
|
|
elif self.is_locked:
|
|
|
state = 2
|
|
|
- elif not self.is_open:
|
|
|
+ else:
|
|
|
state = 1
|
|
|
|
|
|
return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], state)
|
|
@@ -268,7 +273,8 @@ class Door(WorldObj):
|
|
|
# Door frame and door
|
|
|
if self.is_locked:
|
|
|
fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)
|
|
|
- fill_coords(img, point_in_rect(0.06, 0.94, 0.06, 0.94), 0.45 * np.array(c))
|
|
|
+ fill_coords(img, point_in_rect(
|
|
|
+ 0.06, 0.94, 0.06, 0.94), 0.45 * np.array(c))
|
|
|
|
|
|
# Draw key slot
|
|
|
fill_coords(img, point_in_rect(0.52, 0.75, 0.50, 0.56), c)
|
|
@@ -323,6 +329,9 @@ class Box(WorldObj):
|
|
|
def can_pickup(self):
|
|
|
return True
|
|
|
|
|
|
+ def set_contains(self, contains):
|
|
|
+ self.contains = contains
|
|
|
+
|
|
|
def render(self, img):
|
|
|
c = COLORS[self.color]
|
|
|
|
|
@@ -335,7 +344,7 @@ class Box(WorldObj):
|
|
|
|
|
|
def toggle(self, env, pos):
|
|
|
# Replace the box by its contents
|
|
|
- env.grid.set(*pos, self.contains)
|
|
|
+ env.grid.set(pos[0], pos[1], self.contains)
|
|
|
return True
|
|
|
|
|
|
|
|
@@ -482,7 +491,8 @@ class Grid:
|
|
|
)
|
|
|
|
|
|
# Rotate the agent based on its direction
|
|
|
- tri_fn = rotate_fn(tri_fn, cx=0.5, cy=0.5, theta=0.5 * math.pi * agent_dir)
|
|
|
+ tri_fn = rotate_fn(tri_fn, cx=0.5, cy=0.5,
|
|
|
+ theta=0.5 * math.pi * agent_dir)
|
|
|
fill_coords(img, tri_fn, (255, 0, 0))
|
|
|
|
|
|
# Highlight the cell if needed
|
|
@@ -497,7 +507,7 @@ class Grid:
|
|
|
|
|
|
return img
|
|
|
|
|
|
- def render(self, tile_size, agent_pos=None, agent_dir=None, highlight_mask=None):
|
|
|
+ def render(self, tile_size, agent_pos, agent_dir=None, highlight_mask=None):
|
|
|
"""
|
|
|
Render this grid at a given scale
|
|
|
:param r: target renderer object
|
|
@@ -505,7 +515,8 @@ class Grid:
|
|
|
"""
|
|
|
|
|
|
if highlight_mask is None:
|
|
|
- highlight_mask = np.zeros(shape=(self.width, self.height), dtype=bool)
|
|
|
+ highlight_mask = np.zeros(
|
|
|
+ shape=(self.width, self.height), dtype=bool)
|
|
|
|
|
|
# Compute the total grid size
|
|
|
width_px = self.width * tile_size
|
|
@@ -580,17 +591,17 @@ class Grid:
|
|
|
|
|
|
return grid, vis_mask
|
|
|
|
|
|
- def process_vis(grid, agent_pos):
|
|
|
- mask = np.zeros(shape=(grid.width, grid.height), dtype=bool)
|
|
|
+ def process_vis(self, agent_pos):
|
|
|
+ mask = np.zeros(shape=(self.width, self.height), dtype=bool)
|
|
|
|
|
|
mask[agent_pos[0], agent_pos[1]] = True
|
|
|
|
|
|
- for j in reversed(range(0, grid.height)):
|
|
|
- for i in range(0, grid.width - 1):
|
|
|
+ for j in reversed(range(0, self.height)):
|
|
|
+ for i in range(0, self.width - 1):
|
|
|
if not mask[i, j]:
|
|
|
continue
|
|
|
|
|
|
- cell = grid.get(i, j)
|
|
|
+ cell = self.get(i, j)
|
|
|
if cell and not cell.see_behind():
|
|
|
continue
|
|
|
|
|
@@ -599,11 +610,11 @@ class Grid:
|
|
|
mask[i + 1, j - 1] = True
|
|
|
mask[i, j - 1] = True
|
|
|
|
|
|
- for i in reversed(range(1, grid.width)):
|
|
|
+ for i in reversed(range(1, self.width)):
|
|
|
if not mask[i, j]:
|
|
|
continue
|
|
|
|
|
|
- cell = grid.get(i, j)
|
|
|
+ cell = self.get(i, j)
|
|
|
if cell and not cell.see_behind():
|
|
|
continue
|
|
|
|
|
@@ -612,10 +623,10 @@ class Grid:
|
|
|
mask[i - 1, j - 1] = True
|
|
|
mask[i, j - 1] = True
|
|
|
|
|
|
- for j in range(0, grid.height):
|
|
|
- for i in range(0, grid.width):
|
|
|
+ for j in range(0, self.height):
|
|
|
+ for i in range(0, self.width):
|
|
|
if not mask[i, j]:
|
|
|
- grid.set(i, j, None)
|
|
|
+ self.set(i, j, None)
|
|
|
|
|
|
return mask
|
|
|
|
|
@@ -713,24 +724,29 @@ class MiniGridEnv(gym.Env):
|
|
|
self.see_through_walls = see_through_walls
|
|
|
|
|
|
# Current position and direction of the agent
|
|
|
- self.agent_pos = None
|
|
|
- self.agent_dir = None
|
|
|
+ self.agent_pos = (-1, -1)
|
|
|
+ self.agent_dir = -1
|
|
|
+
|
|
|
+ # Current grid and mission and carryinh
|
|
|
+ self.grid = Grid(width, height)
|
|
|
+ self.mission = ""
|
|
|
+ self.carrying = None
|
|
|
|
|
|
# Initialize the state
|
|
|
self.reset()
|
|
|
|
|
|
def reset(self, *, seed=None, return_info=False, options=None):
|
|
|
super().reset(seed=seed)
|
|
|
- # Current position and direction of the agent
|
|
|
- self.agent_pos = None
|
|
|
- self.agent_dir = None
|
|
|
+
|
|
|
+ # Reinitialize episode-specific variables
|
|
|
+ self.agent_pos = (-1, -1)
|
|
|
+ self.agent_dir = -1
|
|
|
|
|
|
# Generate a new random grid at the start of each episode
|
|
|
self._gen_grid(self.width, self.height)
|
|
|
|
|
|
# These fields should be defined by _gen_grid
|
|
|
- assert self.agent_pos is not None
|
|
|
- assert self.agent_dir is not None
|
|
|
+ assert self.agent_pos >= (0, 0) and self.agent_dir >= 0
|
|
|
|
|
|
# Check that the agent doesn't overlap with an object
|
|
|
start_cell = self.grid.get(*self.agent_pos)
|
|
@@ -752,7 +768,8 @@ class MiniGridEnv(gym.Env):
|
|
|
"""
|
|
|
sample_hash = hashlib.sha256()
|
|
|
|
|
|
- to_encode = [self.grid.encode().tolist(), self.agent_pos, self.agent_dir]
|
|
|
+ to_encode = [self.grid.encode().tolist(), self.agent_pos,
|
|
|
+ self.agent_dir]
|
|
|
for item in to_encode:
|
|
|
sample_hash.update(str(item).encode("utf8"))
|
|
|
|
|
@@ -815,8 +832,9 @@ class MiniGridEnv(gym.Env):
|
|
|
|
|
|
return str
|
|
|
|
|
|
+ @abstractmethod
|
|
|
def _gen_grid(self, width, height):
|
|
|
- assert False, "_gen_grid needs to be implemented by each environment"
|
|
|
+ pass
|
|
|
|
|
|
def _reward(self):
|
|
|
"""
|
|
@@ -918,11 +936,15 @@ class MiniGridEnv(gym.Env):
|
|
|
|
|
|
pos = np.array(
|
|
|
(
|
|
|
- self._rand_int(top[0], min(top[0] + size[0], self.grid.width)),
|
|
|
- self._rand_int(top[1], min(top[1] + size[1], self.grid.height)),
|
|
|
+ self._rand_int(top[0], min(
|
|
|
+ top[0] + size[0], self.grid.width)),
|
|
|
+ self._rand_int(top[1], min(
|
|
|
+ top[1] + size[1], self.grid.height)),
|
|
|
)
|
|
|
)
|
|
|
|
|
|
+ pos = tuple(pos)
|
|
|
+
|
|
|
# Don't place the object on top of another object
|
|
|
if self.grid.get(*pos) is not None:
|
|
|
continue
|
|
@@ -937,7 +959,7 @@ class MiniGridEnv(gym.Env):
|
|
|
|
|
|
break
|
|
|
|
|
|
- self.grid.set(*pos, obj)
|
|
|
+ self.grid.set(pos[0], pos[1], obj)
|
|
|
|
|
|
if obj is not None:
|
|
|
obj.init_pos = pos
|
|
@@ -959,7 +981,7 @@ class MiniGridEnv(gym.Env):
|
|
|
Set the agent's starting point at an empty position in the grid
|
|
|
"""
|
|
|
|
|
|
- self.agent_pos = None
|
|
|
+ self.agent_pos = (-1, -1)
|
|
|
pos = self.place_obj(None, top, size, max_tries=max_tries)
|
|
|
self.agent_pos = pos
|
|
|
|
|
@@ -1089,13 +1111,16 @@ class MiniGridEnv(gym.Env):
|
|
|
obs_cell = obs_grid.get(vx, vy)
|
|
|
world_cell = self.grid.get(x, y)
|
|
|
|
|
|
+ assert world_cell is not None
|
|
|
+
|
|
|
return obs_cell is not None and obs_cell.type == world_cell.type
|
|
|
|
|
|
def step(self, action):
|
|
|
self.step_count += 1
|
|
|
|
|
|
reward = 0
|
|
|
- done = False
|
|
|
+ terminated = False
|
|
|
+ truncated = False
|
|
|
|
|
|
# Get the position in front of the agent
|
|
|
fwd_pos = self.front_pos
|
|
@@ -1116,12 +1141,12 @@ class MiniGridEnv(gym.Env):
|
|
|
# Move forward
|
|
|
elif action == self.actions.forward:
|
|
|
if fwd_cell is None or fwd_cell.can_overlap():
|
|
|
- self.agent_pos = fwd_pos
|
|
|
+ self.agent_pos = tuple(fwd_pos)
|
|
|
if fwd_cell is not None and fwd_cell.type == "goal":
|
|
|
- done = True
|
|
|
+ terminated = True
|
|
|
reward = self._reward()
|
|
|
if fwd_cell is not None and fwd_cell.type == "lava":
|
|
|
- done = True
|
|
|
+ terminated = True
|
|
|
|
|
|
# Pick up an object
|
|
|
elif action == self.actions.pickup:
|
|
@@ -1129,12 +1154,12 @@ class MiniGridEnv(gym.Env):
|
|
|
if self.carrying is None:
|
|
|
self.carrying = fwd_cell
|
|
|
self.carrying.cur_pos = np.array([-1, -1])
|
|
|
- self.grid.set(*fwd_pos, None)
|
|
|
+ self.grid.set(fwd_pos[0], fwd_pos[1], None)
|
|
|
|
|
|
# Drop an object
|
|
|
elif action == self.actions.drop:
|
|
|
if not fwd_cell and self.carrying:
|
|
|
- self.grid.set(*fwd_pos, self.carrying)
|
|
|
+ self.grid.set(fwd_pos[0], fwd_pos[1], self.carrying)
|
|
|
self.carrying.cur_pos = fwd_pos
|
|
|
self.carrying = None
|
|
|
|
|
@@ -1148,14 +1173,14 @@ class MiniGridEnv(gym.Env):
|
|
|
pass
|
|
|
|
|
|
else:
|
|
|
- assert False, "unknown action"
|
|
|
+ raise ValueError('Unknown action: {}'.format(action))
|
|
|
|
|
|
if self.step_count >= self.max_steps:
|
|
|
- done = True
|
|
|
+ truncated = True
|
|
|
|
|
|
obs = self.gen_obs()
|
|
|
|
|
|
- return obs, reward, done, {}
|
|
|
+ return obs, reward, terminated, truncated, {}
|
|
|
|
|
|
def gen_obs_grid(self, agent_view_size=None):
|
|
|
"""
|
|
@@ -1204,15 +1229,12 @@ class MiniGridEnv(gym.Env):
|
|
|
# Encode the partially observable view into a numpy array
|
|
|
image = grid.encode(vis_mask)
|
|
|
|
|
|
- assert hasattr(
|
|
|
- self, "mission"
|
|
|
- ), "environments must define a textual mission string"
|
|
|
-
|
|
|
# Observations are dictionaries containing:
|
|
|
# - an image (partially observable view of the environment)
|
|
|
# - the agent's direction/orientation (acting as a compass)
|
|
|
# - a textual mission string (instructions for the agent)
|
|
|
- obs = {"image": image, "direction": self.agent_dir, "mission": self.mission}
|
|
|
+ obs = {"image": image, "direction": self.agent_dir,
|
|
|
+ "mission": self.mission}
|
|
|
|
|
|
return obs
|
|
|
|
|
@@ -1293,6 +1315,7 @@ class MiniGridEnv(gym.Env):
|
|
|
)
|
|
|
|
|
|
if mode == "human":
|
|
|
+ assert self.window is not None
|
|
|
self.window.set_caption(self.mission)
|
|
|
self.window.show_img(img)
|
|
|
|