minigrid.py 35 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312
  1. import math
  2. import gym
  3. from enum import IntEnum
  4. import numpy as np
  5. from gym import error, spaces, utils
  6. from gym.utils import seeding
  7. # Size in pixels of a cell in the full-scale human view
  8. CELL_PIXELS = 32
  9. # Number of cells (width and height) in the agent view
  10. AGENT_VIEW_SIZE = 7
  11. # Size of the array given as an observation to the agent
  12. OBS_ARRAY_SIZE = (AGENT_VIEW_SIZE, AGENT_VIEW_SIZE, 3)
  13. # Map of color names to RGB values
  14. COLORS = {
  15. 'red' : np.array([255, 0, 0]),
  16. 'green' : np.array([0, 255, 0]),
  17. 'blue' : np.array([0, 0, 255]),
  18. 'purple': np.array([112, 39, 195]),
  19. 'yellow': np.array([255, 255, 0]),
  20. 'grey' : np.array([100, 100, 100])
  21. }
  22. COLOR_NAMES = sorted(list(COLORS.keys()))
  23. # Used to map colors to integers
  24. COLOR_TO_IDX = {
  25. 'red' : 0,
  26. 'green' : 1,
  27. 'blue' : 2,
  28. 'purple': 3,
  29. 'yellow': 4,
  30. 'grey' : 5
  31. }
  32. IDX_TO_COLOR = dict(zip(COLOR_TO_IDX.values(), COLOR_TO_IDX.keys()))
  33. # Map of object type to integers
  34. OBJECT_TO_IDX = {
  35. 'empty' : 0,
  36. 'wall' : 1,
  37. 'floor' : 2,
  38. 'door' : 3,
  39. 'locked_door' : 4,
  40. 'key' : 5,
  41. 'ball' : 6,
  42. 'box' : 7,
  43. 'goal' : 8
  44. }
  45. IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys()))
  46. # Map of agent direction indices to vectors
  47. DIR_TO_VEC = [
  48. # Pointing right (positive X)
  49. np.array((1, 0)),
  50. # Down (positive Y)
  51. np.array((0, 1)),
  52. # Pointing left (negative X)
  53. np.array((-1, 0)),
  54. # Up (negative Y)
  55. np.array((0, -1)),
  56. ]
  57. class WorldObj:
  58. """
  59. Base class for grid world objects
  60. """
  61. def __init__(self, type, color):
  62. assert type in OBJECT_TO_IDX, type
  63. assert color in COLOR_TO_IDX, color
  64. self.type = type
  65. self.color = color
  66. self.contains = None
  67. # Initial position of the object
  68. self.init_pos = None
  69. # Current position of the object
  70. self.cur_pos = None
  71. def can_overlap(self):
  72. """Can the agent overlap with this?"""
  73. return False
  74. def can_pickup(self):
  75. """Can the agent pick this up?"""
  76. return False
  77. def can_contain(self):
  78. """Can this contain another object?"""
  79. return False
  80. def see_behind(self):
  81. """Can the agent see behind this object?"""
  82. return True
  83. def toggle(self, env, pos):
  84. """Method to trigger/toggle an action this object performs"""
  85. return False
  86. def render(self, r):
  87. """Draw this object with the given renderer"""
  88. raise NotImplementedError
  89. def _set_color(self, r):
  90. """Set the color of this object as the active drawing color"""
  91. c = COLORS[self.color]
  92. r.setLineColor(c[0], c[1], c[2])
  93. r.setColor(c[0], c[1], c[2])
  94. class Goal(WorldObj):
  95. def __init__(self):
  96. super().__init__('goal', 'green')
  97. def can_overlap(self):
  98. return True
  99. def render(self, r):
  100. self._set_color(r)
  101. r.drawPolygon([
  102. (0 , CELL_PIXELS),
  103. (CELL_PIXELS, CELL_PIXELS),
  104. (CELL_PIXELS, 0),
  105. (0 , 0)
  106. ])
  107. class Floor(WorldObj):
  108. """
  109. Colored floor tile the agent can walk over
  110. """
  111. def __init__(self, color='blue'):
  112. super().__init__('floor', color)
  113. def can_overlap(self):
  114. return True
  115. def render(self, r):
  116. # Give the floor a pale color
  117. c = COLORS[self.color]
  118. r.setLineColor(100, 100, 100, 0)
  119. r.setColor(*c/2)
  120. r.drawPolygon([
  121. (1 , CELL_PIXELS),
  122. (CELL_PIXELS, CELL_PIXELS),
  123. (CELL_PIXELS, 1),
  124. (1 , 1)
  125. ])
  126. class Wall(WorldObj):
  127. def __init__(self, color='grey'):
  128. super().__init__('wall', color)
  129. def see_behind(self):
  130. return False
  131. def render(self, r):
  132. self._set_color(r)
  133. r.drawPolygon([
  134. (0 , CELL_PIXELS),
  135. (CELL_PIXELS, CELL_PIXELS),
  136. (CELL_PIXELS, 0),
  137. (0 , 0)
  138. ])
  139. class Door(WorldObj):
  140. def __init__(self, color, is_open=False):
  141. super().__init__('door', color)
  142. self.is_open = is_open
  143. def can_overlap(self):
  144. """The agent can only walk over this cell when the door is open"""
  145. return self.is_open
  146. def see_behind(self):
  147. return self.is_open
  148. def toggle(self, env, pos):
  149. self.is_open = not self.is_open
  150. return True
  151. def render(self, r):
  152. c = COLORS[self.color]
  153. r.setLineColor(c[0], c[1], c[2])
  154. r.setColor(0, 0, 0)
  155. if self.is_open:
  156. r.drawPolygon([
  157. (CELL_PIXELS-2, CELL_PIXELS),
  158. (CELL_PIXELS , CELL_PIXELS),
  159. (CELL_PIXELS , 0),
  160. (CELL_PIXELS-2, 0)
  161. ])
  162. return
  163. r.drawPolygon([
  164. (0 , CELL_PIXELS),
  165. (CELL_PIXELS, CELL_PIXELS),
  166. (CELL_PIXELS, 0),
  167. (0 , 0)
  168. ])
  169. r.drawPolygon([
  170. (2 , CELL_PIXELS-2),
  171. (CELL_PIXELS-2, CELL_PIXELS-2),
  172. (CELL_PIXELS-2, 2),
  173. (2 , 2)
  174. ])
  175. r.drawCircle(CELL_PIXELS * 0.75, CELL_PIXELS * 0.5, 2)
  176. class LockedDoor(WorldObj):
  177. def __init__(self, color, is_open=False):
  178. super(LockedDoor, self).__init__('locked_door', color)
  179. self.is_open = is_open
  180. def toggle(self, env, pos):
  181. # If the player has the right key to open the door
  182. if isinstance(env.carrying, Key) and env.carrying.color == self.color:
  183. self.is_open = True
  184. # The key has been used, remove it from the agent
  185. env.carrying = None
  186. return True
  187. return False
  188. def can_overlap(self):
  189. """The agent can only walk over this cell when the door is open"""
  190. return self.is_open
  191. def see_behind(self):
  192. return self.is_open
  193. def render(self, r):
  194. c = COLORS[self.color]
  195. r.setLineColor(c[0], c[1], c[2])
  196. r.setColor(c[0], c[1], c[2], 50)
  197. if self.is_open:
  198. r.drawPolygon([
  199. (CELL_PIXELS-2, CELL_PIXELS),
  200. (CELL_PIXELS , CELL_PIXELS),
  201. (CELL_PIXELS , 0),
  202. (CELL_PIXELS-2, 0)
  203. ])
  204. return
  205. r.drawPolygon([
  206. (0 , CELL_PIXELS),
  207. (CELL_PIXELS, CELL_PIXELS),
  208. (CELL_PIXELS, 0),
  209. (0 , 0)
  210. ])
  211. r.drawPolygon([
  212. (2 , CELL_PIXELS-2),
  213. (CELL_PIXELS-2, CELL_PIXELS-2),
  214. (CELL_PIXELS-2, 2),
  215. (2 , 2)
  216. ])
  217. r.drawLine(
  218. CELL_PIXELS * 0.55,
  219. CELL_PIXELS * 0.5,
  220. CELL_PIXELS * 0.75,
  221. CELL_PIXELS * 0.5
  222. )
  223. class Key(WorldObj):
  224. def __init__(self, color='blue'):
  225. super(Key, self).__init__('key', color)
  226. def can_pickup(self):
  227. return True
  228. def render(self, r):
  229. self._set_color(r)
  230. # Vertical quad
  231. r.drawPolygon([
  232. (16, 10),
  233. (20, 10),
  234. (20, 28),
  235. (16, 28)
  236. ])
  237. # Teeth
  238. r.drawPolygon([
  239. (12, 19),
  240. (16, 19),
  241. (16, 21),
  242. (12, 21)
  243. ])
  244. r.drawPolygon([
  245. (12, 26),
  246. (16, 26),
  247. (16, 28),
  248. (12, 28)
  249. ])
  250. r.drawCircle(18, 9, 6)
  251. r.setLineColor(0, 0, 0)
  252. r.setColor(0, 0, 0)
  253. r.drawCircle(18, 9, 2)
  254. class Ball(WorldObj):
  255. def __init__(self, color='blue'):
  256. super(Ball, self).__init__('ball', color)
  257. def can_pickup(self):
  258. return True
  259. def render(self, r):
  260. self._set_color(r)
  261. r.drawCircle(CELL_PIXELS * 0.5, CELL_PIXELS * 0.5, 10)
  262. class Box(WorldObj):
  263. def __init__(self, color, contains=None):
  264. super(Box, self).__init__('box', color)
  265. self.contains = contains
  266. def can_pickup(self):
  267. return True
  268. def render(self, r):
  269. c = COLORS[self.color]
  270. r.setLineColor(c[0], c[1], c[2])
  271. r.setColor(0, 0, 0)
  272. r.setLineWidth(2)
  273. r.drawPolygon([
  274. (4 , CELL_PIXELS-4),
  275. (CELL_PIXELS-4, CELL_PIXELS-4),
  276. (CELL_PIXELS-4, 4),
  277. (4 , 4)
  278. ])
  279. r.drawLine(
  280. 4,
  281. CELL_PIXELS / 2,
  282. CELL_PIXELS - 4,
  283. CELL_PIXELS / 2
  284. )
  285. r.setLineWidth(1)
  286. def toggle(self, env, pos):
  287. # Replace the box by its contents
  288. env.grid.set(*pos, self.contains)
  289. return True
  290. class Grid:
  291. """
  292. Represent a grid and operations on it
  293. """
  294. def __init__(self, width, height):
  295. assert width >= 4
  296. assert height >= 4
  297. self.width = width
  298. self.height = height
  299. self.grid = [None] * width * height
  300. def __contains__(self, key):
  301. if isinstance(key, WorldObj):
  302. for e in self.grid:
  303. if e is key:
  304. return True
  305. elif isinstance(key, tuple):
  306. for e in self.grid:
  307. if e is None:
  308. continue
  309. if (e.color, e.type) == key:
  310. return True
  311. if key[0] is None and key[1] == e.type:
  312. return True
  313. return False
  314. def __eq__(self, other):
  315. grid1 = self.encode()
  316. grid2 = other.encode()
  317. return np.array_equal(grid2, grid1)
  318. def __ne__(self, other):
  319. return not self == other
  320. def copy(self):
  321. from copy import deepcopy
  322. return deepcopy(self)
  323. def set(self, i, j, v):
  324. assert i >= 0 and i < self.width
  325. assert j >= 0 and j < self.height
  326. self.grid[j * self.width + i] = v
  327. def get(self, i, j):
  328. assert i >= 0 and i < self.width
  329. assert j >= 0 and j < self.height
  330. return self.grid[j * self.width + i]
  331. def horz_wall(self, x, y, length=None):
  332. if length is None:
  333. length = self.width - x
  334. for i in range(0, length):
  335. self.set(x + i, y, Wall())
  336. def vert_wall(self, x, y, length=None):
  337. if length is None:
  338. length = self.height - y
  339. for j in range(0, length):
  340. self.set(x, y + j, Wall())
  341. def wall_rect(self, x, y, w, h):
  342. self.horz_wall(x, y, w)
  343. self.horz_wall(x, y+h-1, w)
  344. self.vert_wall(x, y, h)
  345. self.vert_wall(x+w-1, y, h)
  346. def rotate_left(self):
  347. """
  348. Rotate the grid to the left (counter-clockwise)
  349. """
  350. grid = Grid(self.width, self.height)
  351. for j in range(0, self.height):
  352. for i in range(0, self.width):
  353. v = self.get(self.width - 1 - j, i)
  354. grid.set(i, j, v)
  355. return grid
  356. def slice(self, topX, topY, width, height):
  357. """
  358. Get a subset of the grid
  359. """
  360. grid = Grid(width, height)
  361. for j in range(0, height):
  362. for i in range(0, width):
  363. x = topX + i
  364. y = topY + j
  365. if x >= 0 and x < self.width and \
  366. y >= 0 and y < self.height:
  367. v = self.get(x, y)
  368. else:
  369. v = Wall()
  370. grid.set(i, j, v)
  371. return grid
  372. def render(self, r, tile_size):
  373. """
  374. Render this grid at a given scale
  375. :param r: target renderer object
  376. :param tile_size: tile size in pixels
  377. """
  378. assert r.width == self.width * tile_size
  379. assert r.height == self.height * tile_size
  380. # Total grid size at native scale
  381. widthPx = self.width * CELL_PIXELS
  382. heightPx = self.height * CELL_PIXELS
  383. r.push()
  384. # Internally, we draw at the "large" full-grid resolution, but we
  385. # use the renderer to scale back to the desired size
  386. r.scale(tile_size / CELL_PIXELS, tile_size / CELL_PIXELS)
  387. # Draw the background of the in-world cells black
  388. r.fillRect(
  389. 0,
  390. 0,
  391. widthPx,
  392. heightPx,
  393. 0, 0, 0
  394. )
  395. # Draw grid lines
  396. r.setLineColor(100, 100, 100)
  397. for rowIdx in range(0, self.height):
  398. y = CELL_PIXELS * rowIdx
  399. r.drawLine(0, y, widthPx, y)
  400. for colIdx in range(0, self.width):
  401. x = CELL_PIXELS * colIdx
  402. r.drawLine(x, 0, x, heightPx)
  403. # Render the grid
  404. for j in range(0, self.height):
  405. for i in range(0, self.width):
  406. cell = self.get(i, j)
  407. if cell == None:
  408. continue
  409. r.push()
  410. r.translate(i * CELL_PIXELS, j * CELL_PIXELS)
  411. cell.render(r)
  412. r.pop()
  413. r.pop()
  414. def encode(self):
  415. """
  416. Produce a compact numpy encoding of the grid
  417. """
  418. codeSize = self.width * self.height * 3
  419. array = np.zeros(shape=(self.width, self.height, 3), dtype='uint8')
  420. for j in range(0, self.height):
  421. for i in range(0, self.width):
  422. v = self.get(i, j)
  423. if v == None:
  424. continue
  425. array[i, j, 0] = OBJECT_TO_IDX[v.type]
  426. array[i, j, 1] = COLOR_TO_IDX[v.color]
  427. if hasattr(v, 'is_open') and v.is_open:
  428. array[i, j, 2] = 1
  429. return array
  430. def decode(array):
  431. """
  432. Decode an array grid encoding back into a grid
  433. """
  434. width = array.shape[0]
  435. height = array.shape[1]
  436. assert array.shape[2] == 3
  437. grid = Grid(width, height)
  438. for j in range(0, height):
  439. for i in range(0, width):
  440. typeIdx = array[i, j, 0]
  441. colorIdx = array[i, j, 1]
  442. openIdx = array[i, j, 2]
  443. if typeIdx == 0:
  444. continue
  445. objType = IDX_TO_OBJECT[typeIdx]
  446. color = IDX_TO_COLOR[colorIdx]
  447. is_open = True if openIdx == 1 else 0
  448. if objType == 'wall':
  449. v = Wall(color)
  450. elif objType == 'floor':
  451. v = Floor(color)
  452. elif objType == 'ball':
  453. v = Ball(color)
  454. elif objType == 'key':
  455. v = Key(color)
  456. elif objType == 'box':
  457. v = Box(color)
  458. elif objType == 'door':
  459. v = Door(color, is_open)
  460. elif objType == 'locked_door':
  461. v = LockedDoor(color, is_open)
  462. elif objType == 'goal':
  463. v = Goal()
  464. else:
  465. assert False, "unknown obj type in decode '%s'" % objType
  466. grid.set(i, j, v)
  467. return grid
  468. def process_vis(grid, agent_pos):
  469. mask = np.zeros(shape=(grid.width, grid.height), dtype=np.bool)
  470. mask[agent_pos[0], agent_pos[1]] = True
  471. for j in reversed(range(1, grid.height)):
  472. for i in range(0, grid.width-1):
  473. if not mask[i, j]:
  474. continue
  475. cell = grid.get(i, j)
  476. if cell and not cell.see_behind():
  477. continue
  478. mask[i+1, j] = True
  479. mask[i+1, j-1] = True
  480. mask[i, j-1] = True
  481. for i in reversed(range(1, grid.width)):
  482. if not mask[i, j]:
  483. continue
  484. cell = grid.get(i, j)
  485. if cell and not cell.see_behind():
  486. continue
  487. mask[i-1, j-1] = True
  488. mask[i-1, j] = True
  489. mask[i, j-1] = True
  490. for j in range(0, grid.height):
  491. for i in range(0, grid.width):
  492. if not mask[i, j]:
  493. grid.set(i, j, None)
  494. return mask
  495. class MiniGridEnv(gym.Env):
  496. """
  497. 2D grid world game environment
  498. """
  499. metadata = {
  500. 'render.modes': ['human', 'rgb_array', 'pixmap'],
  501. 'video.frames_per_second' : 10
  502. }
  503. # Enumeration of possible actions
  504. class Actions(IntEnum):
  505. # Turn left, turn right, move forward
  506. left = 0
  507. right = 1
  508. forward = 2
  509. # Pick up an object
  510. pickup = 3
  511. # Drop an object
  512. drop = 4
  513. # Toggle/activate an object
  514. toggle = 5
  515. # Done completing task
  516. done = 6
  517. def __init__(
  518. self,
  519. grid_size=16,
  520. max_steps=100,
  521. see_through_walls=False,
  522. seed=1337
  523. ):
  524. # Action enumeration for this environment
  525. self.actions = MiniGridEnv.Actions
  526. # Actions are discrete integer values
  527. self.action_space = spaces.Discrete(len(self.actions))
  528. # Observations are dictionaries containing an
  529. # encoding of the grid and a textual 'mission' string
  530. self.observation_space = spaces.Box(
  531. low=0,
  532. high=255,
  533. shape=OBS_ARRAY_SIZE,
  534. dtype='uint8'
  535. )
  536. self.observation_space = spaces.Dict({
  537. 'image': self.observation_space
  538. })
  539. # Range of possible rewards
  540. self.reward_range = (0, 1)
  541. # Renderer object used to render the whole grid (full-scale)
  542. self.grid_render = None
  543. # Renderer used to render observations (small-scale agent view)
  544. self.obs_render = None
  545. # Environment configuration
  546. self.grid_size = grid_size
  547. self.max_steps = max_steps
  548. self.see_through_walls = see_through_walls
  549. # Starting position and direction for the agent
  550. self.start_pos = None
  551. self.start_dir = None
  552. # Initialize the RNG
  553. self.seed(seed=seed)
  554. # Initialize the state
  555. self.reset()
  556. def reset(self):
  557. # Generate a new random grid at the start of each episode
  558. # To keep the same grid for each episode, call env.seed() with
  559. # the same seed before calling env.reset()
  560. self._gen_grid(self.grid_size, self.grid_size)
  561. # These fields should be defined by _gen_grid
  562. assert self.start_pos is not None
  563. assert self.start_dir is not None
  564. # Check that the agent doesn't overlap with an object
  565. start_cell = self.grid.get(*self.start_pos)
  566. assert start_cell is None or start_cell.can_overlap()
  567. # Place the agent in the starting position and direction
  568. self.agent_pos = self.start_pos
  569. self.agent_dir = self.start_dir
  570. # Item picked up, being carried, initially nothing
  571. self.carrying = None
  572. # Step count since episode start
  573. self.step_count = 0
  574. # Return first observation
  575. obs = self.gen_obs()
  576. return obs
  577. def seed(self, seed=1337):
  578. # Seed the random number generator
  579. self.np_random, _ = seeding.np_random(seed)
  580. return [seed]
  581. @property
  582. def steps_remaining(self):
  583. return self.max_steps - self.step_count
  584. def __str__(self):
  585. """
  586. Produce a pretty string of the environment's grid along with the agent.
  587. The agent is represented by `⏩`. A grid pixel is represented by 2-character
  588. string, the first one for the object and the second one for the color.
  589. """
  590. from copy import deepcopy
  591. def rotate_left(array):
  592. new_array = deepcopy(array)
  593. for i in range(len(array)):
  594. for j in range(len(array[0])):
  595. new_array[j][len(array[0])-1-i] = array[i][j]
  596. return new_array
  597. def vertically_symmetrize(array):
  598. new_array = deepcopy(array)
  599. for i in range(len(array)):
  600. for j in range(len(array[0])):
  601. new_array[i][len(array[0])-1-j] = array[i][j]
  602. return new_array
  603. # Map of object id to short string
  604. OBJECT_IDX_TO_IDS = {
  605. 0: ' ',
  606. 1: 'W',
  607. 2: 'D',
  608. 3: 'L',
  609. 4: 'K',
  610. 5: 'B',
  611. 6: 'X',
  612. 7: 'G'
  613. }
  614. # Short string for opened door
  615. OPENDED_DOOR_IDS = '_'
  616. # Map of color id to short string
  617. COLOR_IDX_TO_IDS = {
  618. 0: 'R',
  619. 1: 'G',
  620. 2: 'B',
  621. 3: 'P',
  622. 4: 'Y',
  623. 5: 'E'
  624. }
  625. # Map agent's direction to short string
  626. AGENT_DIR_TO_IDS = {
  627. 0: '⏩ ',
  628. 1: '⏬ ',
  629. 2: '⏪ ',
  630. 3: '⏫ '
  631. }
  632. array = self.grid.encode()
  633. array = rotate_left(array)
  634. array = vertically_symmetrize(array)
  635. new_array = []
  636. for line in array:
  637. new_line = []
  638. for pixel in line:
  639. # If the door is opened
  640. if pixel[0] in [2, 3] and pixel[2] == 1:
  641. object_ids = OPENDED_DOOR_IDS
  642. else:
  643. object_ids = OBJECT_IDX_TO_IDS[pixel[0]]
  644. # If no object
  645. if pixel[0] == 0:
  646. color_ids = ' '
  647. else:
  648. color_ids = COLOR_IDX_TO_IDS[pixel[1]]
  649. new_line.append(object_ids + color_ids)
  650. new_array.append(new_line)
  651. # Add the agent
  652. new_array[self.agent_pos[1]][self.agent_pos[0]] = AGENT_DIR_TO_IDS[self.agent_dir]
  653. return "\n".join([" ".join(line) for line in new_array])
  654. def _gen_grid(self, width, height):
  655. assert False, "_gen_grid needs to be implemented by each environment"
  656. def _reward(self):
  657. """
  658. Compute the reward to be given upon success
  659. """
  660. return 1 - 0.9 * (self.step_count / self.max_steps)
  661. def _rand_int(self, low, high):
  662. """
  663. Generate random integer in [low,high[
  664. """
  665. return self.np_random.randint(low, high)
  666. def _rand_float(self, low, high):
  667. """
  668. Generate random float in [low,high[
  669. """
  670. return self.np_random.uniform(low, high)
  671. def _rand_bool(self):
  672. """
  673. Generate random boolean value
  674. """
  675. return (self.np_random.randint(0, 2) == 0)
  676. def _rand_elem(self, iterable):
  677. """
  678. Pick a random element in a list
  679. """
  680. lst = list(iterable)
  681. idx = self._rand_int(0, len(lst))
  682. return lst[idx]
  683. def _rand_subset(self, iterable, num_elems):
  684. """
  685. Sample a random subset of distinct elements of a list
  686. """
  687. lst = list(iterable)
  688. assert num_elems <= len(lst)
  689. out = []
  690. while len(out) < num_elems:
  691. elem = self._rand_elem(lst)
  692. lst.remove(elem)
  693. out.append(elem)
  694. return out
  695. def _rand_color(self):
  696. """
  697. Generate a random color name (string)
  698. """
  699. return self._rand_elem(COLOR_NAMES)
  700. def _rand_pos(self, xLow, xHigh, yLow, yHigh):
  701. """
  702. Generate a random (x,y) position tuple
  703. """
  704. return (
  705. self.np_random.randint(xLow, xHigh),
  706. self.np_random.randint(yLow, yHigh)
  707. )
  708. def place_obj(self,
  709. obj,
  710. top=None,
  711. size=None,
  712. reject_fn=None,
  713. max_tries=math.inf
  714. ):
  715. """
  716. Place an object at an empty position in the grid
  717. :param top: top-left position of the rectangle where to place
  718. :param size: size of the rectangle where to place
  719. :param reject_fn: function to filter out potential positions
  720. """
  721. if top is None:
  722. top = (0, 0)
  723. if size is None:
  724. size = (self.grid.width, self.grid.height)
  725. num_tries = 0
  726. while True:
  727. # This is to handle with rare cases where rejection sampling
  728. # gets stuck in an infinite loop
  729. if num_tries > max_tries:
  730. raise RecursionError('rejection sampling failed in place_obj')
  731. num_tries += 1
  732. pos = np.array((
  733. self._rand_int(top[0], top[0] + size[0]),
  734. self._rand_int(top[1], top[1] + size[1])
  735. ))
  736. # Don't place the object on top of another object
  737. if self.grid.get(*pos) != None:
  738. continue
  739. # Don't place the object where the agent is
  740. if np.array_equal(pos, self.start_pos):
  741. continue
  742. # Check if there is a filtering criterion
  743. if reject_fn and reject_fn(self, pos):
  744. continue
  745. break
  746. self.grid.set(*pos, obj)
  747. if obj is not None:
  748. obj.init_pos = pos
  749. obj.cur_pos = pos
  750. return pos
  751. def place_agent(
  752. self,
  753. top=None,
  754. size=None,
  755. rand_dir=True,
  756. max_tries=math.inf
  757. ):
  758. """
  759. Set the agent's starting point at an empty position in the grid
  760. """
  761. self.start_pos = None
  762. pos = self.place_obj(None, top, size, max_tries=max_tries)
  763. self.start_pos = pos
  764. if rand_dir:
  765. self.start_dir = self._rand_int(0, 4)
  766. return pos
  767. @property
  768. def dir_vec(self):
  769. """
  770. Get the direction vector for the agent, pointing in the direction
  771. of forward movement.
  772. """
  773. assert self.agent_dir >= 0 and self.agent_dir < 4
  774. return DIR_TO_VEC[self.agent_dir]
  775. @property
  776. def right_vec(self):
  777. """
  778. Get the vector pointing to the right of the agent.
  779. """
  780. dx, dy = self.dir_vec
  781. return np.array((-dy, dx))
  782. @property
  783. def front_pos(self):
  784. """
  785. Get the position of the cell that is right in front of the agent
  786. """
  787. return self.agent_pos + self.dir_vec
  788. def get_view_coords(self, i, j):
  789. """
  790. Translate and rotate absolute grid coordinates (i, j) into the
  791. agent's partially observable view (sub-grid). Note that the resulting
  792. coordinates may be negative or outside of the agent's view size.
  793. """
  794. ax, ay = self.agent_pos
  795. dx, dy = self.dir_vec
  796. rx, ry = self.right_vec
  797. # Compute the absolute coordinates of the top-left view corner
  798. sz = AGENT_VIEW_SIZE
  799. hs = AGENT_VIEW_SIZE // 2
  800. tx = ax + (dx * (sz-1)) - (rx * hs)
  801. ty = ay + (dy * (sz-1)) - (ry * hs)
  802. lx = i - tx
  803. ly = j - ty
  804. # Project the coordinates of the object relative to the top-left
  805. # corner onto the agent's own coordinate system
  806. vx = (rx*lx + ry*ly)
  807. vy = -(dx*lx + dy*ly)
  808. return vx, vy
  809. def get_view_exts(self):
  810. """
  811. Get the extents of the square set of tiles visible to the agent
  812. Note: the bottom extent indices are not included in the set
  813. """
  814. # Facing right
  815. if self.agent_dir == 0:
  816. topX = self.agent_pos[0]
  817. topY = self.agent_pos[1] - AGENT_VIEW_SIZE // 2
  818. # Facing down
  819. elif self.agent_dir == 1:
  820. topX = self.agent_pos[0] - AGENT_VIEW_SIZE // 2
  821. topY = self.agent_pos[1]
  822. # Facing left
  823. elif self.agent_dir == 2:
  824. topX = self.agent_pos[0] - AGENT_VIEW_SIZE + 1
  825. topY = self.agent_pos[1] - AGENT_VIEW_SIZE // 2
  826. # Facing up
  827. elif self.agent_dir == 3:
  828. topX = self.agent_pos[0] - AGENT_VIEW_SIZE // 2
  829. topY = self.agent_pos[1] - AGENT_VIEW_SIZE + 1
  830. else:
  831. assert False, "invalid agent direction"
  832. botX = topX + AGENT_VIEW_SIZE
  833. botY = topY + AGENT_VIEW_SIZE
  834. return (topX, topY, botX, botY)
  835. def agent_sees(self, x, y):
  836. """
  837. Check if a grid position is visible to the agent
  838. """
  839. vx, vy = self.get_view_coords(x, y)
  840. if vx < 0 or vy < 0 or vx >= AGENT_VIEW_SIZE or vy >= AGENT_VIEW_SIZE:
  841. return False
  842. obs = self.gen_obs()
  843. obs_grid = Grid.decode(obs['image'])
  844. obs_cell = obs_grid.get(vx, vy)
  845. world_cell = self.grid.get(x, y)
  846. return obs_cell is not None and obs_cell.type == world_cell.type
  847. def step(self, action):
  848. self.step_count += 1
  849. reward = 0
  850. done = False
  851. # Get the position in front of the agent
  852. fwd_pos = self.front_pos
  853. # Get the contents of the cell in front of the agent
  854. fwd_cell = self.grid.get(*fwd_pos)
  855. # Rotate left
  856. if action == self.actions.left:
  857. self.agent_dir -= 1
  858. if self.agent_dir < 0:
  859. self.agent_dir += 4
  860. # Rotate right
  861. elif action == self.actions.right:
  862. self.agent_dir = (self.agent_dir + 1) % 4
  863. # Move forward
  864. elif action == self.actions.forward:
  865. if fwd_cell == None or fwd_cell.can_overlap():
  866. self.agent_pos = fwd_pos
  867. if fwd_cell != None and fwd_cell.type == 'goal':
  868. done = True
  869. reward = self._reward()
  870. # Pick up an object
  871. elif action == self.actions.pickup:
  872. if fwd_cell and fwd_cell.can_pickup():
  873. if self.carrying is None:
  874. self.carrying = fwd_cell
  875. self.carrying.cur_pos = np.array([-1, -1])
  876. self.grid.set(*fwd_pos, None)
  877. # Drop an object
  878. elif action == self.actions.drop:
  879. if not fwd_cell and self.carrying:
  880. self.grid.set(*fwd_pos, self.carrying)
  881. self.carrying.cur_pos = fwd_pos
  882. self.carrying = None
  883. # Toggle/activate an object
  884. elif action == self.actions.toggle:
  885. if fwd_cell:
  886. fwd_cell.toggle(self, fwd_pos)
  887. # Done action (not used by default)
  888. elif action == self.actions.done:
  889. pass
  890. else:
  891. assert False, "unknown action"
  892. if self.step_count >= self.max_steps:
  893. done = True
  894. obs = self.gen_obs()
  895. return obs, reward, done, {}
  896. def gen_obs_grid(self):
  897. """
  898. Generate the sub-grid observed by the agent.
  899. This method also outputs a visibility mask telling us which grid
  900. cells the agent can actually see.
  901. """
  902. topX, topY, botX, botY = self.get_view_exts()
  903. grid = self.grid.slice(topX, topY, AGENT_VIEW_SIZE, AGENT_VIEW_SIZE)
  904. for i in range(self.agent_dir + 1):
  905. grid = grid.rotate_left()
  906. # Process occluders and visibility
  907. # Note that this incurs some performance cost
  908. if not self.see_through_walls:
  909. vis_mask = grid.process_vis(agent_pos=(AGENT_VIEW_SIZE // 2 , AGENT_VIEW_SIZE - 1))
  910. else:
  911. vis_mask = np.ones(shape=(grid.width, grid.height), dtype=np.bool)
  912. # Make it so the agent sees what it's carrying
  913. # We do this by placing the carried object at the agent's position
  914. # in the agent's partially observable view
  915. agent_pos = grid.width // 2, grid.height - 1
  916. if self.carrying:
  917. grid.set(*agent_pos, self.carrying)
  918. else:
  919. grid.set(*agent_pos, None)
  920. return grid, vis_mask
  921. def gen_obs(self):
  922. """
  923. Generate the agent's view (partially observable, low-resolution encoding)
  924. """
  925. grid, vis_mask = self.gen_obs_grid()
  926. # Encode the partially observable view into a numpy array
  927. image = grid.encode()
  928. assert hasattr(self, 'mission'), "environments must define a textual mission string"
  929. # Observations are dictionaries containing:
  930. # - an image (partially observable view of the environment)
  931. # - the agent's direction/orientation (acting as a compass)
  932. # - a textual mission string (instructions for the agent)
  933. obs = {
  934. 'image': image,
  935. 'direction': self.agent_dir,
  936. 'mission': self.mission
  937. }
  938. return obs
  939. def get_obs_render(self, obs, tile_pixels=CELL_PIXELS//2):
  940. """
  941. Render an agent observation for visualization
  942. """
  943. if self.obs_render == None:
  944. from gym_minigrid.rendering import Renderer
  945. self.obs_render = Renderer(
  946. AGENT_VIEW_SIZE * tile_pixels,
  947. AGENT_VIEW_SIZE * tile_pixels
  948. )
  949. r = self.obs_render
  950. r.beginFrame()
  951. grid = Grid.decode(obs)
  952. # Render the whole grid
  953. grid.render(r, tile_pixels)
  954. # Draw the agent
  955. ratio = tile_pixels / CELL_PIXELS
  956. r.push()
  957. r.scale(ratio, ratio)
  958. r.translate(
  959. CELL_PIXELS * (0.5 + AGENT_VIEW_SIZE // 2),
  960. CELL_PIXELS * (AGENT_VIEW_SIZE - 0.5)
  961. )
  962. r.rotate(3 * 90)
  963. r.setLineColor(255, 0, 0)
  964. r.setColor(255, 0, 0)
  965. r.drawPolygon([
  966. (-12, 10),
  967. ( 12, 0),
  968. (-12, -10)
  969. ])
  970. r.pop()
  971. r.endFrame()
  972. return r.getPixmap()
  973. def render(self, mode='human', close=False):
  974. """
  975. Render the whole-grid human view
  976. """
  977. if close:
  978. if self.grid_render:
  979. self.grid_render.close()
  980. return
  981. if self.grid_render is None:
  982. from gym_minigrid.rendering import Renderer
  983. self.grid_render = Renderer(
  984. self.grid_size * CELL_PIXELS,
  985. self.grid_size * CELL_PIXELS,
  986. True if mode == 'human' else False
  987. )
  988. r = self.grid_render
  989. r.beginFrame()
  990. # Render the whole grid
  991. self.grid.render(r, CELL_PIXELS)
  992. # Draw the agent
  993. r.push()
  994. r.translate(
  995. CELL_PIXELS * (self.agent_pos[0] + 0.5),
  996. CELL_PIXELS * (self.agent_pos[1] + 0.5)
  997. )
  998. r.rotate(self.agent_dir * 90)
  999. r.setLineColor(255, 0, 0)
  1000. r.setColor(255, 0, 0)
  1001. r.drawPolygon([
  1002. (-12, 10),
  1003. ( 12, 0),
  1004. (-12, -10)
  1005. ])
  1006. r.pop()
  1007. # Compute which cells are visible to the agent
  1008. _, vis_mask = self.gen_obs_grid()
  1009. # Compute the absolute coordinates of the bottom-left corner
  1010. # of the agent's view area
  1011. f_vec = self.dir_vec
  1012. r_vec = self.right_vec
  1013. top_left = self.agent_pos + f_vec * (AGENT_VIEW_SIZE-1) - r_vec * (AGENT_VIEW_SIZE // 2)
  1014. # For each cell in the visibility mask
  1015. for vis_j in range(0, AGENT_VIEW_SIZE):
  1016. for vis_i in range(0, AGENT_VIEW_SIZE):
  1017. # If this cell is not visible, don't highlight it
  1018. if not vis_mask[vis_i, vis_j]:
  1019. continue
  1020. # Compute the world coordinates of this cell
  1021. abs_i, abs_j = top_left - (f_vec * vis_j) + (r_vec * vis_i)
  1022. # Highlight the cell
  1023. r.fillRect(
  1024. abs_i * CELL_PIXELS,
  1025. abs_j * CELL_PIXELS,
  1026. CELL_PIXELS,
  1027. CELL_PIXELS,
  1028. 255, 255, 255, 75
  1029. )
  1030. r.endFrame()
  1031. if mode == 'rgb_array':
  1032. return r.getArray()
  1033. elif mode == 'pixmap':
  1034. return r.getPixmap()
  1035. return r