minigrid.py 36 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336
  1. import math
  2. import gym
  3. from enum import IntEnum
  4. import numpy as np
  5. from gym import error, spaces, utils
  6. from gym.utils import seeding
  7. # Size in pixels of a cell in the full-scale human view
  8. CELL_PIXELS = 32
  9. # Map of color names to RGB values
  10. COLORS = {
  11. 'red' : np.array([255, 0, 0]),
  12. 'green' : np.array([0, 255, 0]),
  13. 'blue' : np.array([0, 0, 255]),
  14. 'purple': np.array([112, 39, 195]),
  15. 'yellow': np.array([255, 255, 0]),
  16. 'grey' : np.array([100, 100, 100])
  17. }
  18. COLOR_NAMES = sorted(list(COLORS.keys()))
  19. # Used to map colors to integers
  20. COLOR_TO_IDX = {
  21. 'red' : 0,
  22. 'green' : 1,
  23. 'blue' : 2,
  24. 'purple': 3,
  25. 'yellow': 4,
  26. 'grey' : 5
  27. }
  28. IDX_TO_COLOR = dict(zip(COLOR_TO_IDX.values(), COLOR_TO_IDX.keys()))
  29. # Map of object type to integers
  30. OBJECT_TO_IDX = {
  31. 'unseen' : 0,
  32. 'empty' : 1,
  33. 'wall' : 2,
  34. 'floor' : 3,
  35. 'door' : 4,
  36. 'key' : 5,
  37. 'ball' : 6,
  38. 'box' : 7,
  39. 'goal' : 8,
  40. 'lava' : 9
  41. }
  42. IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys()))
  43. # Map of agent direction indices to vectors
  44. DIR_TO_VEC = [
  45. # Pointing right (positive X)
  46. np.array((1, 0)),
  47. # Down (positive Y)
  48. np.array((0, 1)),
  49. # Pointing left (negative X)
  50. np.array((-1, 0)),
  51. # Up (negative Y)
  52. np.array((0, -1)),
  53. ]
  54. class WorldObj:
  55. """
  56. Base class for grid world objects
  57. """
  58. def __init__(self, type, color):
  59. assert type in OBJECT_TO_IDX, type
  60. assert color in COLOR_TO_IDX, color
  61. self.type = type
  62. self.color = color
  63. self.contains = None
  64. # Initial position of the object
  65. self.init_pos = None
  66. # Current position of the object
  67. self.cur_pos = None
  68. def can_overlap(self):
  69. """Can the agent overlap with this?"""
  70. return False
  71. def can_pickup(self):
  72. """Can the agent pick this up?"""
  73. return False
  74. def can_contain(self):
  75. """Can this contain another object?"""
  76. return False
  77. def see_behind(self):
  78. """Can the agent see behind this object?"""
  79. return True
  80. def toggle(self, env, pos):
  81. """Method to trigger/toggle an action this object performs"""
  82. return False
  83. def render(self, r):
  84. """Draw this object with the given renderer"""
  85. raise NotImplementedError
  86. def _set_color(self, r):
  87. """Set the color of this object as the active drawing color"""
  88. c = COLORS[self.color]
  89. r.setLineColor(c[0], c[1], c[2])
  90. r.setColor(c[0], c[1], c[2])
  91. class Goal(WorldObj):
  92. def __init__(self):
  93. super().__init__('goal', 'green')
  94. def can_overlap(self):
  95. return True
  96. def render(self, r):
  97. self._set_color(r)
  98. r.drawPolygon([
  99. (0 , CELL_PIXELS),
  100. (CELL_PIXELS, CELL_PIXELS),
  101. (CELL_PIXELS, 0),
  102. (0 , 0)
  103. ])
  104. class Floor(WorldObj):
  105. """
  106. Colored floor tile the agent can walk over
  107. """
  108. def __init__(self, color='blue'):
  109. super().__init__('floor', color)
  110. def can_overlap(self):
  111. return True
  112. def render(self, r):
  113. # Give the floor a pale color
  114. c = COLORS[self.color]
  115. r.setLineColor(100, 100, 100, 0)
  116. r.setColor(*c/2)
  117. r.drawPolygon([
  118. (1 , CELL_PIXELS),
  119. (CELL_PIXELS, CELL_PIXELS),
  120. (CELL_PIXELS, 1),
  121. (1 , 1)
  122. ])
  123. class Lava(WorldObj):
  124. def __init__(self):
  125. super().__init__('lava', 'red')
  126. def can_overlap(self):
  127. return True
  128. def render(self, r):
  129. orange = 255, 128, 0
  130. r.setLineColor(*orange)
  131. r.setColor(*orange)
  132. r.drawPolygon([
  133. (0 , CELL_PIXELS),
  134. (CELL_PIXELS, CELL_PIXELS),
  135. (CELL_PIXELS, 0),
  136. (0 , 0)
  137. ])
  138. # drawing the waves
  139. r.setLineColor(0, 0, 0)
  140. r.drawPolyline([
  141. (.1 * CELL_PIXELS, .3 * CELL_PIXELS),
  142. (.3 * CELL_PIXELS, .4 * CELL_PIXELS),
  143. (.5 * CELL_PIXELS, .3 * CELL_PIXELS),
  144. (.7 * CELL_PIXELS, .4 * CELL_PIXELS),
  145. (.9 * CELL_PIXELS, .3 * CELL_PIXELS),
  146. ])
  147. r.drawPolyline([
  148. (.1 * CELL_PIXELS, .5 * CELL_PIXELS),
  149. (.3 * CELL_PIXELS, .6 * CELL_PIXELS),
  150. (.5 * CELL_PIXELS, .5 * CELL_PIXELS),
  151. (.7 * CELL_PIXELS, .6 * CELL_PIXELS),
  152. (.9 * CELL_PIXELS, .5 * CELL_PIXELS),
  153. ])
  154. r.drawPolyline([
  155. (.1 * CELL_PIXELS, .7 * CELL_PIXELS),
  156. (.3 * CELL_PIXELS, .8 * CELL_PIXELS),
  157. (.5 * CELL_PIXELS, .7 * CELL_PIXELS),
  158. (.7 * CELL_PIXELS, .8 * CELL_PIXELS),
  159. (.9 * CELL_PIXELS, .7 * CELL_PIXELS),
  160. ])
  161. class Wall(WorldObj):
  162. def __init__(self, color='grey'):
  163. super().__init__('wall', color)
  164. def see_behind(self):
  165. return False
  166. def render(self, r):
  167. self._set_color(r)
  168. r.drawPolygon([
  169. (0 , CELL_PIXELS),
  170. (CELL_PIXELS, CELL_PIXELS),
  171. (CELL_PIXELS, 0),
  172. (0 , 0)
  173. ])
  174. class Door(WorldObj):
  175. def __init__(self, color, is_open=False, is_locked=False):
  176. super().__init__('door', color)
  177. self.is_open = is_open
  178. self.is_locked = is_locked
  179. def can_overlap(self):
  180. """The agent can only walk over this cell when the door is open"""
  181. return self.is_open
  182. def see_behind(self):
  183. return self.is_open
  184. def toggle(self, env, pos):
  185. # If the player has the right key to open the door
  186. if self.is_locked:
  187. if isinstance(env.carrying, Key) and env.carrying.color == self.color:
  188. self.is_locked = False
  189. self.is_open = True
  190. return True
  191. return False
  192. self.is_open = not self.is_open
  193. return True
  194. def render(self, r):
  195. c = COLORS[self.color]
  196. r.setLineColor(c[0], c[1], c[2])
  197. r.setColor(c[0], c[1], c[2], 50 if self.is_locked else 0)
  198. if self.is_open:
  199. r.drawPolygon([
  200. (CELL_PIXELS-2, CELL_PIXELS),
  201. (CELL_PIXELS , CELL_PIXELS),
  202. (CELL_PIXELS , 0),
  203. (CELL_PIXELS-2, 0)
  204. ])
  205. return
  206. r.drawPolygon([
  207. (0 , CELL_PIXELS),
  208. (CELL_PIXELS, CELL_PIXELS),
  209. (CELL_PIXELS, 0),
  210. (0 , 0)
  211. ])
  212. r.drawPolygon([
  213. (2 , CELL_PIXELS-2),
  214. (CELL_PIXELS-2, CELL_PIXELS-2),
  215. (CELL_PIXELS-2, 2),
  216. (2 , 2)
  217. ])
  218. if self.is_locked:
  219. # Draw key slot
  220. r.drawLine(
  221. CELL_PIXELS * 0.55,
  222. CELL_PIXELS * 0.5,
  223. CELL_PIXELS * 0.75,
  224. CELL_PIXELS * 0.5
  225. )
  226. else:
  227. # Draw door handle
  228. r.drawCircle(CELL_PIXELS * 0.75, CELL_PIXELS * 0.5, 2)
  229. class Key(WorldObj):
  230. def __init__(self, color='blue'):
  231. super(Key, self).__init__('key', color)
  232. def can_pickup(self):
  233. return True
  234. def render(self, r):
  235. self._set_color(r)
  236. # Vertical quad
  237. r.drawPolygon([
  238. (16, 10),
  239. (20, 10),
  240. (20, 28),
  241. (16, 28)
  242. ])
  243. # Teeth
  244. r.drawPolygon([
  245. (12, 19),
  246. (16, 19),
  247. (16, 21),
  248. (12, 21)
  249. ])
  250. r.drawPolygon([
  251. (12, 26),
  252. (16, 26),
  253. (16, 28),
  254. (12, 28)
  255. ])
  256. r.drawCircle(18, 9, 6)
  257. r.setLineColor(0, 0, 0)
  258. r.setColor(0, 0, 0)
  259. r.drawCircle(18, 9, 2)
  260. class Ball(WorldObj):
  261. def __init__(self, color='blue'):
  262. super(Ball, self).__init__('ball', color)
  263. def can_pickup(self):
  264. return True
  265. def render(self, r):
  266. self._set_color(r)
  267. r.drawCircle(CELL_PIXELS * 0.5, CELL_PIXELS * 0.5, 10)
  268. class Box(WorldObj):
  269. def __init__(self, color, contains=None):
  270. super(Box, self).__init__('box', color)
  271. self.contains = contains
  272. def can_pickup(self):
  273. return True
  274. def render(self, r):
  275. c = COLORS[self.color]
  276. r.setLineColor(c[0], c[1], c[2])
  277. r.setColor(0, 0, 0)
  278. r.setLineWidth(2)
  279. r.drawPolygon([
  280. (4 , CELL_PIXELS-4),
  281. (CELL_PIXELS-4, CELL_PIXELS-4),
  282. (CELL_PIXELS-4, 4),
  283. (4 , 4)
  284. ])
  285. r.drawLine(
  286. 4,
  287. CELL_PIXELS / 2,
  288. CELL_PIXELS - 4,
  289. CELL_PIXELS / 2
  290. )
  291. r.setLineWidth(1)
  292. def toggle(self, env, pos):
  293. # Replace the box by its contents
  294. env.grid.set(*pos, self.contains)
  295. return True
  296. class Grid:
  297. """
  298. Represent a grid and operations on it
  299. """
  300. def __init__(self, width, height):
  301. assert width >= 3
  302. assert height >= 3
  303. self.width = width
  304. self.height = height
  305. self.grid = [None] * width * height
  306. def __contains__(self, key):
  307. if isinstance(key, WorldObj):
  308. for e in self.grid:
  309. if e is key:
  310. return True
  311. elif isinstance(key, tuple):
  312. for e in self.grid:
  313. if e is None:
  314. continue
  315. if (e.color, e.type) == key:
  316. return True
  317. if key[0] is None and key[1] == e.type:
  318. return True
  319. return False
  320. def __eq__(self, other):
  321. grid1 = self.encode()
  322. grid2 = other.encode()
  323. return np.array_equal(grid2, grid1)
  324. def __ne__(self, other):
  325. return not self == other
  326. def copy(self):
  327. from copy import deepcopy
  328. return deepcopy(self)
  329. def set(self, i, j, v):
  330. assert i >= 0 and i < self.width
  331. assert j >= 0 and j < self.height
  332. self.grid[j * self.width + i] = v
  333. def get(self, i, j):
  334. assert i >= 0 and i < self.width
  335. assert j >= 0 and j < self.height
  336. return self.grid[j * self.width + i]
  337. def horz_wall(self, x, y, length=None):
  338. if length is None:
  339. length = self.width - x
  340. for i in range(0, length):
  341. self.set(x + i, y, Wall())
  342. def vert_wall(self, x, y, length=None):
  343. if length is None:
  344. length = self.height - y
  345. for j in range(0, length):
  346. self.set(x, y + j, Wall())
  347. def wall_rect(self, x, y, w, h):
  348. self.horz_wall(x, y, w)
  349. self.horz_wall(x, y+h-1, w)
  350. self.vert_wall(x, y, h)
  351. self.vert_wall(x+w-1, y, h)
  352. def rotate_left(self):
  353. """
  354. Rotate the grid to the left (counter-clockwise)
  355. """
  356. grid = Grid(self.height, self.width)
  357. for i in range(self.width):
  358. for j in range(self.height):
  359. v = self.get(i, j)
  360. grid.set(j, grid.height - 1 - i, v)
  361. return grid
  362. def slice(self, topX, topY, width, height):
  363. """
  364. Get a subset of the grid
  365. """
  366. grid = Grid(width, height)
  367. for j in range(0, height):
  368. for i in range(0, width):
  369. x = topX + i
  370. y = topY + j
  371. if x >= 0 and x < self.width and \
  372. y >= 0 and y < self.height:
  373. v = self.get(x, y)
  374. else:
  375. v = Wall()
  376. grid.set(i, j, v)
  377. return grid
  378. def render(self, r, tile_size):
  379. """
  380. Render this grid at a given scale
  381. :param r: target renderer object
  382. :param tile_size: tile size in pixels
  383. """
  384. assert r.width == self.width * tile_size
  385. assert r.height == self.height * tile_size
  386. # Total grid size at native scale
  387. widthPx = self.width * CELL_PIXELS
  388. heightPx = self.height * CELL_PIXELS
  389. r.push()
  390. # Internally, we draw at the "large" full-grid resolution, but we
  391. # use the renderer to scale back to the desired size
  392. r.scale(tile_size / CELL_PIXELS, tile_size / CELL_PIXELS)
  393. # Draw the background of the in-world cells black
  394. r.fillRect(
  395. 0,
  396. 0,
  397. widthPx,
  398. heightPx,
  399. 0, 0, 0
  400. )
  401. # Draw grid lines
  402. r.setLineColor(100, 100, 100)
  403. for rowIdx in range(0, self.height):
  404. y = CELL_PIXELS * rowIdx
  405. r.drawLine(0, y, widthPx, y)
  406. for colIdx in range(0, self.width):
  407. x = CELL_PIXELS * colIdx
  408. r.drawLine(x, 0, x, heightPx)
  409. # Render the grid
  410. for j in range(0, self.height):
  411. for i in range(0, self.width):
  412. cell = self.get(i, j)
  413. if cell == None:
  414. continue
  415. r.push()
  416. r.translate(i * CELL_PIXELS, j * CELL_PIXELS)
  417. cell.render(r)
  418. r.pop()
  419. r.pop()
  420. def encode(self, vis_mask=None):
  421. """
  422. Produce a compact numpy encoding of the grid
  423. """
  424. if vis_mask is None:
  425. vis_mask = np.ones((self.width, self.height), dtype=bool)
  426. array = np.zeros((self.width, self.height, 3), dtype='uint8')
  427. for i in range(self.width):
  428. for j in range(self.height):
  429. if vis_mask[i, j]:
  430. v = self.get(i, j)
  431. if v is None:
  432. array[i, j, 0] = OBJECT_TO_IDX['empty']
  433. array[i, j, 1] = 0
  434. array[i, j, 2] = 0
  435. else:
  436. # State, 0: open, 1: closed, 2: locked
  437. state = 0
  438. if hasattr(v, 'is_open') and not v.is_open:
  439. state = 1
  440. if hasattr(v, 'is_locked') and v.is_locked:
  441. state = 2
  442. array[i, j, 0] = OBJECT_TO_IDX[v.type]
  443. array[i, j, 1] = COLOR_TO_IDX[v.color]
  444. array[i, j, 2] = state
  445. return array
  446. @staticmethod
  447. def decode(array):
  448. """
  449. Decode an array grid encoding back into a grid
  450. """
  451. width, height, channels = array.shape
  452. assert channels == 3
  453. grid = Grid(width, height)
  454. for i in range(width):
  455. for j in range(height):
  456. typeIdx, colorIdx, state = array[i, j]
  457. if typeIdx == OBJECT_TO_IDX['unseen'] or \
  458. typeIdx == OBJECT_TO_IDX['empty']:
  459. continue
  460. objType = IDX_TO_OBJECT[typeIdx]
  461. color = IDX_TO_COLOR[colorIdx]
  462. # State, 0: open, 1: closed, 2: locked
  463. is_open = state == 0
  464. is_locked = state == 2
  465. if objType == 'wall':
  466. v = Wall(color)
  467. elif objType == 'floor':
  468. v = Floor(color)
  469. elif objType == 'ball':
  470. v = Ball(color)
  471. elif objType == 'key':
  472. v = Key(color)
  473. elif objType == 'box':
  474. v = Box(color)
  475. elif objType == 'door':
  476. v = Door(color, is_open, is_locked)
  477. elif objType == 'goal':
  478. v = Goal()
  479. elif objType == 'lava':
  480. v = Lava()
  481. else:
  482. assert False, "unknown obj type in decode '%s'" % objType
  483. grid.set(i, j, v)
  484. return grid
  485. def process_vis(grid, agent_pos):
  486. mask = np.zeros(shape=(grid.width, grid.height), dtype=np.bool)
  487. mask[agent_pos[0], agent_pos[1]] = True
  488. for j in reversed(range(0, grid.height)):
  489. for i in range(0, grid.width-1):
  490. if not mask[i, j]:
  491. continue
  492. cell = grid.get(i, j)
  493. if cell and not cell.see_behind():
  494. continue
  495. mask[i+1, j] = True
  496. if j > 0:
  497. mask[i+1, j-1] = True
  498. mask[i, j-1] = True
  499. for i in reversed(range(1, grid.width)):
  500. if not mask[i, j]:
  501. continue
  502. cell = grid.get(i, j)
  503. if cell and not cell.see_behind():
  504. continue
  505. mask[i-1, j] = True
  506. if j > 0:
  507. mask[i-1, j-1] = True
  508. mask[i, j-1] = True
  509. for j in range(0, grid.height):
  510. for i in range(0, grid.width):
  511. if not mask[i, j]:
  512. grid.set(i, j, None)
  513. return mask
  514. class MiniGridEnv(gym.Env):
  515. """
  516. 2D grid world game environment
  517. """
  518. metadata = {
  519. 'render.modes': ['human', 'rgb_array', 'pixmap'],
  520. 'video.frames_per_second' : 10
  521. }
  522. # Enumeration of possible actions
  523. class Actions(IntEnum):
  524. # Turn left, turn right, move forward
  525. left = 0
  526. right = 1
  527. forward = 2
  528. # Pick up an object
  529. pickup = 3
  530. # Drop an object
  531. drop = 4
  532. # Toggle/activate an object
  533. toggle = 5
  534. # Done completing task
  535. done = 6
  536. def __init__(
  537. self,
  538. grid_size=None,
  539. width=None,
  540. height=None,
  541. max_steps=100,
  542. see_through_walls=False,
  543. seed=1337,
  544. agent_view_size=7
  545. ):
  546. # Can't set both grid_size and width/height
  547. if grid_size:
  548. assert width == None and height == None
  549. width = grid_size
  550. height = grid_size
  551. # Action enumeration for this environment
  552. self.actions = MiniGridEnv.Actions
  553. # Actions are discrete integer values
  554. self.action_space = spaces.Discrete(len(self.actions))
  555. # Number of cells (width and height) in the agent view
  556. self.agent_view_size = agent_view_size
  557. # Observations are dictionaries containing an
  558. # encoding of the grid and a textual 'mission' string
  559. self.observation_space = spaces.Box(
  560. low=0,
  561. high=255,
  562. shape=(self.agent_view_size, self.agent_view_size, 3),
  563. dtype='uint8'
  564. )
  565. self.observation_space = spaces.Dict({
  566. 'image': self.observation_space
  567. })
  568. # Range of possible rewards
  569. self.reward_range = (0, 1)
  570. # Renderer object used to render the whole grid (full-scale)
  571. self.grid_render = None
  572. # Renderer used to render observations (small-scale agent view)
  573. self.obs_render = None
  574. # Environment configuration
  575. self.width = width
  576. self.height = height
  577. self.max_steps = max_steps
  578. self.see_through_walls = see_through_walls
  579. # Starting position and direction for the agent
  580. self.start_pos = None
  581. self.start_dir = None
  582. # Initialize the RNG
  583. self.seed(seed=seed)
  584. # Initialize the state
  585. self.reset()
  586. def reset(self):
  587. # Generate a new random grid at the start of each episode
  588. # To keep the same grid for each episode, call env.seed() with
  589. # the same seed before calling env.reset()
  590. self._gen_grid(self.width, self.height)
  591. # These fields should be defined by _gen_grid
  592. assert self.start_pos is not None
  593. assert self.start_dir is not None
  594. # Check that the agent doesn't overlap with an object
  595. start_cell = self.grid.get(*self.start_pos)
  596. assert start_cell is None or start_cell.can_overlap()
  597. # Place the agent in the starting position and direction
  598. self.agent_pos = self.start_pos
  599. self.agent_dir = self.start_dir
  600. # Item picked up, being carried, initially nothing
  601. self.carrying = None
  602. # Step count since episode start
  603. self.step_count = 0
  604. # Return first observation
  605. obs = self.gen_obs()
  606. return obs
  607. def seed(self, seed=1337):
  608. # Seed the random number generator
  609. self.np_random, _ = seeding.np_random(seed)
  610. return [seed]
  611. @property
  612. def steps_remaining(self):
  613. return self.max_steps - self.step_count
  614. def __str__(self):
  615. """
  616. Produce a pretty string of the environment's grid along with the agent.
  617. A grid cell is represented by 2-character string, the first one for
  618. the object and the second one for the color.
  619. """
  620. # Map of object types to short string
  621. OBJECT_TO_STR = {
  622. 'wall' : 'W',
  623. 'floor' : 'F',
  624. 'door' : 'D',
  625. 'key' : 'K',
  626. 'ball' : 'A',
  627. 'box' : 'B',
  628. 'goal' : 'G',
  629. 'lava' : 'V',
  630. }
  631. # Short string for opened door
  632. OPENDED_DOOR_IDS = '_'
  633. # Map agent's direction to short string
  634. AGENT_DIR_TO_STR = {
  635. 0: '>',
  636. 1: 'V',
  637. 2: '<',
  638. 3: '^'
  639. }
  640. str = ''
  641. for j in range(self.grid.height):
  642. for i in range(self.grid.width):
  643. if i == self.agent_pos[0] and j == self.agent_pos[1]:
  644. str += 2 * AGENT_DIR_TO_STR[self.agent_dir]
  645. continue
  646. c = self.grid.get(i, j)
  647. if c == None:
  648. str += ' '
  649. continue
  650. if c.type == 'door':
  651. if c.is_open:
  652. str += '__'
  653. elif c.is_locked:
  654. str += 'L' + c.color[0].upper()
  655. else:
  656. str += 'D' + c.color[0].upper()
  657. continue
  658. str += OBJECT_TO_STR[c.type] + c.color[0].upper()
  659. if j < self.grid.height - 1:
  660. str += '\n'
  661. return str
  662. def _gen_grid(self, width, height):
  663. assert False, "_gen_grid needs to be implemented by each environment"
  664. def _reward(self):
  665. """
  666. Compute the reward to be given upon success
  667. """
  668. return 1 - 0.9 * (self.step_count / self.max_steps)
  669. def _rand_int(self, low, high):
  670. """
  671. Generate random integer in [low,high[
  672. """
  673. return self.np_random.randint(low, high)
  674. def _rand_float(self, low, high):
  675. """
  676. Generate random float in [low,high[
  677. """
  678. return self.np_random.uniform(low, high)
  679. def _rand_bool(self):
  680. """
  681. Generate random boolean value
  682. """
  683. return (self.np_random.randint(0, 2) == 0)
  684. def _rand_elem(self, iterable):
  685. """
  686. Pick a random element in a list
  687. """
  688. lst = list(iterable)
  689. idx = self._rand_int(0, len(lst))
  690. return lst[idx]
  691. def _rand_subset(self, iterable, num_elems):
  692. """
  693. Sample a random subset of distinct elements of a list
  694. """
  695. lst = list(iterable)
  696. assert num_elems <= len(lst)
  697. out = []
  698. while len(out) < num_elems:
  699. elem = self._rand_elem(lst)
  700. lst.remove(elem)
  701. out.append(elem)
  702. return out
  703. def _rand_color(self):
  704. """
  705. Generate a random color name (string)
  706. """
  707. return self._rand_elem(COLOR_NAMES)
  708. def _rand_pos(self, xLow, xHigh, yLow, yHigh):
  709. """
  710. Generate a random (x,y) position tuple
  711. """
  712. return (
  713. self.np_random.randint(xLow, xHigh),
  714. self.np_random.randint(yLow, yHigh)
  715. )
  716. def place_obj(self,
  717. obj,
  718. top=None,
  719. size=None,
  720. reject_fn=None,
  721. max_tries=math.inf
  722. ):
  723. """
  724. Place an object at an empty position in the grid
  725. :param top: top-left position of the rectangle where to place
  726. :param size: size of the rectangle where to place
  727. :param reject_fn: function to filter out potential positions
  728. """
  729. if top is None:
  730. top = (0, 0)
  731. if size is None:
  732. size = (self.grid.width, self.grid.height)
  733. num_tries = 0
  734. while True:
  735. # This is to handle with rare cases where rejection sampling
  736. # gets stuck in an infinite loop
  737. if num_tries > max_tries:
  738. raise RecursionError('rejection sampling failed in place_obj')
  739. num_tries += 1
  740. pos = np.array((
  741. self._rand_int(top[0], top[0] + size[0]),
  742. self._rand_int(top[1], top[1] + size[1])
  743. ))
  744. # Don't place the object on top of another object
  745. if self.grid.get(*pos) != None:
  746. continue
  747. # Don't place the object where the agent is
  748. if np.array_equal(pos, self.start_pos):
  749. continue
  750. # Check if there is a filtering criterion
  751. if reject_fn and reject_fn(self, pos):
  752. continue
  753. break
  754. self.grid.set(*pos, obj)
  755. if obj is not None:
  756. obj.init_pos = pos
  757. obj.cur_pos = pos
  758. return pos
  759. def place_agent(
  760. self,
  761. top=None,
  762. size=None,
  763. rand_dir=True,
  764. max_tries=math.inf
  765. ):
  766. """
  767. Set the agent's starting point at an empty position in the grid
  768. """
  769. self.start_pos = None
  770. pos = self.place_obj(None, top, size, max_tries=max_tries)
  771. self.start_pos = pos
  772. if rand_dir:
  773. self.start_dir = self._rand_int(0, 4)
  774. return pos
  775. @property
  776. def dir_vec(self):
  777. """
  778. Get the direction vector for the agent, pointing in the direction
  779. of forward movement.
  780. """
  781. assert self.agent_dir >= 0 and self.agent_dir < 4
  782. return DIR_TO_VEC[self.agent_dir]
  783. @property
  784. def right_vec(self):
  785. """
  786. Get the vector pointing to the right of the agent.
  787. """
  788. dx, dy = self.dir_vec
  789. return np.array((-dy, dx))
  790. @property
  791. def front_pos(self):
  792. """
  793. Get the position of the cell that is right in front of the agent
  794. """
  795. return self.agent_pos + self.dir_vec
  796. def get_view_coords(self, i, j):
  797. """
  798. Translate and rotate absolute grid coordinates (i, j) into the
  799. agent's partially observable view (sub-grid). Note that the resulting
  800. coordinates may be negative or outside of the agent's view size.
  801. """
  802. ax, ay = self.agent_pos
  803. dx, dy = self.dir_vec
  804. rx, ry = self.right_vec
  805. # Compute the absolute coordinates of the top-left view corner
  806. sz = self.agent_view_size
  807. hs = self.agent_view_size // 2
  808. tx = ax + (dx * (sz-1)) - (rx * hs)
  809. ty = ay + (dy * (sz-1)) - (ry * hs)
  810. lx = i - tx
  811. ly = j - ty
  812. # Project the coordinates of the object relative to the top-left
  813. # corner onto the agent's own coordinate system
  814. vx = (rx*lx + ry*ly)
  815. vy = -(dx*lx + dy*ly)
  816. return vx, vy
  817. def get_view_exts(self):
  818. """
  819. Get the extents of the square set of tiles visible to the agent
  820. Note: the bottom extent indices are not included in the set
  821. """
  822. # Facing right
  823. if self.agent_dir == 0:
  824. topX = self.agent_pos[0]
  825. topY = self.agent_pos[1] - self.agent_view_size // 2
  826. # Facing down
  827. elif self.agent_dir == 1:
  828. topX = self.agent_pos[0] - self.agent_view_size // 2
  829. topY = self.agent_pos[1]
  830. # Facing left
  831. elif self.agent_dir == 2:
  832. topX = self.agent_pos[0] - self.agent_view_size + 1
  833. topY = self.agent_pos[1] - self.agent_view_size // 2
  834. # Facing up
  835. elif self.agent_dir == 3:
  836. topX = self.agent_pos[0] - self.agent_view_size // 2
  837. topY = self.agent_pos[1] - self.agent_view_size + 1
  838. else:
  839. assert False, "invalid agent direction"
  840. botX = topX + self.agent_view_size
  841. botY = topY + self.agent_view_size
  842. return (topX, topY, botX, botY)
  843. def relative_coords(self, x, y):
  844. """
  845. Check if a grid position belongs to the agent's field of view, and returns the corresponding coordinates
  846. """
  847. vx, vy = self.get_view_coords(x, y)
  848. if vx < 0 or vy < 0 or vx >= self.agent_view_size or vy >= self.agent_view_size:
  849. return None
  850. return vx, vy
  851. def in_view(self, x, y):
  852. """
  853. check if a grid position is visible to the agent
  854. """
  855. return self.relative_coords(x, y) is not None
  856. def agent_sees(self, x, y):
  857. """
  858. Check if a non-empty grid position is visible to the agent
  859. """
  860. coordinates = self.relative_coords(x, y)
  861. if coordinates is None:
  862. return False
  863. vx, vy = coordinates
  864. obs = self.gen_obs()
  865. obs_grid = Grid.decode(obs['image'])
  866. obs_cell = obs_grid.get(vx, vy)
  867. world_cell = self.grid.get(x, y)
  868. return obs_cell is not None and obs_cell.type == world_cell.type
  869. def step(self, action):
  870. self.step_count += 1
  871. reward = 0
  872. done = False
  873. # Get the position in front of the agent
  874. fwd_pos = self.front_pos
  875. # Get the contents of the cell in front of the agent
  876. fwd_cell = self.grid.get(*fwd_pos)
  877. # Rotate left
  878. if action == self.actions.left:
  879. self.agent_dir -= 1
  880. if self.agent_dir < 0:
  881. self.agent_dir += 4
  882. # Rotate right
  883. elif action == self.actions.right:
  884. self.agent_dir = (self.agent_dir + 1) % 4
  885. # Move forward
  886. elif action == self.actions.forward:
  887. if fwd_cell == None or fwd_cell.can_overlap():
  888. self.agent_pos = fwd_pos
  889. if fwd_cell != None and fwd_cell.type == 'goal':
  890. done = True
  891. reward = self._reward()
  892. if fwd_cell != None and fwd_cell.type == 'lava':
  893. done = True
  894. # Pick up an object
  895. elif action == self.actions.pickup:
  896. if fwd_cell and fwd_cell.can_pickup():
  897. if self.carrying is None:
  898. self.carrying = fwd_cell
  899. self.carrying.cur_pos = np.array([-1, -1])
  900. self.grid.set(*fwd_pos, None)
  901. # Drop an object
  902. elif action == self.actions.drop:
  903. if not fwd_cell and self.carrying:
  904. self.grid.set(*fwd_pos, self.carrying)
  905. self.carrying.cur_pos = fwd_pos
  906. self.carrying = None
  907. # Toggle/activate an object
  908. elif action == self.actions.toggle:
  909. if fwd_cell:
  910. fwd_cell.toggle(self, fwd_pos)
  911. # Done action (not used by default)
  912. elif action == self.actions.done:
  913. pass
  914. else:
  915. assert False, "unknown action"
  916. if self.step_count >= self.max_steps:
  917. done = True
  918. obs = self.gen_obs()
  919. return obs, reward, done, {}
  920. def gen_obs_grid(self):
  921. """
  922. Generate the sub-grid observed by the agent.
  923. This method also outputs a visibility mask telling us which grid
  924. cells the agent can actually see.
  925. """
  926. topX, topY, botX, botY = self.get_view_exts()
  927. grid = self.grid.slice(topX, topY, self.agent_view_size, self.agent_view_size)
  928. for i in range(self.agent_dir + 1):
  929. grid = grid.rotate_left()
  930. # Process occluders and visibility
  931. # Note that this incurs some performance cost
  932. if not self.see_through_walls:
  933. vis_mask = grid.process_vis(agent_pos=(self.agent_view_size // 2 , self.agent_view_size - 1))
  934. else:
  935. vis_mask = np.ones(shape=(grid.width, grid.height), dtype=np.bool)
  936. # Make it so the agent sees what it's carrying
  937. # We do this by placing the carried object at the agent's position
  938. # in the agent's partially observable view
  939. agent_pos = grid.width // 2, grid.height - 1
  940. if self.carrying:
  941. grid.set(*agent_pos, self.carrying)
  942. else:
  943. grid.set(*agent_pos, None)
  944. return grid, vis_mask
  945. def gen_obs(self):
  946. """
  947. Generate the agent's view (partially observable, low-resolution encoding)
  948. """
  949. grid, vis_mask = self.gen_obs_grid()
  950. # Encode the partially observable view into a numpy array
  951. image = grid.encode(vis_mask)
  952. assert hasattr(self, 'mission'), "environments must define a textual mission string"
  953. # Observations are dictionaries containing:
  954. # - an image (partially observable view of the environment)
  955. # - the agent's direction/orientation (acting as a compass)
  956. # - a textual mission string (instructions for the agent)
  957. obs = {
  958. 'image': image,
  959. 'direction': self.agent_dir,
  960. 'mission': self.mission
  961. }
  962. return obs
  963. def get_obs_render(self, obs, tile_pixels=CELL_PIXELS//2):
  964. """
  965. Render an agent observation for visualization
  966. """
  967. if self.obs_render == None:
  968. from gym_minigrid.rendering import Renderer
  969. self.obs_render = Renderer(
  970. self.agent_view_size * tile_pixels,
  971. self.agent_view_size * tile_pixels
  972. )
  973. r = self.obs_render
  974. r.beginFrame()
  975. grid = Grid.decode(obs)
  976. # Render the whole grid
  977. grid.render(r, tile_pixels)
  978. # Draw the agent
  979. ratio = tile_pixels / CELL_PIXELS
  980. r.push()
  981. r.scale(ratio, ratio)
  982. r.translate(
  983. CELL_PIXELS * (0.5 + self.agent_view_size // 2),
  984. CELL_PIXELS * (self.agent_view_size - 0.5)
  985. )
  986. r.rotate(3 * 90)
  987. r.setLineColor(255, 0, 0)
  988. r.setColor(255, 0, 0)
  989. r.drawPolygon([
  990. (-12, 10),
  991. ( 12, 0),
  992. (-12, -10)
  993. ])
  994. r.pop()
  995. r.endFrame()
  996. return r.getPixmap()
  997. def render(self, mode='human', close=False):
  998. """
  999. Render the whole-grid human view
  1000. """
  1001. if close:
  1002. if self.grid_render:
  1003. self.grid_render.close()
  1004. return
  1005. if self.grid_render is None:
  1006. from gym_minigrid.rendering import Renderer
  1007. self.grid_render = Renderer(
  1008. self.width * CELL_PIXELS,
  1009. self.height * CELL_PIXELS,
  1010. True if mode == 'human' else False
  1011. )
  1012. r = self.grid_render
  1013. if r.window:
  1014. r.window.setText(self.mission)
  1015. r.beginFrame()
  1016. # Render the whole grid
  1017. self.grid.render(r, CELL_PIXELS)
  1018. # Draw the agent
  1019. r.push()
  1020. r.translate(
  1021. CELL_PIXELS * (self.agent_pos[0] + 0.5),
  1022. CELL_PIXELS * (self.agent_pos[1] + 0.5)
  1023. )
  1024. r.rotate(self.agent_dir * 90)
  1025. r.setLineColor(255, 0, 0)
  1026. r.setColor(255, 0, 0)
  1027. r.drawPolygon([
  1028. (-12, 10),
  1029. ( 12, 0),
  1030. (-12, -10)
  1031. ])
  1032. r.pop()
  1033. # Compute which cells are visible to the agent
  1034. _, vis_mask = self.gen_obs_grid()
  1035. # Compute the absolute coordinates of the bottom-left corner
  1036. # of the agent's view area
  1037. f_vec = self.dir_vec
  1038. r_vec = self.right_vec
  1039. top_left = self.agent_pos + f_vec * (self.agent_view_size-1) - r_vec * (self.agent_view_size // 2)
  1040. # For each cell in the visibility mask
  1041. for vis_j in range(0, self.agent_view_size):
  1042. for vis_i in range(0, self.agent_view_size):
  1043. # If this cell is not visible, don't highlight it
  1044. if not vis_mask[vis_i, vis_j]:
  1045. continue
  1046. # Compute the world coordinates of this cell
  1047. abs_i, abs_j = top_left - (f_vec * vis_j) + (r_vec * vis_i)
  1048. # Highlight the cell
  1049. r.fillRect(
  1050. abs_i * CELL_PIXELS,
  1051. abs_j * CELL_PIXELS,
  1052. CELL_PIXELS,
  1053. CELL_PIXELS,
  1054. 255, 255, 255, 75
  1055. )
  1056. r.endFrame()
  1057. if mode == 'rgb_array':
  1058. return r.getArray()
  1059. elif mode == 'pixmap':
  1060. return r.getPixmap()
  1061. return r