minigrid.py 36 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288
  1. import hashlib
  2. import math
  3. from enum import IntEnum
  4. import gym
  5. import numpy as np
  6. from gym import spaces
  7. from gym.utils import seeding
  8. from gym_minigrid.rendering import (
  9. downsample,
  10. fill_coords,
  11. highlight_img,
  12. point_in_circle,
  13. point_in_line,
  14. point_in_rect,
  15. point_in_triangle,
  16. rotate_fn,
  17. )
  18. # Size in pixels of a tile in the full-scale human view
  19. TILE_PIXELS = 32
  20. # Map of color names to RGB values
  21. COLORS = {
  22. "red": np.array([255, 0, 0]),
  23. "green": np.array([0, 255, 0]),
  24. "blue": np.array([0, 0, 255]),
  25. "purple": np.array([112, 39, 195]),
  26. "yellow": np.array([255, 255, 0]),
  27. "grey": np.array([100, 100, 100]),
  28. }
  29. COLOR_NAMES = sorted(list(COLORS.keys()))
  30. # Used to map colors to integers
  31. COLOR_TO_IDX = {"red": 0, "green": 1, "blue": 2, "purple": 3, "yellow": 4, "grey": 5}
  32. IDX_TO_COLOR = dict(zip(COLOR_TO_IDX.values(), COLOR_TO_IDX.keys()))
  33. # Map of object type to integers
  34. OBJECT_TO_IDX = {
  35. "unseen": 0,
  36. "empty": 1,
  37. "wall": 2,
  38. "floor": 3,
  39. "door": 4,
  40. "key": 5,
  41. "ball": 6,
  42. "box": 7,
  43. "goal": 8,
  44. "lava": 9,
  45. "agent": 10,
  46. }
  47. IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys()))
  48. # Map of state names to integers
  49. STATE_TO_IDX = {
  50. "open": 0,
  51. "closed": 1,
  52. "locked": 2,
  53. }
  54. # Map of agent direction indices to vectors
  55. DIR_TO_VEC = [
  56. # Pointing right (positive X)
  57. np.array((1, 0)),
  58. # Down (positive Y)
  59. np.array((0, 1)),
  60. # Pointing left (negative X)
  61. np.array((-1, 0)),
  62. # Up (negative Y)
  63. np.array((0, -1)),
  64. ]
  65. class WorldObj:
  66. """
  67. Base class for grid world objects
  68. """
  69. def __init__(self, type, color):
  70. assert type in OBJECT_TO_IDX, type
  71. assert color in COLOR_TO_IDX, color
  72. self.type = type
  73. self.color = color
  74. self.contains = None
  75. # Initial position of the object
  76. self.init_pos = None
  77. # Current position of the object
  78. self.cur_pos = None
  79. def can_overlap(self):
  80. """Can the agent overlap with this?"""
  81. return False
  82. def can_pickup(self):
  83. """Can the agent pick this up?"""
  84. return False
  85. def can_contain(self):
  86. """Can this contain another object?"""
  87. return False
  88. def see_behind(self):
  89. """Can the agent see behind this object?"""
  90. return True
  91. def toggle(self, env, pos):
  92. """Method to trigger/toggle an action this object performs"""
  93. return False
  94. def encode(self):
  95. """Encode the a description of this object as a 3-tuple of integers"""
  96. return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], 0)
  97. @staticmethod
  98. def decode(type_idx, color_idx, state):
  99. """Create an object from a 3-tuple state description"""
  100. obj_type = IDX_TO_OBJECT[type_idx]
  101. color = IDX_TO_COLOR[color_idx]
  102. if obj_type == "empty" or obj_type == "unseen":
  103. return None
  104. # State, 0: open, 1: closed, 2: locked
  105. is_open = state == 0
  106. is_locked = state == 2
  107. if obj_type == "wall":
  108. v = Wall(color)
  109. elif obj_type == "floor":
  110. v = Floor(color)
  111. elif obj_type == "ball":
  112. v = Ball(color)
  113. elif obj_type == "key":
  114. v = Key(color)
  115. elif obj_type == "box":
  116. v = Box(color)
  117. elif obj_type == "door":
  118. v = Door(color, is_open, is_locked)
  119. elif obj_type == "goal":
  120. v = Goal()
  121. elif obj_type == "lava":
  122. v = Lava()
  123. else:
  124. assert False, "unknown object type in decode '%s'" % obj_type
  125. return v
  126. def render(self, r):
  127. """Draw this object with the given renderer"""
  128. raise NotImplementedError
  129. class Goal(WorldObj):
  130. def __init__(self):
  131. super().__init__("goal", "green")
  132. def can_overlap(self):
  133. return True
  134. def render(self, img):
  135. fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color])
  136. class Floor(WorldObj):
  137. """
  138. Colored floor tile the agent can walk over
  139. """
  140. def __init__(self, color="blue"):
  141. super().__init__("floor", color)
  142. def can_overlap(self):
  143. return True
  144. def render(self, img):
  145. # Give the floor a pale color
  146. color = COLORS[self.color] / 2
  147. fill_coords(img, point_in_rect(0.031, 1, 0.031, 1), color)
  148. class Lava(WorldObj):
  149. def __init__(self):
  150. super().__init__("lava", "red")
  151. def can_overlap(self):
  152. return True
  153. def render(self, img):
  154. c = (255, 128, 0)
  155. # Background color
  156. fill_coords(img, point_in_rect(0, 1, 0, 1), c)
  157. # Little waves
  158. for i in range(3):
  159. ylo = 0.3 + 0.2 * i
  160. yhi = 0.4 + 0.2 * i
  161. fill_coords(img, point_in_line(0.1, ylo, 0.3, yhi, r=0.03), (0, 0, 0))
  162. fill_coords(img, point_in_line(0.3, yhi, 0.5, ylo, r=0.03), (0, 0, 0))
  163. fill_coords(img, point_in_line(0.5, ylo, 0.7, yhi, r=0.03), (0, 0, 0))
  164. fill_coords(img, point_in_line(0.7, yhi, 0.9, ylo, r=0.03), (0, 0, 0))
  165. class Wall(WorldObj):
  166. def __init__(self, color="grey"):
  167. super().__init__("wall", color)
  168. def see_behind(self):
  169. return False
  170. def render(self, img):
  171. fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color])
  172. class Door(WorldObj):
  173. def __init__(self, color, is_open=False, is_locked=False):
  174. super().__init__("door", color)
  175. self.is_open = is_open
  176. self.is_locked = is_locked
  177. def can_overlap(self):
  178. """The agent can only walk over this cell when the door is open"""
  179. return self.is_open
  180. def see_behind(self):
  181. return self.is_open
  182. def toggle(self, env, pos):
  183. # If the player has the right key to open the door
  184. if self.is_locked:
  185. if isinstance(env.carrying, Key) and env.carrying.color == self.color:
  186. self.is_locked = False
  187. self.is_open = True
  188. return True
  189. return False
  190. self.is_open = not self.is_open
  191. return True
  192. def encode(self):
  193. """Encode the a description of this object as a 3-tuple of integers"""
  194. # State, 0: open, 1: closed, 2: locked
  195. if self.is_open:
  196. state = 0
  197. elif self.is_locked:
  198. state = 2
  199. elif not self.is_open:
  200. state = 1
  201. return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], state)
  202. def render(self, img):
  203. c = COLORS[self.color]
  204. if self.is_open:
  205. fill_coords(img, point_in_rect(0.88, 1.00, 0.00, 1.00), c)
  206. fill_coords(img, point_in_rect(0.92, 0.96, 0.04, 0.96), (0, 0, 0))
  207. return
  208. # Door frame and door
  209. if self.is_locked:
  210. fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)
  211. fill_coords(img, point_in_rect(0.06, 0.94, 0.06, 0.94), 0.45 * np.array(c))
  212. # Draw key slot
  213. fill_coords(img, point_in_rect(0.52, 0.75, 0.50, 0.56), c)
  214. else:
  215. fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)
  216. fill_coords(img, point_in_rect(0.04, 0.96, 0.04, 0.96), (0, 0, 0))
  217. fill_coords(img, point_in_rect(0.08, 0.92, 0.08, 0.92), c)
  218. fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), (0, 0, 0))
  219. # Draw door handle
  220. fill_coords(img, point_in_circle(cx=0.75, cy=0.50, r=0.08), c)
  221. class Key(WorldObj):
  222. def __init__(self, color="blue"):
  223. super().__init__("key", color)
  224. def can_pickup(self):
  225. return True
  226. def render(self, img):
  227. c = COLORS[self.color]
  228. # Vertical quad
  229. fill_coords(img, point_in_rect(0.50, 0.63, 0.31, 0.88), c)
  230. # Teeth
  231. fill_coords(img, point_in_rect(0.38, 0.50, 0.59, 0.66), c)
  232. fill_coords(img, point_in_rect(0.38, 0.50, 0.81, 0.88), c)
  233. # Ring
  234. fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.190), c)
  235. fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.064), (0, 0, 0))
  236. class Ball(WorldObj):
  237. def __init__(self, color="blue"):
  238. super().__init__("ball", color)
  239. def can_pickup(self):
  240. return True
  241. def render(self, img):
  242. fill_coords(img, point_in_circle(0.5, 0.5, 0.31), COLORS[self.color])
  243. class Box(WorldObj):
  244. def __init__(self, color, contains=None):
  245. super().__init__("box", color)
  246. self.contains = contains
  247. def can_pickup(self):
  248. return True
  249. def render(self, img):
  250. c = COLORS[self.color]
  251. # Outline
  252. fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), c)
  253. fill_coords(img, point_in_rect(0.18, 0.82, 0.18, 0.82), (0, 0, 0))
  254. # Horizontal slit
  255. fill_coords(img, point_in_rect(0.16, 0.84, 0.47, 0.53), c)
  256. def toggle(self, env, pos):
  257. # Replace the box by its contents
  258. env.grid.set(*pos, self.contains)
  259. return True
  260. class Grid:
  261. """
  262. Represent a grid and operations on it
  263. """
  264. # Static cache of pre-renderer tiles
  265. tile_cache = {}
  266. def __init__(self, width, height):
  267. assert width >= 3
  268. assert height >= 3
  269. self.width = width
  270. self.height = height
  271. self.grid = [None] * width * height
  272. def __contains__(self, key):
  273. if isinstance(key, WorldObj):
  274. for e in self.grid:
  275. if e is key:
  276. return True
  277. elif isinstance(key, tuple):
  278. for e in self.grid:
  279. if e is None:
  280. continue
  281. if (e.color, e.type) == key:
  282. return True
  283. if key[0] is None and key[1] == e.type:
  284. return True
  285. return False
  286. def __eq__(self, other):
  287. grid1 = self.encode()
  288. grid2 = other.encode()
  289. return np.array_equal(grid2, grid1)
  290. def __ne__(self, other):
  291. return not self == other
  292. def copy(self):
  293. from copy import deepcopy
  294. return deepcopy(self)
  295. def set(self, i, j, v):
  296. assert i >= 0 and i < self.width
  297. assert j >= 0 and j < self.height
  298. self.grid[j * self.width + i] = v
  299. def get(self, i, j):
  300. assert i >= 0 and i < self.width
  301. assert j >= 0 and j < self.height
  302. return self.grid[j * self.width + i]
  303. def horz_wall(self, x, y, length=None, obj_type=Wall):
  304. if length is None:
  305. length = self.width - x
  306. for i in range(0, length):
  307. self.set(x + i, y, obj_type())
  308. def vert_wall(self, x, y, length=None, obj_type=Wall):
  309. if length is None:
  310. length = self.height - y
  311. for j in range(0, length):
  312. self.set(x, y + j, obj_type())
  313. def wall_rect(self, x, y, w, h):
  314. self.horz_wall(x, y, w)
  315. self.horz_wall(x, y + h - 1, w)
  316. self.vert_wall(x, y, h)
  317. self.vert_wall(x + w - 1, y, h)
  318. def rotate_left(self):
  319. """
  320. Rotate the grid to the left (counter-clockwise)
  321. """
  322. grid = Grid(self.height, self.width)
  323. for i in range(self.width):
  324. for j in range(self.height):
  325. v = self.get(i, j)
  326. grid.set(j, grid.height - 1 - i, v)
  327. return grid
  328. def slice(self, topX, topY, width, height):
  329. """
  330. Get a subset of the grid
  331. """
  332. grid = Grid(width, height)
  333. for j in range(0, height):
  334. for i in range(0, width):
  335. x = topX + i
  336. y = topY + j
  337. if x >= 0 and x < self.width and y >= 0 and y < self.height:
  338. v = self.get(x, y)
  339. else:
  340. v = Wall()
  341. grid.set(i, j, v)
  342. return grid
  343. @classmethod
  344. def render_tile(
  345. cls, obj, agent_dir=None, highlight=False, tile_size=TILE_PIXELS, subdivs=3
  346. ):
  347. """
  348. Render a tile and cache the result
  349. """
  350. # Hash map lookup key for the cache
  351. key = (agent_dir, highlight, tile_size)
  352. key = obj.encode() + key if obj else key
  353. if key in cls.tile_cache:
  354. return cls.tile_cache[key]
  355. img = np.zeros(
  356. shape=(tile_size * subdivs, tile_size * subdivs, 3), dtype=np.uint8
  357. )
  358. # Draw the grid lines (top and left edges)
  359. fill_coords(img, point_in_rect(0, 0.031, 0, 1), (100, 100, 100))
  360. fill_coords(img, point_in_rect(0, 1, 0, 0.031), (100, 100, 100))
  361. if obj is not None:
  362. obj.render(img)
  363. # Overlay the agent on top
  364. if agent_dir is not None:
  365. tri_fn = point_in_triangle(
  366. (0.12, 0.19),
  367. (0.87, 0.50),
  368. (0.12, 0.81),
  369. )
  370. # Rotate the agent based on its direction
  371. tri_fn = rotate_fn(tri_fn, cx=0.5, cy=0.5, theta=0.5 * math.pi * agent_dir)
  372. fill_coords(img, tri_fn, (255, 0, 0))
  373. # Highlight the cell if needed
  374. if highlight:
  375. highlight_img(img)
  376. # Downsample the image to perform supersampling/anti-aliasing
  377. img = downsample(img, subdivs)
  378. # Cache the rendered tile
  379. cls.tile_cache[key] = img
  380. return img
  381. def render(self, tile_size, agent_pos=None, agent_dir=None, highlight_mask=None):
  382. """
  383. Render this grid at a given scale
  384. :param r: target renderer object
  385. :param tile_size: tile size in pixels
  386. """
  387. if highlight_mask is None:
  388. highlight_mask = np.zeros(shape=(self.width, self.height), dtype=bool)
  389. # Compute the total grid size
  390. width_px = self.width * tile_size
  391. height_px = self.height * tile_size
  392. img = np.zeros(shape=(height_px, width_px, 3), dtype=np.uint8)
  393. # Render the grid
  394. for j in range(0, self.height):
  395. for i in range(0, self.width):
  396. cell = self.get(i, j)
  397. agent_here = np.array_equal(agent_pos, (i, j))
  398. tile_img = Grid.render_tile(
  399. cell,
  400. agent_dir=agent_dir if agent_here else None,
  401. highlight=highlight_mask[i, j],
  402. tile_size=tile_size,
  403. )
  404. ymin = j * tile_size
  405. ymax = (j + 1) * tile_size
  406. xmin = i * tile_size
  407. xmax = (i + 1) * tile_size
  408. img[ymin:ymax, xmin:xmax, :] = tile_img
  409. return img
  410. def encode(self, vis_mask=None):
  411. """
  412. Produce a compact numpy encoding of the grid
  413. """
  414. if vis_mask is None:
  415. vis_mask = np.ones((self.width, self.height), dtype=bool)
  416. array = np.zeros((self.width, self.height, 3), dtype="uint8")
  417. for i in range(self.width):
  418. for j in range(self.height):
  419. if vis_mask[i, j]:
  420. v = self.get(i, j)
  421. if v is None:
  422. array[i, j, 0] = OBJECT_TO_IDX["empty"]
  423. array[i, j, 1] = 0
  424. array[i, j, 2] = 0
  425. else:
  426. array[i, j, :] = v.encode()
  427. return array
  428. @staticmethod
  429. def decode(array):
  430. """
  431. Decode an array grid encoding back into a grid
  432. """
  433. width, height, channels = array.shape
  434. assert channels == 3
  435. vis_mask = np.ones(shape=(width, height), dtype=bool)
  436. grid = Grid(width, height)
  437. for i in range(width):
  438. for j in range(height):
  439. type_idx, color_idx, state = array[i, j]
  440. v = WorldObj.decode(type_idx, color_idx, state)
  441. grid.set(i, j, v)
  442. vis_mask[i, j] = type_idx != OBJECT_TO_IDX["unseen"]
  443. return grid, vis_mask
  444. def process_vis(grid, agent_pos):
  445. mask = np.zeros(shape=(grid.width, grid.height), dtype=bool)
  446. mask[agent_pos[0], agent_pos[1]] = True
  447. for j in reversed(range(0, grid.height)):
  448. for i in range(0, grid.width - 1):
  449. if not mask[i, j]:
  450. continue
  451. cell = grid.get(i, j)
  452. if cell and not cell.see_behind():
  453. continue
  454. mask[i + 1, j] = True
  455. if j > 0:
  456. mask[i + 1, j - 1] = True
  457. mask[i, j - 1] = True
  458. for i in reversed(range(1, grid.width)):
  459. if not mask[i, j]:
  460. continue
  461. cell = grid.get(i, j)
  462. if cell and not cell.see_behind():
  463. continue
  464. mask[i - 1, j] = True
  465. if j > 0:
  466. mask[i - 1, j - 1] = True
  467. mask[i, j - 1] = True
  468. for j in range(0, grid.height):
  469. for i in range(0, grid.width):
  470. if not mask[i, j]:
  471. grid.set(i, j, None)
  472. return mask
  473. class MiniGridEnv(gym.Env):
  474. """
  475. 2D grid world game environment
  476. """
  477. metadata = {"render.modes": ["human", "rgb_array"], "video.frames_per_second": 10}
  478. # Enumeration of possible actions
  479. class Actions(IntEnum):
  480. # Turn left, turn right, move forward
  481. left = 0
  482. right = 1
  483. forward = 2
  484. # Pick up an object
  485. pickup = 3
  486. # Drop an object
  487. drop = 4
  488. # Toggle/activate an object
  489. toggle = 5
  490. # Done completing task
  491. done = 6
  492. def __init__(
  493. self,
  494. grid_size=None,
  495. width=None,
  496. height=None,
  497. max_steps=100,
  498. see_through_walls=False,
  499. seed=1337,
  500. agent_view_size=7,
  501. ):
  502. # Can't set both grid_size and width/height
  503. if grid_size:
  504. assert width is None and height is None
  505. width = grid_size
  506. height = grid_size
  507. # Action enumeration for this environment
  508. self.actions = MiniGridEnv.Actions
  509. # Actions are discrete integer values
  510. self.action_space = spaces.Discrete(len(self.actions))
  511. # Number of cells (width and height) in the agent view
  512. assert agent_view_size % 2 == 1
  513. assert agent_view_size >= 3
  514. self.agent_view_size = agent_view_size
  515. # Observations are dictionaries containing an
  516. # encoding of the grid and a textual 'mission' string
  517. self.observation_space = spaces.Box(
  518. low=0,
  519. high=255,
  520. shape=(self.agent_view_size, self.agent_view_size, 3),
  521. dtype="uint8",
  522. )
  523. self.observation_space = spaces.Dict({"image": self.observation_space})
  524. # Range of possible rewards
  525. self.reward_range = (0, 1)
  526. # Window to use for human rendering mode
  527. self.window = None
  528. # Environment configuration
  529. self.width = width
  530. self.height = height
  531. self.max_steps = max_steps
  532. self.see_through_walls = see_through_walls
  533. # Current position and direction of the agent
  534. self.agent_pos = None
  535. self.agent_dir = None
  536. # Initialize the RNG
  537. self.seed(seed=seed)
  538. # Initialize the state
  539. self.reset()
  540. def reset(self):
  541. # Current position and direction of the agent
  542. self.agent_pos = None
  543. self.agent_dir = None
  544. # Generate a new random grid at the start of each episode
  545. # To keep the same grid for each episode, call env.seed() with
  546. # the same seed before calling env.reset()
  547. self._gen_grid(self.width, self.height)
  548. # These fields should be defined by _gen_grid
  549. assert self.agent_pos is not None
  550. assert self.agent_dir is not None
  551. # Check that the agent doesn't overlap with an object
  552. start_cell = self.grid.get(*self.agent_pos)
  553. assert start_cell is None or start_cell.can_overlap()
  554. # Item picked up, being carried, initially nothing
  555. self.carrying = None
  556. # Step count since episode start
  557. self.step_count = 0
  558. # Return first observation
  559. obs = self.gen_obs()
  560. return obs
  561. def seed(self, seed=1337):
  562. # Seed the random number generator
  563. self.np_random, _ = seeding.np_random(seed)
  564. return [seed]
  565. def hash(self, size=16):
  566. """Compute a hash that uniquely identifies the current state of the environment.
  567. :param size: Size of the hashing
  568. """
  569. sample_hash = hashlib.sha256()
  570. to_encode = [self.grid.encode().tolist(), self.agent_pos, self.agent_dir]
  571. for item in to_encode:
  572. sample_hash.update(str(item).encode("utf8"))
  573. return sample_hash.hexdigest()[:size]
  574. @property
  575. def steps_remaining(self):
  576. return self.max_steps - self.step_count
  577. def __str__(self):
  578. """
  579. Produce a pretty string of the environment's grid along with the agent.
  580. A grid cell is represented by 2-character string, the first one for
  581. the object and the second one for the color.
  582. """
  583. # Map of object types to short string
  584. OBJECT_TO_STR = {
  585. "wall": "W",
  586. "floor": "F",
  587. "door": "D",
  588. "key": "K",
  589. "ball": "A",
  590. "box": "B",
  591. "goal": "G",
  592. "lava": "V",
  593. }
  594. # Map agent's direction to short string
  595. AGENT_DIR_TO_STR = {0: ">", 1: "V", 2: "<", 3: "^"}
  596. str = ""
  597. for j in range(self.grid.height):
  598. for i in range(self.grid.width):
  599. if i == self.agent_pos[0] and j == self.agent_pos[1]:
  600. str += 2 * AGENT_DIR_TO_STR[self.agent_dir]
  601. continue
  602. c = self.grid.get(i, j)
  603. if c is None:
  604. str += " "
  605. continue
  606. if c.type == "door":
  607. if c.is_open:
  608. str += "__"
  609. elif c.is_locked:
  610. str += "L" + c.color[0].upper()
  611. else:
  612. str += "D" + c.color[0].upper()
  613. continue
  614. str += OBJECT_TO_STR[c.type] + c.color[0].upper()
  615. if j < self.grid.height - 1:
  616. str += "\n"
  617. return str
  618. def _gen_grid(self, width, height):
  619. assert False, "_gen_grid needs to be implemented by each environment"
  620. def _reward(self):
  621. """
  622. Compute the reward to be given upon success
  623. """
  624. return 1 - 0.9 * (self.step_count / self.max_steps)
  625. def _rand_int(self, low, high):
  626. """
  627. Generate random integer in [low,high[
  628. """
  629. return self.np_random.randint(low, high)
  630. def _rand_float(self, low, high):
  631. """
  632. Generate random float in [low,high[
  633. """
  634. return self.np_random.uniform(low, high)
  635. def _rand_bool(self):
  636. """
  637. Generate random boolean value
  638. """
  639. return self.np_random.randint(0, 2) == 0
  640. def _rand_elem(self, iterable):
  641. """
  642. Pick a random element in a list
  643. """
  644. lst = list(iterable)
  645. idx = self._rand_int(0, len(lst))
  646. return lst[idx]
  647. def _rand_subset(self, iterable, num_elems):
  648. """
  649. Sample a random subset of distinct elements of a list
  650. """
  651. lst = list(iterable)
  652. assert num_elems <= len(lst)
  653. out = []
  654. while len(out) < num_elems:
  655. elem = self._rand_elem(lst)
  656. lst.remove(elem)
  657. out.append(elem)
  658. return out
  659. def _rand_color(self):
  660. """
  661. Generate a random color name (string)
  662. """
  663. return self._rand_elem(COLOR_NAMES)
  664. def _rand_pos(self, xLow, xHigh, yLow, yHigh):
  665. """
  666. Generate a random (x,y) position tuple
  667. """
  668. return (
  669. self.np_random.randint(xLow, xHigh),
  670. self.np_random.randint(yLow, yHigh),
  671. )
  672. def place_obj(self, obj, top=None, size=None, reject_fn=None, max_tries=math.inf):
  673. """
  674. Place an object at an empty position in the grid
  675. :param top: top-left position of the rectangle where to place
  676. :param size: size of the rectangle where to place
  677. :param reject_fn: function to filter out potential positions
  678. """
  679. if top is None:
  680. top = (0, 0)
  681. else:
  682. top = (max(top[0], 0), max(top[1], 0))
  683. if size is None:
  684. size = (self.grid.width, self.grid.height)
  685. num_tries = 0
  686. while True:
  687. # This is to handle with rare cases where rejection sampling
  688. # gets stuck in an infinite loop
  689. if num_tries > max_tries:
  690. raise RecursionError("rejection sampling failed in place_obj")
  691. num_tries += 1
  692. pos = np.array(
  693. (
  694. self._rand_int(top[0], min(top[0] + size[0], self.grid.width)),
  695. self._rand_int(top[1], min(top[1] + size[1], self.grid.height)),
  696. )
  697. )
  698. # Don't place the object on top of another object
  699. if self.grid.get(*pos) is not None:
  700. continue
  701. # Don't place the object where the agent is
  702. if np.array_equal(pos, self.agent_pos):
  703. continue
  704. # Check if there is a filtering criterion
  705. if reject_fn and reject_fn(self, pos):
  706. continue
  707. break
  708. self.grid.set(*pos, obj)
  709. if obj is not None:
  710. obj.init_pos = pos
  711. obj.cur_pos = pos
  712. return pos
  713. def put_obj(self, obj, i, j):
  714. """
  715. Put an object at a specific position in the grid
  716. """
  717. self.grid.set(i, j, obj)
  718. obj.init_pos = (i, j)
  719. obj.cur_pos = (i, j)
  720. def place_agent(self, top=None, size=None, rand_dir=True, max_tries=math.inf):
  721. """
  722. Set the agent's starting point at an empty position in the grid
  723. """
  724. self.agent_pos = None
  725. pos = self.place_obj(None, top, size, max_tries=max_tries)
  726. self.agent_pos = pos
  727. if rand_dir:
  728. self.agent_dir = self._rand_int(0, 4)
  729. return pos
  730. @property
  731. def dir_vec(self):
  732. """
  733. Get the direction vector for the agent, pointing in the direction
  734. of forward movement.
  735. """
  736. assert self.agent_dir >= 0 and self.agent_dir < 4
  737. return DIR_TO_VEC[self.agent_dir]
  738. @property
  739. def right_vec(self):
  740. """
  741. Get the vector pointing to the right of the agent.
  742. """
  743. dx, dy = self.dir_vec
  744. return np.array((-dy, dx))
  745. @property
  746. def front_pos(self):
  747. """
  748. Get the position of the cell that is right in front of the agent
  749. """
  750. return self.agent_pos + self.dir_vec
  751. def get_view_coords(self, i, j):
  752. """
  753. Translate and rotate absolute grid coordinates (i, j) into the
  754. agent's partially observable view (sub-grid). Note that the resulting
  755. coordinates may be negative or outside of the agent's view size.
  756. """
  757. ax, ay = self.agent_pos
  758. dx, dy = self.dir_vec
  759. rx, ry = self.right_vec
  760. # Compute the absolute coordinates of the top-left view corner
  761. sz = self.agent_view_size
  762. hs = self.agent_view_size // 2
  763. tx = ax + (dx * (sz - 1)) - (rx * hs)
  764. ty = ay + (dy * (sz - 1)) - (ry * hs)
  765. lx = i - tx
  766. ly = j - ty
  767. # Project the coordinates of the object relative to the top-left
  768. # corner onto the agent's own coordinate system
  769. vx = rx * lx + ry * ly
  770. vy = -(dx * lx + dy * ly)
  771. return vx, vy
  772. def get_view_exts(self):
  773. """
  774. Get the extents of the square set of tiles visible to the agent
  775. Note: the bottom extent indices are not included in the set
  776. """
  777. # Facing right
  778. if self.agent_dir == 0:
  779. topX = self.agent_pos[0]
  780. topY = self.agent_pos[1] - self.agent_view_size // 2
  781. # Facing down
  782. elif self.agent_dir == 1:
  783. topX = self.agent_pos[0] - self.agent_view_size // 2
  784. topY = self.agent_pos[1]
  785. # Facing left
  786. elif self.agent_dir == 2:
  787. topX = self.agent_pos[0] - self.agent_view_size + 1
  788. topY = self.agent_pos[1] - self.agent_view_size // 2
  789. # Facing up
  790. elif self.agent_dir == 3:
  791. topX = self.agent_pos[0] - self.agent_view_size // 2
  792. topY = self.agent_pos[1] - self.agent_view_size + 1
  793. else:
  794. assert False, "invalid agent direction"
  795. botX = topX + self.agent_view_size
  796. botY = topY + self.agent_view_size
  797. return (topX, topY, botX, botY)
  798. def relative_coords(self, x, y):
  799. """
  800. Check if a grid position belongs to the agent's field of view, and returns the corresponding coordinates
  801. """
  802. vx, vy = self.get_view_coords(x, y)
  803. if vx < 0 or vy < 0 or vx >= self.agent_view_size or vy >= self.agent_view_size:
  804. return None
  805. return vx, vy
  806. def in_view(self, x, y):
  807. """
  808. check if a grid position is visible to the agent
  809. """
  810. return self.relative_coords(x, y) is not None
  811. def agent_sees(self, x, y):
  812. """
  813. Check if a non-empty grid position is visible to the agent
  814. """
  815. coordinates = self.relative_coords(x, y)
  816. if coordinates is None:
  817. return False
  818. vx, vy = coordinates
  819. obs = self.gen_obs()
  820. obs_grid, _ = Grid.decode(obs["image"])
  821. obs_cell = obs_grid.get(vx, vy)
  822. world_cell = self.grid.get(x, y)
  823. return obs_cell is not None and obs_cell.type == world_cell.type
  824. def step(self, action):
  825. self.step_count += 1
  826. reward = 0
  827. done = False
  828. # Get the position in front of the agent
  829. fwd_pos = self.front_pos
  830. # Get the contents of the cell in front of the agent
  831. fwd_cell = self.grid.get(*fwd_pos)
  832. # Rotate left
  833. if action == self.actions.left:
  834. self.agent_dir -= 1
  835. if self.agent_dir < 0:
  836. self.agent_dir += 4
  837. # Rotate right
  838. elif action == self.actions.right:
  839. self.agent_dir = (self.agent_dir + 1) % 4
  840. # Move forward
  841. elif action == self.actions.forward:
  842. if fwd_cell is None or fwd_cell.can_overlap():
  843. self.agent_pos = fwd_pos
  844. if fwd_cell is not None and fwd_cell.type == "goal":
  845. done = True
  846. reward = self._reward()
  847. if fwd_cell is not None and fwd_cell.type == "lava":
  848. done = True
  849. # Pick up an object
  850. elif action == self.actions.pickup:
  851. if fwd_cell and fwd_cell.can_pickup():
  852. if self.carrying is None:
  853. self.carrying = fwd_cell
  854. self.carrying.cur_pos = np.array([-1, -1])
  855. self.grid.set(*fwd_pos, None)
  856. # Drop an object
  857. elif action == self.actions.drop:
  858. if not fwd_cell and self.carrying:
  859. self.grid.set(*fwd_pos, self.carrying)
  860. self.carrying.cur_pos = fwd_pos
  861. self.carrying = None
  862. # Toggle/activate an object
  863. elif action == self.actions.toggle:
  864. if fwd_cell:
  865. fwd_cell.toggle(self, fwd_pos)
  866. # Done action (not used by default)
  867. elif action == self.actions.done:
  868. pass
  869. else:
  870. assert False, "unknown action"
  871. if self.step_count >= self.max_steps:
  872. done = True
  873. obs = self.gen_obs()
  874. return obs, reward, done, {}
  875. def gen_obs_grid(self):
  876. """
  877. Generate the sub-grid observed by the agent.
  878. This method also outputs a visibility mask telling us which grid
  879. cells the agent can actually see.
  880. """
  881. topX, topY, botX, botY = self.get_view_exts()
  882. grid = self.grid.slice(topX, topY, self.agent_view_size, self.agent_view_size)
  883. for i in range(self.agent_dir + 1):
  884. grid = grid.rotate_left()
  885. # Process occluders and visibility
  886. # Note that this incurs some performance cost
  887. if not self.see_through_walls:
  888. vis_mask = grid.process_vis(
  889. agent_pos=(self.agent_view_size // 2, self.agent_view_size - 1)
  890. )
  891. else:
  892. vis_mask = np.ones(shape=(grid.width, grid.height), dtype=bool)
  893. # Make it so the agent sees what it's carrying
  894. # We do this by placing the carried object at the agent's position
  895. # in the agent's partially observable view
  896. agent_pos = grid.width // 2, grid.height - 1
  897. if self.carrying:
  898. grid.set(*agent_pos, self.carrying)
  899. else:
  900. grid.set(*agent_pos, None)
  901. return grid, vis_mask
  902. def gen_obs(self):
  903. """
  904. Generate the agent's view (partially observable, low-resolution encoding)
  905. """
  906. grid, vis_mask = self.gen_obs_grid()
  907. # Encode the partially observable view into a numpy array
  908. image = grid.encode(vis_mask)
  909. assert hasattr(
  910. self, "mission"
  911. ), "environments must define a textual mission string"
  912. # Observations are dictionaries containing:
  913. # - an image (partially observable view of the environment)
  914. # - the agent's direction/orientation (acting as a compass)
  915. # - a textual mission string (instructions for the agent)
  916. obs = {"image": image, "direction": self.agent_dir, "mission": self.mission}
  917. return obs
  918. def get_obs_render(self, obs, tile_size=TILE_PIXELS // 2):
  919. """
  920. Render an agent observation for visualization
  921. """
  922. grid, vis_mask = Grid.decode(obs)
  923. # Render the whole grid
  924. img = grid.render(
  925. tile_size,
  926. agent_pos=(self.agent_view_size // 2, self.agent_view_size - 1),
  927. agent_dir=3,
  928. highlight_mask=vis_mask,
  929. )
  930. return img
  931. def render(self, mode="human", close=False, highlight=True, tile_size=TILE_PIXELS):
  932. """
  933. Render the whole-grid human view
  934. """
  935. if close:
  936. if self.window:
  937. self.window.close()
  938. return
  939. if mode == "human" and not self.window:
  940. import gym_minigrid.window
  941. self.window = gym_minigrid.window.Window("gym_minigrid")
  942. self.window.show(block=False)
  943. # Compute which cells are visible to the agent
  944. _, vis_mask = self.gen_obs_grid()
  945. # Compute the world coordinates of the bottom-left corner
  946. # of the agent's view area
  947. f_vec = self.dir_vec
  948. r_vec = self.right_vec
  949. top_left = (
  950. self.agent_pos
  951. + f_vec * (self.agent_view_size - 1)
  952. - r_vec * (self.agent_view_size // 2)
  953. )
  954. # Mask of which cells to highlight
  955. highlight_mask = np.zeros(shape=(self.width, self.height), dtype=bool)
  956. # For each cell in the visibility mask
  957. for vis_j in range(0, self.agent_view_size):
  958. for vis_i in range(0, self.agent_view_size):
  959. # If this cell is not visible, don't highlight it
  960. if not vis_mask[vis_i, vis_j]:
  961. continue
  962. # Compute the world coordinates of this cell
  963. abs_i, abs_j = top_left - (f_vec * vis_j) + (r_vec * vis_i)
  964. if abs_i < 0 or abs_i >= self.width:
  965. continue
  966. if abs_j < 0 or abs_j >= self.height:
  967. continue
  968. # Mark this cell to be highlighted
  969. highlight_mask[abs_i, abs_j] = True
  970. # Render the whole grid
  971. img = self.grid.render(
  972. tile_size,
  973. self.agent_pos,
  974. self.agent_dir,
  975. highlight_mask=highlight_mask if highlight else None,
  976. )
  977. if mode == "human":
  978. self.window.set_caption(self.mission)
  979. self.window.show_img(img)
  980. return img
  981. def close(self):
  982. if self.window:
  983. self.window.close()
  984. return