minigrid.py 37 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384
  1. import math
  2. import gym
  3. from enum import IntEnum
  4. import numpy as np
  5. from gym import error, spaces, utils
  6. from gym.utils import seeding
  7. # Size in pixels of a cell in the full-scale human view
  8. CELL_PIXELS = 32
  9. # Number of cells (width and height) in the agent view
  10. AGENT_VIEW_SIZE = 7
  11. # Size of the array given as an observation to the agent
  12. OBS_ARRAY_SIZE = (AGENT_VIEW_SIZE, AGENT_VIEW_SIZE, 3)
  13. # Map of color names to RGB values
  14. COLORS = {
  15. 'red' : np.array([255, 0, 0]),
  16. 'green' : np.array([0, 255, 0]),
  17. 'blue' : np.array([0, 0, 255]),
  18. 'purple': np.array([112, 39, 195]),
  19. 'yellow': np.array([255, 255, 0]),
  20. 'grey' : np.array([100, 100, 100])
  21. }
  22. COLOR_NAMES = sorted(list(COLORS.keys()))
  23. # Used to map colors to integers
  24. COLOR_TO_IDX = {
  25. 'red' : 0,
  26. 'green' : 1,
  27. 'blue' : 2,
  28. 'purple': 3,
  29. 'yellow': 4,
  30. 'grey' : 5
  31. }
  32. IDX_TO_COLOR = dict(zip(COLOR_TO_IDX.values(), COLOR_TO_IDX.keys()))
  33. # Map of object type to integers
  34. OBJECT_TO_IDX = {
  35. 'unseen' : 0,
  36. 'empty' : 1,
  37. 'wall' : 2,
  38. 'floor' : 3,
  39. 'door' : 4,
  40. 'locked_door' : 5,
  41. 'key' : 6,
  42. 'ball' : 7,
  43. 'box' : 8,
  44. 'goal' : 9,
  45. 'lava' : 10
  46. }
  47. IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys()))
  48. # Map of agent direction indices to vectors
  49. DIR_TO_VEC = [
  50. # Pointing right (positive X)
  51. np.array((1, 0)),
  52. # Down (positive Y)
  53. np.array((0, 1)),
  54. # Pointing left (negative X)
  55. np.array((-1, 0)),
  56. # Up (negative Y)
  57. np.array((0, -1)),
  58. ]
  59. class WorldObj:
  60. """
  61. Base class for grid world objects
  62. """
  63. def __init__(self, type, color):
  64. assert type in OBJECT_TO_IDX, type
  65. assert color in COLOR_TO_IDX, color
  66. self.type = type
  67. self.color = color
  68. self.contains = None
  69. # Initial position of the object
  70. self.init_pos = None
  71. # Current position of the object
  72. self.cur_pos = None
  73. def can_overlap(self):
  74. """Can the agent overlap with this?"""
  75. return False
  76. def can_pickup(self):
  77. """Can the agent pick this up?"""
  78. return False
  79. def can_contain(self):
  80. """Can this contain another object?"""
  81. return False
  82. def see_behind(self):
  83. """Can the agent see behind this object?"""
  84. return True
  85. def toggle(self, env, pos):
  86. """Method to trigger/toggle an action this object performs"""
  87. return False
  88. def render(self, r):
  89. """Draw this object with the given renderer"""
  90. raise NotImplementedError
  91. def _set_color(self, r):
  92. """Set the color of this object as the active drawing color"""
  93. c = COLORS[self.color]
  94. r.setLineColor(c[0], c[1], c[2])
  95. r.setColor(c[0], c[1], c[2])
  96. class Goal(WorldObj):
  97. def __init__(self):
  98. super().__init__('goal', 'green')
  99. def can_overlap(self):
  100. return True
  101. def render(self, r):
  102. self._set_color(r)
  103. r.drawPolygon([
  104. (0 , CELL_PIXELS),
  105. (CELL_PIXELS, CELL_PIXELS),
  106. (CELL_PIXELS, 0),
  107. (0 , 0)
  108. ])
  109. class Floor(WorldObj):
  110. """
  111. Colored floor tile the agent can walk over
  112. """
  113. def __init__(self, color='blue'):
  114. super().__init__('floor', color)
  115. def can_overlap(self):
  116. return True
  117. def render(self, r):
  118. # Give the floor a pale color
  119. c = COLORS[self.color]
  120. r.setLineColor(100, 100, 100, 0)
  121. r.setColor(*c/2)
  122. r.drawPolygon([
  123. (1 , CELL_PIXELS),
  124. (CELL_PIXELS, CELL_PIXELS),
  125. (CELL_PIXELS, 1),
  126. (1 , 1)
  127. ])
  128. class Lava(WorldObj):
  129. def __init__(self):
  130. super().__init__('lava', 'red')
  131. def can_overlap(self):
  132. return True
  133. def render(self, r):
  134. orange = 255, 128, 0
  135. r.setLineColor(*orange)
  136. r.setColor(*orange)
  137. r.drawPolygon([
  138. (0 , CELL_PIXELS),
  139. (CELL_PIXELS, CELL_PIXELS),
  140. (CELL_PIXELS, 0),
  141. (0 , 0)
  142. ])
  143. # drawing the waves
  144. r.setLineColor(0, 0, 0)
  145. r.drawPolyline([
  146. (.1 * CELL_PIXELS, .3 * CELL_PIXELS),
  147. (.3 * CELL_PIXELS, .4 * CELL_PIXELS),
  148. (.5 * CELL_PIXELS, .3 * CELL_PIXELS),
  149. (.7 * CELL_PIXELS, .4 * CELL_PIXELS),
  150. (.9 * CELL_PIXELS, .3 * CELL_PIXELS),
  151. ])
  152. r.drawPolyline([
  153. (.1 * CELL_PIXELS, .5 * CELL_PIXELS),
  154. (.3 * CELL_PIXELS, .6 * CELL_PIXELS),
  155. (.5 * CELL_PIXELS, .5 * CELL_PIXELS),
  156. (.7 * CELL_PIXELS, .6 * CELL_PIXELS),
  157. (.9 * CELL_PIXELS, .5 * CELL_PIXELS),
  158. ])
  159. r.drawPolyline([
  160. (.1 * CELL_PIXELS, .7 * CELL_PIXELS),
  161. (.3 * CELL_PIXELS, .8 * CELL_PIXELS),
  162. (.5 * CELL_PIXELS, .7 * CELL_PIXELS),
  163. (.7 * CELL_PIXELS, .8 * CELL_PIXELS),
  164. (.9 * CELL_PIXELS, .7 * CELL_PIXELS),
  165. ])
  166. class Wall(WorldObj):
  167. def __init__(self, color='grey'):
  168. super().__init__('wall', color)
  169. def see_behind(self):
  170. return False
  171. def render(self, r):
  172. self._set_color(r)
  173. r.drawPolygon([
  174. (0 , CELL_PIXELS),
  175. (CELL_PIXELS, CELL_PIXELS),
  176. (CELL_PIXELS, 0),
  177. (0 , 0)
  178. ])
  179. class Door(WorldObj):
  180. def __init__(self, color, is_open=False):
  181. super().__init__('door', color)
  182. self.is_open = is_open
  183. def can_overlap(self):
  184. """The agent can only walk over this cell when the door is open"""
  185. return self.is_open
  186. def see_behind(self):
  187. return self.is_open
  188. def toggle(self, env, pos):
  189. self.is_open = not self.is_open
  190. return True
  191. def render(self, r):
  192. c = COLORS[self.color]
  193. r.setLineColor(c[0], c[1], c[2])
  194. r.setColor(0, 0, 0)
  195. if self.is_open:
  196. r.drawPolygon([
  197. (CELL_PIXELS-2, CELL_PIXELS),
  198. (CELL_PIXELS , CELL_PIXELS),
  199. (CELL_PIXELS , 0),
  200. (CELL_PIXELS-2, 0)
  201. ])
  202. return
  203. r.drawPolygon([
  204. (0 , CELL_PIXELS),
  205. (CELL_PIXELS, CELL_PIXELS),
  206. (CELL_PIXELS, 0),
  207. (0 , 0)
  208. ])
  209. r.drawPolygon([
  210. (2 , CELL_PIXELS-2),
  211. (CELL_PIXELS-2, CELL_PIXELS-2),
  212. (CELL_PIXELS-2, 2),
  213. (2 , 2)
  214. ])
  215. r.drawCircle(CELL_PIXELS * 0.75, CELL_PIXELS * 0.5, 2)
  216. class LockedDoor(WorldObj):
  217. def __init__(self, color, is_open=False):
  218. super(LockedDoor, self).__init__('locked_door', color)
  219. self.is_open = is_open
  220. def toggle(self, env, pos):
  221. # If the player has the right key to open the door
  222. if isinstance(env.carrying, Key) and env.carrying.color == self.color:
  223. self.is_open = True
  224. # The key has been used, remove it from the agent
  225. env.carrying = None
  226. return True
  227. return False
  228. def can_overlap(self):
  229. """The agent can only walk over this cell when the door is open"""
  230. return self.is_open
  231. def see_behind(self):
  232. return self.is_open
  233. def render(self, r):
  234. c = COLORS[self.color]
  235. r.setLineColor(c[0], c[1], c[2])
  236. r.setColor(c[0], c[1], c[2], 50)
  237. if self.is_open:
  238. r.drawPolygon([
  239. (CELL_PIXELS-2, CELL_PIXELS),
  240. (CELL_PIXELS , CELL_PIXELS),
  241. (CELL_PIXELS , 0),
  242. (CELL_PIXELS-2, 0)
  243. ])
  244. return
  245. r.drawPolygon([
  246. (0 , CELL_PIXELS),
  247. (CELL_PIXELS, CELL_PIXELS),
  248. (CELL_PIXELS, 0),
  249. (0 , 0)
  250. ])
  251. r.drawPolygon([
  252. (2 , CELL_PIXELS-2),
  253. (CELL_PIXELS-2, CELL_PIXELS-2),
  254. (CELL_PIXELS-2, 2),
  255. (2 , 2)
  256. ])
  257. r.drawLine(
  258. CELL_PIXELS * 0.55,
  259. CELL_PIXELS * 0.5,
  260. CELL_PIXELS * 0.75,
  261. CELL_PIXELS * 0.5
  262. )
  263. class Key(WorldObj):
  264. def __init__(self, color='blue'):
  265. super(Key, self).__init__('key', color)
  266. def can_pickup(self):
  267. return True
  268. def render(self, r):
  269. self._set_color(r)
  270. # Vertical quad
  271. r.drawPolygon([
  272. (16, 10),
  273. (20, 10),
  274. (20, 28),
  275. (16, 28)
  276. ])
  277. # Teeth
  278. r.drawPolygon([
  279. (12, 19),
  280. (16, 19),
  281. (16, 21),
  282. (12, 21)
  283. ])
  284. r.drawPolygon([
  285. (12, 26),
  286. (16, 26),
  287. (16, 28),
  288. (12, 28)
  289. ])
  290. r.drawCircle(18, 9, 6)
  291. r.setLineColor(0, 0, 0)
  292. r.setColor(0, 0, 0)
  293. r.drawCircle(18, 9, 2)
  294. class Ball(WorldObj):
  295. def __init__(self, color='blue'):
  296. super(Ball, self).__init__('ball', color)
  297. def can_pickup(self):
  298. return True
  299. def render(self, r):
  300. self._set_color(r)
  301. r.drawCircle(CELL_PIXELS * 0.5, CELL_PIXELS * 0.5, 10)
  302. class Box(WorldObj):
  303. def __init__(self, color, contains=None):
  304. super(Box, self).__init__('box', color)
  305. self.contains = contains
  306. def can_pickup(self):
  307. return True
  308. def render(self, r):
  309. c = COLORS[self.color]
  310. r.setLineColor(c[0], c[1], c[2])
  311. r.setColor(0, 0, 0)
  312. r.setLineWidth(2)
  313. r.drawPolygon([
  314. (4 , CELL_PIXELS-4),
  315. (CELL_PIXELS-4, CELL_PIXELS-4),
  316. (CELL_PIXELS-4, 4),
  317. (4 , 4)
  318. ])
  319. r.drawLine(
  320. 4,
  321. CELL_PIXELS / 2,
  322. CELL_PIXELS - 4,
  323. CELL_PIXELS / 2
  324. )
  325. r.setLineWidth(1)
  326. def toggle(self, env, pos):
  327. # Replace the box by its contents
  328. env.grid.set(*pos, self.contains)
  329. return True
  330. class Grid:
  331. """
  332. Represent a grid and operations on it
  333. """
  334. def __init__(self, width, height):
  335. assert width >= 4
  336. assert height >= 4
  337. self.width = width
  338. self.height = height
  339. self.grid = [None] * width * height
  340. def __contains__(self, key):
  341. if isinstance(key, WorldObj):
  342. for e in self.grid:
  343. if e is key:
  344. return True
  345. elif isinstance(key, tuple):
  346. for e in self.grid:
  347. if e is None:
  348. continue
  349. if (e.color, e.type) == key:
  350. return True
  351. if key[0] is None and key[1] == e.type:
  352. return True
  353. return False
  354. def __eq__(self, other):
  355. grid1 = self.encode()
  356. grid2 = other.encode()
  357. return np.array_equal(grid2, grid1)
  358. def __ne__(self, other):
  359. return not self == other
  360. def copy(self):
  361. from copy import deepcopy
  362. return deepcopy(self)
  363. def set(self, i, j, v):
  364. assert i >= 0 and i < self.width
  365. assert j >= 0 and j < self.height
  366. self.grid[j * self.width + i] = v
  367. def get(self, i, j):
  368. assert i >= 0 and i < self.width
  369. assert j >= 0 and j < self.height
  370. return self.grid[j * self.width + i]
  371. def horz_wall(self, x, y, length=None):
  372. if length is None:
  373. length = self.width - x
  374. for i in range(0, length):
  375. self.set(x + i, y, Wall())
  376. def vert_wall(self, x, y, length=None):
  377. if length is None:
  378. length = self.height - y
  379. for j in range(0, length):
  380. self.set(x, y + j, Wall())
  381. def wall_rect(self, x, y, w, h):
  382. self.horz_wall(x, y, w)
  383. self.horz_wall(x, y+h-1, w)
  384. self.vert_wall(x, y, h)
  385. self.vert_wall(x+w-1, y, h)
  386. def rotate_left(self):
  387. """
  388. Rotate the grid to the left (counter-clockwise)
  389. """
  390. grid = Grid(self.height, self.width)
  391. for i in range(self.width):
  392. for j in range(self.height):
  393. v = self.get(i, j)
  394. grid.set(j, grid.height - 1 - i, v)
  395. return grid
  396. def slice(self, topX, topY, width, height):
  397. """
  398. Get a subset of the grid
  399. """
  400. grid = Grid(width, height)
  401. for j in range(0, height):
  402. for i in range(0, width):
  403. x = topX + i
  404. y = topY + j
  405. if x >= 0 and x < self.width and \
  406. y >= 0 and y < self.height:
  407. v = self.get(x, y)
  408. else:
  409. v = Wall()
  410. grid.set(i, j, v)
  411. return grid
  412. def render(self, r, tile_size):
  413. """
  414. Render this grid at a given scale
  415. :param r: target renderer object
  416. :param tile_size: tile size in pixels
  417. """
  418. assert r.width == self.width * tile_size
  419. assert r.height == self.height * tile_size
  420. # Total grid size at native scale
  421. widthPx = self.width * CELL_PIXELS
  422. heightPx = self.height * CELL_PIXELS
  423. r.push()
  424. # Internally, we draw at the "large" full-grid resolution, but we
  425. # use the renderer to scale back to the desired size
  426. r.scale(tile_size / CELL_PIXELS, tile_size / CELL_PIXELS)
  427. # Draw the background of the in-world cells black
  428. r.fillRect(
  429. 0,
  430. 0,
  431. widthPx,
  432. heightPx,
  433. 0, 0, 0
  434. )
  435. # Draw grid lines
  436. r.setLineColor(100, 100, 100)
  437. for rowIdx in range(0, self.height):
  438. y = CELL_PIXELS * rowIdx
  439. r.drawLine(0, y, widthPx, y)
  440. for colIdx in range(0, self.width):
  441. x = CELL_PIXELS * colIdx
  442. r.drawLine(x, 0, x, heightPx)
  443. # Render the grid
  444. for j in range(0, self.height):
  445. for i in range(0, self.width):
  446. cell = self.get(i, j)
  447. if cell == None:
  448. continue
  449. r.push()
  450. r.translate(i * CELL_PIXELS, j * CELL_PIXELS)
  451. cell.render(r)
  452. r.pop()
  453. r.pop()
  454. def encode(self, vis_mask=None):
  455. """
  456. Produce a compact numpy encoding of the grid
  457. """
  458. if vis_mask is None:
  459. vis_mask = np.ones((self.width, self.height), dtype=bool)
  460. array = np.zeros((self.width, self.height, 3), dtype='uint8')
  461. for i in range(self.width):
  462. for j in range(self.height):
  463. if vis_mask[i, j]:
  464. v = self.get(i, j)
  465. if v is None:
  466. array[i, j, 0] = OBJECT_TO_IDX['empty']
  467. array[i, j, 1] = 0
  468. array[i, j, 2] = 0
  469. else:
  470. array[i, j, 0] = OBJECT_TO_IDX[v.type]
  471. array[i, j, 1] = COLOR_TO_IDX[v.color]
  472. array[i, j, 2] = hasattr(v, 'is_open') and v.is_open
  473. return array
  474. @staticmethod
  475. def decode(array):
  476. """
  477. Decode an array grid encoding back into a grid
  478. """
  479. width, height, channels = array.shape
  480. assert channels == 3
  481. grid = Grid(width, height)
  482. for i in range(width):
  483. for j in range(height):
  484. typeIdx, colorIdx, openIdx = array[i, j]
  485. if typeIdx == OBJECT_TO_IDX['unseen'] or \
  486. typeIdx == OBJECT_TO_IDX['empty']:
  487. continue
  488. objType = IDX_TO_OBJECT[typeIdx]
  489. color = IDX_TO_COLOR[colorIdx]
  490. is_open = openIdx == 1
  491. if objType == 'wall':
  492. v = Wall(color)
  493. elif objType == 'floor':
  494. v = Floor(color)
  495. elif objType == 'ball':
  496. v = Ball(color)
  497. elif objType == 'key':
  498. v = Key(color)
  499. elif objType == 'box':
  500. v = Box(color)
  501. elif objType == 'door':
  502. v = Door(color, is_open)
  503. elif objType == 'locked_door':
  504. v = LockedDoor(color, is_open)
  505. elif objType == 'goal':
  506. v = Goal()
  507. elif objType == 'lava':
  508. v = Lava()
  509. else:
  510. assert False, "unknown obj type in decode '%s'" % objType
  511. grid.set(i, j, v)
  512. return grid
  513. def process_vis(grid, agent_pos):
  514. mask = np.zeros(shape=(grid.width, grid.height), dtype=np.bool)
  515. mask[agent_pos[0], agent_pos[1]] = True
  516. for j in reversed(range(0, grid.height)):
  517. for i in range(0, grid.width-1):
  518. if not mask[i, j]:
  519. continue
  520. cell = grid.get(i, j)
  521. if cell and not cell.see_behind():
  522. continue
  523. mask[i+1, j] = True
  524. if j > 0:
  525. mask[i+1, j-1] = True
  526. mask[i, j-1] = True
  527. for i in reversed(range(1, grid.width)):
  528. if not mask[i, j]:
  529. continue
  530. cell = grid.get(i, j)
  531. if cell and not cell.see_behind():
  532. continue
  533. mask[i-1, j] = True
  534. if j > 0:
  535. mask[i-1, j-1] = True
  536. mask[i, j-1] = True
  537. for j in range(0, grid.height):
  538. for i in range(0, grid.width):
  539. if not mask[i, j]:
  540. grid.set(i, j, None)
  541. return mask
  542. class MiniGridEnv(gym.Env):
  543. """
  544. 2D grid world game environment
  545. """
  546. metadata = {
  547. 'render.modes': ['human', 'rgb_array', 'pixmap'],
  548. 'video.frames_per_second' : 10
  549. }
  550. # Enumeration of possible actions
  551. class Actions(IntEnum):
  552. # Turn left, turn right, move forward
  553. left = 0
  554. right = 1
  555. forward = 2
  556. # Pick up an object
  557. pickup = 3
  558. # Drop an object
  559. drop = 4
  560. # Toggle/activate an object
  561. toggle = 5
  562. # Done completing task
  563. done = 6
  564. def __init__(
  565. self,
  566. grid_size=16,
  567. max_steps=100,
  568. see_through_walls=False,
  569. seed=1337
  570. ):
  571. # Action enumeration for this environment
  572. self.actions = MiniGridEnv.Actions
  573. # Actions are discrete integer values
  574. self.action_space = spaces.Discrete(len(self.actions))
  575. # Observations are dictionaries containing an
  576. # encoding of the grid and a textual 'mission' string
  577. self.observation_space = spaces.Box(
  578. low=0,
  579. high=255,
  580. shape=OBS_ARRAY_SIZE,
  581. dtype='uint8'
  582. )
  583. self.observation_space = spaces.Dict({
  584. 'image': self.observation_space
  585. })
  586. # Range of possible rewards
  587. self.reward_range = (0, 1)
  588. # Renderer object used to render the whole grid (full-scale)
  589. self.grid_render = None
  590. # Renderer used to render observations (small-scale agent view)
  591. self.obs_render = None
  592. # Environment configuration
  593. self.grid_size = grid_size
  594. self.max_steps = max_steps
  595. self.see_through_walls = see_through_walls
  596. # Starting position and direction for the agent
  597. self.start_pos = None
  598. self.start_dir = None
  599. # Initialize the RNG
  600. self.seed(seed=seed)
  601. # Initialize the state
  602. self.reset()
  603. def reset(self):
  604. # Generate a new random grid at the start of each episode
  605. # To keep the same grid for each episode, call env.seed() with
  606. # the same seed before calling env.reset()
  607. self._gen_grid(self.grid_size, self.grid_size)
  608. # These fields should be defined by _gen_grid
  609. assert self.start_pos is not None
  610. assert self.start_dir is not None
  611. # Check that the agent doesn't overlap with an object
  612. start_cell = self.grid.get(*self.start_pos)
  613. assert start_cell is None or start_cell.can_overlap()
  614. # Place the agent in the starting position and direction
  615. self.agent_pos = self.start_pos
  616. self.agent_dir = self.start_dir
  617. # Item picked up, being carried, initially nothing
  618. self.carrying = None
  619. # Step count since episode start
  620. self.step_count = 0
  621. # Return first observation
  622. obs = self.gen_obs()
  623. return obs
  624. def seed(self, seed=1337):
  625. # Seed the random number generator
  626. self.np_random, _ = seeding.np_random(seed)
  627. return [seed]
  628. @property
  629. def steps_remaining(self):
  630. return self.max_steps - self.step_count
  631. def __str__(self):
  632. """
  633. Produce a pretty string of the environment's grid along with the agent.
  634. The agent is represented by `⏩`. A grid pixel is represented by 2-character
  635. string, the first one for the object and the second one for the color.
  636. """
  637. from copy import deepcopy
  638. def rotate_left(array):
  639. new_array = deepcopy(array)
  640. for i in range(len(array)):
  641. for j in range(len(array[0])):
  642. new_array[j][len(array[0])-1-i] = array[i][j]
  643. return new_array
  644. def vertically_symmetrize(array):
  645. new_array = deepcopy(array)
  646. for i in range(len(array)):
  647. for j in range(len(array[0])):
  648. new_array[i][len(array[0])-1-j] = array[i][j]
  649. return new_array
  650. # Map of object id to short string
  651. OBJECT_IDX_TO_IDS = {
  652. 0: ' ',
  653. 1: 'W',
  654. 2: 'D',
  655. 3: 'L',
  656. 4: 'K',
  657. 5: 'B',
  658. 6: 'X',
  659. 7: 'G'
  660. }
  661. # Short string for opened door
  662. OPENDED_DOOR_IDS = '_'
  663. # Map of color id to short string
  664. COLOR_IDX_TO_IDS = {
  665. 0: 'R',
  666. 1: 'G',
  667. 2: 'B',
  668. 3: 'P',
  669. 4: 'Y',
  670. 5: 'E'
  671. }
  672. # Map agent's direction to short string
  673. AGENT_DIR_TO_IDS = {
  674. 0: '⏩ ',
  675. 1: '⏬ ',
  676. 2: '⏪ ',
  677. 3: '⏫ '
  678. }
  679. array = self.grid.encode()
  680. array = rotate_left(array)
  681. array = vertically_symmetrize(array)
  682. new_array = []
  683. for line in array:
  684. new_line = []
  685. for pixel in line:
  686. # If the door is opened
  687. if pixel[0] in [2, 3] and pixel[2] == 1:
  688. object_ids = OPENDED_DOOR_IDS
  689. else:
  690. object_ids = OBJECT_IDX_TO_IDS[pixel[0]]
  691. # If no object
  692. if pixel[0] == 0:
  693. color_ids = ' '
  694. else:
  695. color_ids = COLOR_IDX_TO_IDS[pixel[1]]
  696. new_line.append(object_ids + color_ids)
  697. new_array.append(new_line)
  698. # Add the agent
  699. new_array[self.agent_pos[1]][self.agent_pos[0]] = AGENT_DIR_TO_IDS[self.agent_dir]
  700. return "\n".join([" ".join(line) for line in new_array])
  701. def _gen_grid(self, width, height):
  702. assert False, "_gen_grid needs to be implemented by each environment"
  703. def _reward(self):
  704. """
  705. Compute the reward to be given upon success
  706. """
  707. return 1 - 0.9 * (self.step_count / self.max_steps)
  708. def _rand_int(self, low, high):
  709. """
  710. Generate random integer in [low,high[
  711. """
  712. return self.np_random.randint(low, high)
  713. def _rand_float(self, low, high):
  714. """
  715. Generate random float in [low,high[
  716. """
  717. return self.np_random.uniform(low, high)
  718. def _rand_bool(self):
  719. """
  720. Generate random boolean value
  721. """
  722. return (self.np_random.randint(0, 2) == 0)
  723. def _rand_elem(self, iterable):
  724. """
  725. Pick a random element in a list
  726. """
  727. lst = list(iterable)
  728. idx = self._rand_int(0, len(lst))
  729. return lst[idx]
  730. def _rand_subset(self, iterable, num_elems):
  731. """
  732. Sample a random subset of distinct elements of a list
  733. """
  734. lst = list(iterable)
  735. assert num_elems <= len(lst)
  736. out = []
  737. while len(out) < num_elems:
  738. elem = self._rand_elem(lst)
  739. lst.remove(elem)
  740. out.append(elem)
  741. return out
  742. def _rand_color(self):
  743. """
  744. Generate a random color name (string)
  745. """
  746. return self._rand_elem(COLOR_NAMES)
  747. def _rand_pos(self, xLow, xHigh, yLow, yHigh):
  748. """
  749. Generate a random (x,y) position tuple
  750. """
  751. return (
  752. self.np_random.randint(xLow, xHigh),
  753. self.np_random.randint(yLow, yHigh)
  754. )
  755. def place_obj(self,
  756. obj,
  757. top=None,
  758. size=None,
  759. reject_fn=None,
  760. max_tries=math.inf
  761. ):
  762. """
  763. Place an object at an empty position in the grid
  764. :param top: top-left position of the rectangle where to place
  765. :param size: size of the rectangle where to place
  766. :param reject_fn: function to filter out potential positions
  767. """
  768. if top is None:
  769. top = (0, 0)
  770. if size is None:
  771. size = (self.grid.width, self.grid.height)
  772. num_tries = 0
  773. while True:
  774. # This is to handle with rare cases where rejection sampling
  775. # gets stuck in an infinite loop
  776. if num_tries > max_tries:
  777. raise RecursionError('rejection sampling failed in place_obj')
  778. num_tries += 1
  779. pos = np.array((
  780. self._rand_int(top[0], top[0] + size[0]),
  781. self._rand_int(top[1], top[1] + size[1])
  782. ))
  783. # Don't place the object on top of another object
  784. if self.grid.get(*pos) != None:
  785. continue
  786. # Don't place the object where the agent is
  787. if np.array_equal(pos, self.start_pos):
  788. continue
  789. # Check if there is a filtering criterion
  790. if reject_fn and reject_fn(self, pos):
  791. continue
  792. break
  793. self.grid.set(*pos, obj)
  794. if obj is not None:
  795. obj.init_pos = pos
  796. obj.cur_pos = pos
  797. return pos
  798. def place_agent(
  799. self,
  800. top=None,
  801. size=None,
  802. rand_dir=True,
  803. max_tries=math.inf
  804. ):
  805. """
  806. Set the agent's starting point at an empty position in the grid
  807. """
  808. self.start_pos = None
  809. pos = self.place_obj(None, top, size, max_tries=max_tries)
  810. self.start_pos = pos
  811. if rand_dir:
  812. self.start_dir = self._rand_int(0, 4)
  813. return pos
  814. @property
  815. def dir_vec(self):
  816. """
  817. Get the direction vector for the agent, pointing in the direction
  818. of forward movement.
  819. """
  820. assert self.agent_dir >= 0 and self.agent_dir < 4
  821. return DIR_TO_VEC[self.agent_dir]
  822. @property
  823. def right_vec(self):
  824. """
  825. Get the vector pointing to the right of the agent.
  826. """
  827. dx, dy = self.dir_vec
  828. return np.array((-dy, dx))
  829. @property
  830. def front_pos(self):
  831. """
  832. Get the position of the cell that is right in front of the agent
  833. """
  834. return self.agent_pos + self.dir_vec
  835. def get_view_coords(self, i, j):
  836. """
  837. Translate and rotate absolute grid coordinates (i, j) into the
  838. agent's partially observable view (sub-grid). Note that the resulting
  839. coordinates may be negative or outside of the agent's view size.
  840. """
  841. ax, ay = self.agent_pos
  842. dx, dy = self.dir_vec
  843. rx, ry = self.right_vec
  844. # Compute the absolute coordinates of the top-left view corner
  845. sz = AGENT_VIEW_SIZE
  846. hs = AGENT_VIEW_SIZE // 2
  847. tx = ax + (dx * (sz-1)) - (rx * hs)
  848. ty = ay + (dy * (sz-1)) - (ry * hs)
  849. lx = i - tx
  850. ly = j - ty
  851. # Project the coordinates of the object relative to the top-left
  852. # corner onto the agent's own coordinate system
  853. vx = (rx*lx + ry*ly)
  854. vy = -(dx*lx + dy*ly)
  855. return vx, vy
  856. def get_view_exts(self):
  857. """
  858. Get the extents of the square set of tiles visible to the agent
  859. Note: the bottom extent indices are not included in the set
  860. """
  861. # Facing right
  862. if self.agent_dir == 0:
  863. topX = self.agent_pos[0]
  864. topY = self.agent_pos[1] - AGENT_VIEW_SIZE // 2
  865. # Facing down
  866. elif self.agent_dir == 1:
  867. topX = self.agent_pos[0] - AGENT_VIEW_SIZE // 2
  868. topY = self.agent_pos[1]
  869. # Facing left
  870. elif self.agent_dir == 2:
  871. topX = self.agent_pos[0] - AGENT_VIEW_SIZE + 1
  872. topY = self.agent_pos[1] - AGENT_VIEW_SIZE // 2
  873. # Facing up
  874. elif self.agent_dir == 3:
  875. topX = self.agent_pos[0] - AGENT_VIEW_SIZE // 2
  876. topY = self.agent_pos[1] - AGENT_VIEW_SIZE + 1
  877. else:
  878. assert False, "invalid agent direction"
  879. botX = topX + AGENT_VIEW_SIZE
  880. botY = topY + AGENT_VIEW_SIZE
  881. return (topX, topY, botX, botY)
  882. def relative_coords(self, x, y):
  883. """
  884. Check if a grid position belongs to the agent's field of view, and returns the corresponding coordinates
  885. """
  886. vx, vy = self.get_view_coords(x, y)
  887. if vx < 0 or vy < 0 or vx >= AGENT_VIEW_SIZE or vy >= AGENT_VIEW_SIZE:
  888. return None
  889. return vx, vy
  890. def in_view(self, x, y):
  891. """
  892. check if a grid position is visible to the agent
  893. """
  894. return self.relative_coords(x, y) is not None
  895. def agent_sees(self, x, y):
  896. """
  897. Check if a non-empty grid position is visible to the agent
  898. """
  899. coordinates = self.relative_coords(x, y)
  900. if coordinates is None:
  901. return False
  902. vx, vy = coordinates
  903. obs = self.gen_obs()
  904. obs_grid = Grid.decode(obs['image'])
  905. obs_cell = obs_grid.get(vx, vy)
  906. world_cell = self.grid.get(x, y)
  907. return obs_cell is not None and obs_cell.type == world_cell.type
  908. def step(self, action):
  909. self.step_count += 1
  910. reward = 0
  911. done = False
  912. # Get the position in front of the agent
  913. fwd_pos = self.front_pos
  914. # Get the contents of the cell in front of the agent
  915. fwd_cell = self.grid.get(*fwd_pos)
  916. # Rotate left
  917. if action == self.actions.left:
  918. self.agent_dir -= 1
  919. if self.agent_dir < 0:
  920. self.agent_dir += 4
  921. # Rotate right
  922. elif action == self.actions.right:
  923. self.agent_dir = (self.agent_dir + 1) % 4
  924. # Move forward
  925. elif action == self.actions.forward:
  926. if fwd_cell == None or fwd_cell.can_overlap():
  927. self.agent_pos = fwd_pos
  928. if fwd_cell != None and fwd_cell.type == 'goal':
  929. done = True
  930. reward = self._reward()
  931. if fwd_cell != None and fwd_cell.type == 'lava':
  932. done = True
  933. # Pick up an object
  934. elif action == self.actions.pickup:
  935. if fwd_cell and fwd_cell.can_pickup():
  936. if self.carrying is None:
  937. self.carrying = fwd_cell
  938. self.carrying.cur_pos = np.array([-1, -1])
  939. self.grid.set(*fwd_pos, None)
  940. # Drop an object
  941. elif action == self.actions.drop:
  942. if not fwd_cell and self.carrying:
  943. self.grid.set(*fwd_pos, self.carrying)
  944. self.carrying.cur_pos = fwd_pos
  945. self.carrying = None
  946. # Toggle/activate an object
  947. elif action == self.actions.toggle:
  948. if fwd_cell:
  949. fwd_cell.toggle(self, fwd_pos)
  950. # Done action (not used by default)
  951. elif action == self.actions.done:
  952. pass
  953. else:
  954. assert False, "unknown action"
  955. if self.step_count >= self.max_steps:
  956. done = True
  957. obs = self.gen_obs()
  958. return obs, reward, done, {}
  959. def gen_obs_grid(self):
  960. """
  961. Generate the sub-grid observed by the agent.
  962. This method also outputs a visibility mask telling us which grid
  963. cells the agent can actually see.
  964. """
  965. topX, topY, botX, botY = self.get_view_exts()
  966. grid = self.grid.slice(topX, topY, AGENT_VIEW_SIZE, AGENT_VIEW_SIZE)
  967. for i in range(self.agent_dir + 1):
  968. grid = grid.rotate_left()
  969. # Process occluders and visibility
  970. # Note that this incurs some performance cost
  971. if not self.see_through_walls:
  972. vis_mask = grid.process_vis(agent_pos=(AGENT_VIEW_SIZE // 2 , AGENT_VIEW_SIZE - 1))
  973. else:
  974. vis_mask = np.ones(shape=(grid.width, grid.height), dtype=np.bool)
  975. # Make it so the agent sees what it's carrying
  976. # We do this by placing the carried object at the agent's position
  977. # in the agent's partially observable view
  978. agent_pos = grid.width // 2, grid.height - 1
  979. if self.carrying:
  980. grid.set(*agent_pos, self.carrying)
  981. else:
  982. grid.set(*agent_pos, None)
  983. return grid, vis_mask
  984. def gen_obs(self):
  985. """
  986. Generate the agent's view (partially observable, low-resolution encoding)
  987. """
  988. grid, vis_mask = self.gen_obs_grid()
  989. # Encode the partially observable view into a numpy array
  990. image = grid.encode(vis_mask)
  991. assert hasattr(self, 'mission'), "environments must define a textual mission string"
  992. # Observations are dictionaries containing:
  993. # - an image (partially observable view of the environment)
  994. # - the agent's direction/orientation (acting as a compass)
  995. # - a textual mission string (instructions for the agent)
  996. obs = {
  997. 'image': image,
  998. 'direction': self.agent_dir,
  999. 'mission': self.mission
  1000. }
  1001. return obs
  1002. def get_obs_render(self, obs, tile_pixels=CELL_PIXELS//2):
  1003. """
  1004. Render an agent observation for visualization
  1005. """
  1006. if self.obs_render == None:
  1007. from gym_minigrid.rendering import Renderer
  1008. self.obs_render = Renderer(
  1009. AGENT_VIEW_SIZE * tile_pixels,
  1010. AGENT_VIEW_SIZE * tile_pixels
  1011. )
  1012. r = self.obs_render
  1013. r.beginFrame()
  1014. grid = Grid.decode(obs)
  1015. # Render the whole grid
  1016. grid.render(r, tile_pixels)
  1017. # Draw the agent
  1018. ratio = tile_pixels / CELL_PIXELS
  1019. r.push()
  1020. r.scale(ratio, ratio)
  1021. r.translate(
  1022. CELL_PIXELS * (0.5 + AGENT_VIEW_SIZE // 2),
  1023. CELL_PIXELS * (AGENT_VIEW_SIZE - 0.5)
  1024. )
  1025. r.rotate(3 * 90)
  1026. r.setLineColor(255, 0, 0)
  1027. r.setColor(255, 0, 0)
  1028. r.drawPolygon([
  1029. (-12, 10),
  1030. ( 12, 0),
  1031. (-12, -10)
  1032. ])
  1033. r.pop()
  1034. r.endFrame()
  1035. return r.getPixmap()
  1036. def render(self, mode='human', close=False):
  1037. """
  1038. Render the whole-grid human view
  1039. """
  1040. if close:
  1041. if self.grid_render:
  1042. self.grid_render.close()
  1043. return
  1044. if self.grid_render is None:
  1045. from gym_minigrid.rendering import Renderer
  1046. self.grid_render = Renderer(
  1047. self.grid_size * CELL_PIXELS,
  1048. self.grid_size * CELL_PIXELS,
  1049. True if mode == 'human' else False
  1050. )
  1051. r = self.grid_render
  1052. if r.window:
  1053. r.window.setText(self.mission)
  1054. r.beginFrame()
  1055. # Render the whole grid
  1056. self.grid.render(r, CELL_PIXELS)
  1057. # Draw the agent
  1058. r.push()
  1059. r.translate(
  1060. CELL_PIXELS * (self.agent_pos[0] + 0.5),
  1061. CELL_PIXELS * (self.agent_pos[1] + 0.5)
  1062. )
  1063. r.rotate(self.agent_dir * 90)
  1064. r.setLineColor(255, 0, 0)
  1065. r.setColor(255, 0, 0)
  1066. r.drawPolygon([
  1067. (-12, 10),
  1068. ( 12, 0),
  1069. (-12, -10)
  1070. ])
  1071. r.pop()
  1072. # Compute which cells are visible to the agent
  1073. _, vis_mask = self.gen_obs_grid()
  1074. # Compute the absolute coordinates of the bottom-left corner
  1075. # of the agent's view area
  1076. f_vec = self.dir_vec
  1077. r_vec = self.right_vec
  1078. top_left = self.agent_pos + f_vec * (AGENT_VIEW_SIZE-1) - r_vec * (AGENT_VIEW_SIZE // 2)
  1079. # For each cell in the visibility mask
  1080. for vis_j in range(0, AGENT_VIEW_SIZE):
  1081. for vis_i in range(0, AGENT_VIEW_SIZE):
  1082. # If this cell is not visible, don't highlight it
  1083. if not vis_mask[vis_i, vis_j]:
  1084. continue
  1085. # Compute the world coordinates of this cell
  1086. abs_i, abs_j = top_left - (f_vec * vis_j) + (r_vec * vis_i)
  1087. # Highlight the cell
  1088. r.fillRect(
  1089. abs_i * CELL_PIXELS,
  1090. abs_j * CELL_PIXELS,
  1091. CELL_PIXELS,
  1092. CELL_PIXELS,
  1093. 255, 255, 255, 75
  1094. )
  1095. r.endFrame()
  1096. if mode == 'rgb_array':
  1097. return r.getArray()
  1098. elif mode == 'pixmap':
  1099. return r.getPixmap()
  1100. return r