minigrid.py 36 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331
  1. import math
  2. import gym
  3. from enum import IntEnum
  4. import numpy as np
  5. from gym import error, spaces, utils
  6. from gym.utils import seeding
  7. # Size in pixels of a cell in the full-scale human view
  8. CELL_PIXELS = 32
  9. # Number of cells (width and height) in the agent view
  10. AGENT_VIEW_SIZE = 7
  11. # Size of the array given as an observation to the agent
  12. OBS_ARRAY_SIZE = (AGENT_VIEW_SIZE, AGENT_VIEW_SIZE, 3)
  13. # Map of color names to RGB values
  14. COLORS = {
  15. 'red' : np.array([255, 0, 0]),
  16. 'green' : np.array([0, 255, 0]),
  17. 'blue' : np.array([0, 0, 255]),
  18. 'purple': np.array([112, 39, 195]),
  19. 'yellow': np.array([255, 255, 0]),
  20. 'grey' : np.array([100, 100, 100])
  21. }
  22. COLOR_NAMES = sorted(list(COLORS.keys()))
  23. # Used to map colors to integers
  24. COLOR_TO_IDX = {
  25. 'red' : 0,
  26. 'green' : 1,
  27. 'blue' : 2,
  28. 'purple': 3,
  29. 'yellow': 4,
  30. 'grey' : 5
  31. }
  32. IDX_TO_COLOR = dict(zip(COLOR_TO_IDX.values(), COLOR_TO_IDX.keys()))
  33. # Map of object type to integers
  34. OBJECT_TO_IDX = {
  35. 'unseen' : 0,
  36. 'empty' : 1,
  37. 'wall' : 2,
  38. 'floor' : 3,
  39. 'door' : 4,
  40. 'key' : 5,
  41. 'ball' : 6,
  42. 'box' : 7,
  43. 'goal' : 8,
  44. 'lava' : 9
  45. }
  46. IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys()))
  47. # Map of agent direction indices to vectors
  48. DIR_TO_VEC = [
  49. # Pointing right (positive X)
  50. np.array((1, 0)),
  51. # Down (positive Y)
  52. np.array((0, 1)),
  53. # Pointing left (negative X)
  54. np.array((-1, 0)),
  55. # Up (negative Y)
  56. np.array((0, -1)),
  57. ]
  58. class WorldObj:
  59. """
  60. Base class for grid world objects
  61. """
  62. def __init__(self, type, color):
  63. assert type in OBJECT_TO_IDX, type
  64. assert color in COLOR_TO_IDX, color
  65. self.type = type
  66. self.color = color
  67. self.contains = None
  68. # Initial position of the object
  69. self.init_pos = None
  70. # Current position of the object
  71. self.cur_pos = None
  72. def can_overlap(self):
  73. """Can the agent overlap with this?"""
  74. return False
  75. def can_pickup(self):
  76. """Can the agent pick this up?"""
  77. return False
  78. def can_contain(self):
  79. """Can this contain another object?"""
  80. return False
  81. def see_behind(self):
  82. """Can the agent see behind this object?"""
  83. return True
  84. def toggle(self, env, pos):
  85. """Method to trigger/toggle an action this object performs"""
  86. return False
  87. def render(self, r):
  88. """Draw this object with the given renderer"""
  89. raise NotImplementedError
  90. def _set_color(self, r):
  91. """Set the color of this object as the active drawing color"""
  92. c = COLORS[self.color]
  93. r.setLineColor(c[0], c[1], c[2])
  94. r.setColor(c[0], c[1], c[2])
  95. class Goal(WorldObj):
  96. def __init__(self):
  97. super().__init__('goal', 'green')
  98. def can_overlap(self):
  99. return True
  100. def render(self, r):
  101. self._set_color(r)
  102. r.drawPolygon([
  103. (0 , CELL_PIXELS),
  104. (CELL_PIXELS, CELL_PIXELS),
  105. (CELL_PIXELS, 0),
  106. (0 , 0)
  107. ])
  108. class Floor(WorldObj):
  109. """
  110. Colored floor tile the agent can walk over
  111. """
  112. def __init__(self, color='blue'):
  113. super().__init__('floor', color)
  114. def can_overlap(self):
  115. return True
  116. def render(self, r):
  117. # Give the floor a pale color
  118. c = COLORS[self.color]
  119. r.setLineColor(100, 100, 100, 0)
  120. r.setColor(*c/2)
  121. r.drawPolygon([
  122. (1 , CELL_PIXELS),
  123. (CELL_PIXELS, CELL_PIXELS),
  124. (CELL_PIXELS, 1),
  125. (1 , 1)
  126. ])
  127. class Lava(WorldObj):
  128. def __init__(self):
  129. super().__init__('lava', 'red')
  130. def can_overlap(self):
  131. return True
  132. def render(self, r):
  133. orange = 255, 128, 0
  134. r.setLineColor(*orange)
  135. r.setColor(*orange)
  136. r.drawPolygon([
  137. (0 , CELL_PIXELS),
  138. (CELL_PIXELS, CELL_PIXELS),
  139. (CELL_PIXELS, 0),
  140. (0 , 0)
  141. ])
  142. # drawing the waves
  143. r.setLineColor(0, 0, 0)
  144. r.drawPolyline([
  145. (.1 * CELL_PIXELS, .3 * CELL_PIXELS),
  146. (.3 * CELL_PIXELS, .4 * CELL_PIXELS),
  147. (.5 * CELL_PIXELS, .3 * CELL_PIXELS),
  148. (.7 * CELL_PIXELS, .4 * CELL_PIXELS),
  149. (.9 * CELL_PIXELS, .3 * CELL_PIXELS),
  150. ])
  151. r.drawPolyline([
  152. (.1 * CELL_PIXELS, .5 * CELL_PIXELS),
  153. (.3 * CELL_PIXELS, .6 * CELL_PIXELS),
  154. (.5 * CELL_PIXELS, .5 * CELL_PIXELS),
  155. (.7 * CELL_PIXELS, .6 * CELL_PIXELS),
  156. (.9 * CELL_PIXELS, .5 * CELL_PIXELS),
  157. ])
  158. r.drawPolyline([
  159. (.1 * CELL_PIXELS, .7 * CELL_PIXELS),
  160. (.3 * CELL_PIXELS, .8 * CELL_PIXELS),
  161. (.5 * CELL_PIXELS, .7 * CELL_PIXELS),
  162. (.7 * CELL_PIXELS, .8 * CELL_PIXELS),
  163. (.9 * CELL_PIXELS, .7 * CELL_PIXELS),
  164. ])
  165. class Wall(WorldObj):
  166. def __init__(self, color='grey'):
  167. super().__init__('wall', color)
  168. def see_behind(self):
  169. return False
  170. def render(self, r):
  171. self._set_color(r)
  172. r.drawPolygon([
  173. (0 , CELL_PIXELS),
  174. (CELL_PIXELS, CELL_PIXELS),
  175. (CELL_PIXELS, 0),
  176. (0 , 0)
  177. ])
  178. class Door(WorldObj):
  179. def __init__(self, color, is_open=False, is_locked=False):
  180. super().__init__('door', color)
  181. self.is_open = is_open
  182. self.is_locked = is_locked
  183. def can_overlap(self):
  184. """The agent can only walk over this cell when the door is open"""
  185. return self.is_open
  186. def see_behind(self):
  187. return self.is_open
  188. def toggle(self, env, pos):
  189. # If the player has the right key to open the door
  190. if self.is_locked:
  191. if isinstance(env.carrying, Key) and env.carrying.color == self.color:
  192. self.is_locked = False
  193. self.is_open = True
  194. # The key has been used, remove it from the agent
  195. env.carrying = None
  196. return True
  197. return False
  198. self.is_open = not self.is_open
  199. return True
  200. def render(self, r):
  201. c = COLORS[self.color]
  202. r.setLineColor(c[0], c[1], c[2])
  203. r.setColor(c[0], c[1], c[2], 50 if self.is_locked else 0)
  204. if self.is_open:
  205. r.drawPolygon([
  206. (CELL_PIXELS-2, CELL_PIXELS),
  207. (CELL_PIXELS , CELL_PIXELS),
  208. (CELL_PIXELS , 0),
  209. (CELL_PIXELS-2, 0)
  210. ])
  211. return
  212. r.drawPolygon([
  213. (0 , CELL_PIXELS),
  214. (CELL_PIXELS, CELL_PIXELS),
  215. (CELL_PIXELS, 0),
  216. (0 , 0)
  217. ])
  218. r.drawPolygon([
  219. (2 , CELL_PIXELS-2),
  220. (CELL_PIXELS-2, CELL_PIXELS-2),
  221. (CELL_PIXELS-2, 2),
  222. (2 , 2)
  223. ])
  224. if self.is_locked:
  225. # Draw key slot
  226. r.drawLine(
  227. CELL_PIXELS * 0.55,
  228. CELL_PIXELS * 0.5,
  229. CELL_PIXELS * 0.75,
  230. CELL_PIXELS * 0.5
  231. )
  232. else:
  233. # Draw door handle
  234. r.drawCircle(CELL_PIXELS * 0.75, CELL_PIXELS * 0.5, 2)
  235. class Key(WorldObj):
  236. def __init__(self, color='blue'):
  237. super(Key, self).__init__('key', color)
  238. def can_pickup(self):
  239. return True
  240. def render(self, r):
  241. self._set_color(r)
  242. # Vertical quad
  243. r.drawPolygon([
  244. (16, 10),
  245. (20, 10),
  246. (20, 28),
  247. (16, 28)
  248. ])
  249. # Teeth
  250. r.drawPolygon([
  251. (12, 19),
  252. (16, 19),
  253. (16, 21),
  254. (12, 21)
  255. ])
  256. r.drawPolygon([
  257. (12, 26),
  258. (16, 26),
  259. (16, 28),
  260. (12, 28)
  261. ])
  262. r.drawCircle(18, 9, 6)
  263. r.setLineColor(0, 0, 0)
  264. r.setColor(0, 0, 0)
  265. r.drawCircle(18, 9, 2)
  266. class Ball(WorldObj):
  267. def __init__(self, color='blue'):
  268. super(Ball, self).__init__('ball', color)
  269. def can_pickup(self):
  270. return True
  271. def render(self, r):
  272. self._set_color(r)
  273. r.drawCircle(CELL_PIXELS * 0.5, CELL_PIXELS * 0.5, 10)
  274. class Box(WorldObj):
  275. def __init__(self, color, contains=None):
  276. super(Box, self).__init__('box', color)
  277. self.contains = contains
  278. def can_pickup(self):
  279. return True
  280. def render(self, r):
  281. c = COLORS[self.color]
  282. r.setLineColor(c[0], c[1], c[2])
  283. r.setColor(0, 0, 0)
  284. r.setLineWidth(2)
  285. r.drawPolygon([
  286. (4 , CELL_PIXELS-4),
  287. (CELL_PIXELS-4, CELL_PIXELS-4),
  288. (CELL_PIXELS-4, 4),
  289. (4 , 4)
  290. ])
  291. r.drawLine(
  292. 4,
  293. CELL_PIXELS / 2,
  294. CELL_PIXELS - 4,
  295. CELL_PIXELS / 2
  296. )
  297. r.setLineWidth(1)
  298. def toggle(self, env, pos):
  299. # Replace the box by its contents
  300. env.grid.set(*pos, self.contains)
  301. return True
  302. class Grid:
  303. """
  304. Represent a grid and operations on it
  305. """
  306. def __init__(self, width, height):
  307. assert width >= 4
  308. assert height >= 4
  309. self.width = width
  310. self.height = height
  311. self.grid = [None] * width * height
  312. def __contains__(self, key):
  313. if isinstance(key, WorldObj):
  314. for e in self.grid:
  315. if e is key:
  316. return True
  317. elif isinstance(key, tuple):
  318. for e in self.grid:
  319. if e is None:
  320. continue
  321. if (e.color, e.type) == key:
  322. return True
  323. if key[0] is None and key[1] == e.type:
  324. return True
  325. return False
  326. def __eq__(self, other):
  327. grid1 = self.encode()
  328. grid2 = other.encode()
  329. return np.array_equal(grid2, grid1)
  330. def __ne__(self, other):
  331. return not self == other
  332. def copy(self):
  333. from copy import deepcopy
  334. return deepcopy(self)
  335. def set(self, i, j, v):
  336. assert i >= 0 and i < self.width
  337. assert j >= 0 and j < self.height
  338. self.grid[j * self.width + i] = v
  339. def get(self, i, j):
  340. assert i >= 0 and i < self.width
  341. assert j >= 0 and j < self.height
  342. return self.grid[j * self.width + i]
  343. def horz_wall(self, x, y, length=None):
  344. if length is None:
  345. length = self.width - x
  346. for i in range(0, length):
  347. self.set(x + i, y, Wall())
  348. def vert_wall(self, x, y, length=None):
  349. if length is None:
  350. length = self.height - y
  351. for j in range(0, length):
  352. self.set(x, y + j, Wall())
  353. def wall_rect(self, x, y, w, h):
  354. self.horz_wall(x, y, w)
  355. self.horz_wall(x, y+h-1, w)
  356. self.vert_wall(x, y, h)
  357. self.vert_wall(x+w-1, y, h)
  358. def rotate_left(self):
  359. """
  360. Rotate the grid to the left (counter-clockwise)
  361. """
  362. grid = Grid(self.height, self.width)
  363. for i in range(self.width):
  364. for j in range(self.height):
  365. v = self.get(i, j)
  366. grid.set(j, grid.height - 1 - i, v)
  367. return grid
  368. def slice(self, topX, topY, width, height):
  369. """
  370. Get a subset of the grid
  371. """
  372. grid = Grid(width, height)
  373. for j in range(0, height):
  374. for i in range(0, width):
  375. x = topX + i
  376. y = topY + j
  377. if x >= 0 and x < self.width and \
  378. y >= 0 and y < self.height:
  379. v = self.get(x, y)
  380. else:
  381. v = Wall()
  382. grid.set(i, j, v)
  383. return grid
  384. def render(self, r, tile_size):
  385. """
  386. Render this grid at a given scale
  387. :param r: target renderer object
  388. :param tile_size: tile size in pixels
  389. """
  390. assert r.width == self.width * tile_size
  391. assert r.height == self.height * tile_size
  392. # Total grid size at native scale
  393. widthPx = self.width * CELL_PIXELS
  394. heightPx = self.height * CELL_PIXELS
  395. r.push()
  396. # Internally, we draw at the "large" full-grid resolution, but we
  397. # use the renderer to scale back to the desired size
  398. r.scale(tile_size / CELL_PIXELS, tile_size / CELL_PIXELS)
  399. # Draw the background of the in-world cells black
  400. r.fillRect(
  401. 0,
  402. 0,
  403. widthPx,
  404. heightPx,
  405. 0, 0, 0
  406. )
  407. # Draw grid lines
  408. r.setLineColor(100, 100, 100)
  409. for rowIdx in range(0, self.height):
  410. y = CELL_PIXELS * rowIdx
  411. r.drawLine(0, y, widthPx, y)
  412. for colIdx in range(0, self.width):
  413. x = CELL_PIXELS * colIdx
  414. r.drawLine(x, 0, x, heightPx)
  415. # Render the grid
  416. for j in range(0, self.height):
  417. for i in range(0, self.width):
  418. cell = self.get(i, j)
  419. if cell == None:
  420. continue
  421. r.push()
  422. r.translate(i * CELL_PIXELS, j * CELL_PIXELS)
  423. cell.render(r)
  424. r.pop()
  425. r.pop()
  426. def encode(self, vis_mask=None):
  427. """
  428. Produce a compact numpy encoding of the grid
  429. """
  430. if vis_mask is None:
  431. vis_mask = np.ones((self.width, self.height), dtype=bool)
  432. array = np.zeros((self.width, self.height, 3), dtype='uint8')
  433. for i in range(self.width):
  434. for j in range(self.height):
  435. if vis_mask[i, j]:
  436. v = self.get(i, j)
  437. if v is None:
  438. array[i, j, 0] = OBJECT_TO_IDX['empty']
  439. array[i, j, 1] = 0
  440. array[i, j, 2] = 0
  441. else:
  442. # State, 0: open, 1: closed, 2: locked
  443. state = 0
  444. if hasattr(v, 'is_open') and not v.is_open:
  445. state = 1
  446. if hasattr(v, 'is_locked') and v.is_locked:
  447. state = 2
  448. array[i, j, 0] = OBJECT_TO_IDX[v.type]
  449. array[i, j, 1] = COLOR_TO_IDX[v.color]
  450. array[i, j, 2] = state
  451. return array
  452. @staticmethod
  453. def decode(array):
  454. """
  455. Decode an array grid encoding back into a grid
  456. """
  457. width, height, channels = array.shape
  458. assert channels == 3
  459. grid = Grid(width, height)
  460. for i in range(width):
  461. for j in range(height):
  462. typeIdx, colorIdx, state = array[i, j]
  463. if typeIdx == OBJECT_TO_IDX['unseen'] or \
  464. typeIdx == OBJECT_TO_IDX['empty']:
  465. continue
  466. objType = IDX_TO_OBJECT[typeIdx]
  467. color = IDX_TO_COLOR[colorIdx]
  468. # State, 0: open, 1: closed, 2: locked
  469. is_open = state == 0
  470. is_locked = state == 2
  471. if objType == 'wall':
  472. v = Wall(color)
  473. elif objType == 'floor':
  474. v = Floor(color)
  475. elif objType == 'ball':
  476. v = Ball(color)
  477. elif objType == 'key':
  478. v = Key(color)
  479. elif objType == 'box':
  480. v = Box(color)
  481. elif objType == 'door':
  482. v = Door(color, is_open, is_locked)
  483. elif objType == 'goal':
  484. v = Goal()
  485. elif objType == 'lava':
  486. v = Lava()
  487. else:
  488. assert False, "unknown obj type in decode '%s'" % objType
  489. grid.set(i, j, v)
  490. return grid
  491. def process_vis(grid, agent_pos):
  492. mask = np.zeros(shape=(grid.width, grid.height), dtype=np.bool)
  493. mask[agent_pos[0], agent_pos[1]] = True
  494. for j in reversed(range(0, grid.height)):
  495. for i in range(0, grid.width-1):
  496. if not mask[i, j]:
  497. continue
  498. cell = grid.get(i, j)
  499. if cell and not cell.see_behind():
  500. continue
  501. mask[i+1, j] = True
  502. if j > 0:
  503. mask[i+1, j-1] = True
  504. mask[i, j-1] = True
  505. for i in reversed(range(1, grid.width)):
  506. if not mask[i, j]:
  507. continue
  508. cell = grid.get(i, j)
  509. if cell and not cell.see_behind():
  510. continue
  511. mask[i-1, j] = True
  512. if j > 0:
  513. mask[i-1, j-1] = True
  514. mask[i, j-1] = True
  515. for j in range(0, grid.height):
  516. for i in range(0, grid.width):
  517. if not mask[i, j]:
  518. grid.set(i, j, None)
  519. return mask
  520. class MiniGridEnv(gym.Env):
  521. """
  522. 2D grid world game environment
  523. """
  524. metadata = {
  525. 'render.modes': ['human', 'rgb_array', 'pixmap'],
  526. 'video.frames_per_second' : 10
  527. }
  528. # Enumeration of possible actions
  529. class Actions(IntEnum):
  530. # Turn left, turn right, move forward
  531. left = 0
  532. right = 1
  533. forward = 2
  534. # Pick up an object
  535. pickup = 3
  536. # Drop an object
  537. drop = 4
  538. # Toggle/activate an object
  539. toggle = 5
  540. # Done completing task
  541. done = 6
  542. def __init__(
  543. self,
  544. grid_size=16,
  545. max_steps=100,
  546. see_through_walls=False,
  547. seed=1337
  548. ):
  549. # Action enumeration for this environment
  550. self.actions = MiniGridEnv.Actions
  551. # Actions are discrete integer values
  552. self.action_space = spaces.Discrete(len(self.actions))
  553. # Observations are dictionaries containing an
  554. # encoding of the grid and a textual 'mission' string
  555. self.observation_space = spaces.Box(
  556. low=0,
  557. high=255,
  558. shape=OBS_ARRAY_SIZE,
  559. dtype='uint8'
  560. )
  561. self.observation_space = spaces.Dict({
  562. 'image': self.observation_space
  563. })
  564. # Range of possible rewards
  565. self.reward_range = (0, 1)
  566. # Renderer object used to render the whole grid (full-scale)
  567. self.grid_render = None
  568. # Renderer used to render observations (small-scale agent view)
  569. self.obs_render = None
  570. # Environment configuration
  571. self.grid_size = grid_size
  572. self.max_steps = max_steps
  573. self.see_through_walls = see_through_walls
  574. # Starting position and direction for the agent
  575. self.start_pos = None
  576. self.start_dir = None
  577. # Initialize the RNG
  578. self.seed(seed=seed)
  579. # Initialize the state
  580. self.reset()
  581. def reset(self):
  582. # Generate a new random grid at the start of each episode
  583. # To keep the same grid for each episode, call env.seed() with
  584. # the same seed before calling env.reset()
  585. self._gen_grid(self.grid_size, self.grid_size)
  586. # These fields should be defined by _gen_grid
  587. assert self.start_pos is not None
  588. assert self.start_dir is not None
  589. # Check that the agent doesn't overlap with an object
  590. start_cell = self.grid.get(*self.start_pos)
  591. assert start_cell is None or start_cell.can_overlap()
  592. # Place the agent in the starting position and direction
  593. self.agent_pos = self.start_pos
  594. self.agent_dir = self.start_dir
  595. # Item picked up, being carried, initially nothing
  596. self.carrying = None
  597. # Step count since episode start
  598. self.step_count = 0
  599. # Return first observation
  600. obs = self.gen_obs()
  601. return obs
  602. def seed(self, seed=1337):
  603. # Seed the random number generator
  604. self.np_random, _ = seeding.np_random(seed)
  605. return [seed]
  606. @property
  607. def steps_remaining(self):
  608. return self.max_steps - self.step_count
  609. def __str__(self):
  610. """
  611. Produce a pretty string of the environment's grid along with the agent.
  612. A grid cell is represented by 2-character string, the first one for
  613. the object and the second one for the color.
  614. """
  615. # Map of object types to short string
  616. OBJECT_TO_STR = {
  617. 'wall' : 'W',
  618. 'floor' : 'F',
  619. 'door' : 'D',
  620. 'key' : 'K',
  621. 'ball' : 'A',
  622. 'box' : 'B',
  623. 'goal' : 'G',
  624. 'lava' : 'V',
  625. }
  626. # Short string for opened door
  627. OPENDED_DOOR_IDS = '_'
  628. # Map agent's direction to short string
  629. AGENT_DIR_TO_STR = {
  630. 0: '>',
  631. 1: 'V',
  632. 2: '<',
  633. 3: '^'
  634. }
  635. str = ''
  636. for j in range(self.grid.height):
  637. for i in range(self.grid.width):
  638. if i == self.agent_pos[0] and j == self.agent_pos[1]:
  639. str += 2 * AGENT_DIR_TO_STR[self.agent_dir]
  640. continue
  641. c = self.grid.get(i, j)
  642. if c == None:
  643. str += ' '
  644. continue
  645. if c.type == 'door':
  646. if c.is_open:
  647. str += '__'
  648. elif c.is_locked:
  649. str += 'L' + c.color[0].upper()
  650. else:
  651. str += 'D' + c.color[0].upper()
  652. continue
  653. str += OBJECT_TO_STR[c.type] + c.color[0].upper()
  654. if j < self.grid.height - 1:
  655. str += '\n'
  656. return str
  657. def _gen_grid(self, width, height):
  658. assert False, "_gen_grid needs to be implemented by each environment"
  659. def _reward(self):
  660. """
  661. Compute the reward to be given upon success
  662. """
  663. return 1 - 0.9 * (self.step_count / self.max_steps)
  664. def _rand_int(self, low, high):
  665. """
  666. Generate random integer in [low,high[
  667. """
  668. return self.np_random.randint(low, high)
  669. def _rand_float(self, low, high):
  670. """
  671. Generate random float in [low,high[
  672. """
  673. return self.np_random.uniform(low, high)
  674. def _rand_bool(self):
  675. """
  676. Generate random boolean value
  677. """
  678. return (self.np_random.randint(0, 2) == 0)
  679. def _rand_elem(self, iterable):
  680. """
  681. Pick a random element in a list
  682. """
  683. lst = list(iterable)
  684. idx = self._rand_int(0, len(lst))
  685. return lst[idx]
  686. def _rand_subset(self, iterable, num_elems):
  687. """
  688. Sample a random subset of distinct elements of a list
  689. """
  690. lst = list(iterable)
  691. assert num_elems <= len(lst)
  692. out = []
  693. while len(out) < num_elems:
  694. elem = self._rand_elem(lst)
  695. lst.remove(elem)
  696. out.append(elem)
  697. return out
  698. def _rand_color(self):
  699. """
  700. Generate a random color name (string)
  701. """
  702. return self._rand_elem(COLOR_NAMES)
  703. def _rand_pos(self, xLow, xHigh, yLow, yHigh):
  704. """
  705. Generate a random (x,y) position tuple
  706. """
  707. return (
  708. self.np_random.randint(xLow, xHigh),
  709. self.np_random.randint(yLow, yHigh)
  710. )
  711. def place_obj(self,
  712. obj,
  713. top=None,
  714. size=None,
  715. reject_fn=None,
  716. max_tries=math.inf
  717. ):
  718. """
  719. Place an object at an empty position in the grid
  720. :param top: top-left position of the rectangle where to place
  721. :param size: size of the rectangle where to place
  722. :param reject_fn: function to filter out potential positions
  723. """
  724. if top is None:
  725. top = (0, 0)
  726. if size is None:
  727. size = (self.grid.width, self.grid.height)
  728. num_tries = 0
  729. while True:
  730. # This is to handle with rare cases where rejection sampling
  731. # gets stuck in an infinite loop
  732. if num_tries > max_tries:
  733. raise RecursionError('rejection sampling failed in place_obj')
  734. num_tries += 1
  735. pos = np.array((
  736. self._rand_int(top[0], top[0] + size[0]),
  737. self._rand_int(top[1], top[1] + size[1])
  738. ))
  739. # Don't place the object on top of another object
  740. if self.grid.get(*pos) != None:
  741. continue
  742. # Don't place the object where the agent is
  743. if np.array_equal(pos, self.start_pos):
  744. continue
  745. # Check if there is a filtering criterion
  746. if reject_fn and reject_fn(self, pos):
  747. continue
  748. break
  749. self.grid.set(*pos, obj)
  750. if obj is not None:
  751. obj.init_pos = pos
  752. obj.cur_pos = pos
  753. return pos
  754. def place_agent(
  755. self,
  756. top=None,
  757. size=None,
  758. rand_dir=True,
  759. max_tries=math.inf
  760. ):
  761. """
  762. Set the agent's starting point at an empty position in the grid
  763. """
  764. self.start_pos = None
  765. pos = self.place_obj(None, top, size, max_tries=max_tries)
  766. self.start_pos = pos
  767. if rand_dir:
  768. self.start_dir = self._rand_int(0, 4)
  769. return pos
  770. @property
  771. def dir_vec(self):
  772. """
  773. Get the direction vector for the agent, pointing in the direction
  774. of forward movement.
  775. """
  776. assert self.agent_dir >= 0 and self.agent_dir < 4
  777. return DIR_TO_VEC[self.agent_dir]
  778. @property
  779. def right_vec(self):
  780. """
  781. Get the vector pointing to the right of the agent.
  782. """
  783. dx, dy = self.dir_vec
  784. return np.array((-dy, dx))
  785. @property
  786. def front_pos(self):
  787. """
  788. Get the position of the cell that is right in front of the agent
  789. """
  790. return self.agent_pos + self.dir_vec
  791. def get_view_coords(self, i, j):
  792. """
  793. Translate and rotate absolute grid coordinates (i, j) into the
  794. agent's partially observable view (sub-grid). Note that the resulting
  795. coordinates may be negative or outside of the agent's view size.
  796. """
  797. ax, ay = self.agent_pos
  798. dx, dy = self.dir_vec
  799. rx, ry = self.right_vec
  800. # Compute the absolute coordinates of the top-left view corner
  801. sz = AGENT_VIEW_SIZE
  802. hs = AGENT_VIEW_SIZE // 2
  803. tx = ax + (dx * (sz-1)) - (rx * hs)
  804. ty = ay + (dy * (sz-1)) - (ry * hs)
  805. lx = i - tx
  806. ly = j - ty
  807. # Project the coordinates of the object relative to the top-left
  808. # corner onto the agent's own coordinate system
  809. vx = (rx*lx + ry*ly)
  810. vy = -(dx*lx + dy*ly)
  811. return vx, vy
  812. def get_view_exts(self):
  813. """
  814. Get the extents of the square set of tiles visible to the agent
  815. Note: the bottom extent indices are not included in the set
  816. """
  817. # Facing right
  818. if self.agent_dir == 0:
  819. topX = self.agent_pos[0]
  820. topY = self.agent_pos[1] - AGENT_VIEW_SIZE // 2
  821. # Facing down
  822. elif self.agent_dir == 1:
  823. topX = self.agent_pos[0] - AGENT_VIEW_SIZE // 2
  824. topY = self.agent_pos[1]
  825. # Facing left
  826. elif self.agent_dir == 2:
  827. topX = self.agent_pos[0] - AGENT_VIEW_SIZE + 1
  828. topY = self.agent_pos[1] - AGENT_VIEW_SIZE // 2
  829. # Facing up
  830. elif self.agent_dir == 3:
  831. topX = self.agent_pos[0] - AGENT_VIEW_SIZE // 2
  832. topY = self.agent_pos[1] - AGENT_VIEW_SIZE + 1
  833. else:
  834. assert False, "invalid agent direction"
  835. botX = topX + AGENT_VIEW_SIZE
  836. botY = topY + AGENT_VIEW_SIZE
  837. return (topX, topY, botX, botY)
  838. def relative_coords(self, x, y):
  839. """
  840. Check if a grid position belongs to the agent's field of view, and returns the corresponding coordinates
  841. """
  842. vx, vy = self.get_view_coords(x, y)
  843. if vx < 0 or vy < 0 or vx >= AGENT_VIEW_SIZE or vy >= AGENT_VIEW_SIZE:
  844. return None
  845. return vx, vy
  846. def in_view(self, x, y):
  847. """
  848. check if a grid position is visible to the agent
  849. """
  850. return self.relative_coords(x, y) is not None
  851. def agent_sees(self, x, y):
  852. """
  853. Check if a non-empty grid position is visible to the agent
  854. """
  855. coordinates = self.relative_coords(x, y)
  856. if coordinates is None:
  857. return False
  858. vx, vy = coordinates
  859. obs = self.gen_obs()
  860. obs_grid = Grid.decode(obs['image'])
  861. obs_cell = obs_grid.get(vx, vy)
  862. world_cell = self.grid.get(x, y)
  863. return obs_cell is not None and obs_cell.type == world_cell.type
  864. def step(self, action):
  865. self.step_count += 1
  866. reward = 0
  867. done = False
  868. # Get the position in front of the agent
  869. fwd_pos = self.front_pos
  870. # Get the contents of the cell in front of the agent
  871. fwd_cell = self.grid.get(*fwd_pos)
  872. # Rotate left
  873. if action == self.actions.left:
  874. self.agent_dir -= 1
  875. if self.agent_dir < 0:
  876. self.agent_dir += 4
  877. # Rotate right
  878. elif action == self.actions.right:
  879. self.agent_dir = (self.agent_dir + 1) % 4
  880. # Move forward
  881. elif action == self.actions.forward:
  882. if fwd_cell == None or fwd_cell.can_overlap():
  883. self.agent_pos = fwd_pos
  884. if fwd_cell != None and fwd_cell.type == 'goal':
  885. done = True
  886. reward = self._reward()
  887. if fwd_cell != None and fwd_cell.type == 'lava':
  888. done = True
  889. # Pick up an object
  890. elif action == self.actions.pickup:
  891. if fwd_cell and fwd_cell.can_pickup():
  892. if self.carrying is None:
  893. self.carrying = fwd_cell
  894. self.carrying.cur_pos = np.array([-1, -1])
  895. self.grid.set(*fwd_pos, None)
  896. # Drop an object
  897. elif action == self.actions.drop:
  898. if not fwd_cell and self.carrying:
  899. self.grid.set(*fwd_pos, self.carrying)
  900. self.carrying.cur_pos = fwd_pos
  901. self.carrying = None
  902. # Toggle/activate an object
  903. elif action == self.actions.toggle:
  904. if fwd_cell:
  905. fwd_cell.toggle(self, fwd_pos)
  906. # Done action (not used by default)
  907. elif action == self.actions.done:
  908. pass
  909. else:
  910. assert False, "unknown action"
  911. if self.step_count >= self.max_steps:
  912. done = True
  913. obs = self.gen_obs()
  914. return obs, reward, done, {}
  915. def gen_obs_grid(self):
  916. """
  917. Generate the sub-grid observed by the agent.
  918. This method also outputs a visibility mask telling us which grid
  919. cells the agent can actually see.
  920. """
  921. topX, topY, botX, botY = self.get_view_exts()
  922. grid = self.grid.slice(topX, topY, AGENT_VIEW_SIZE, AGENT_VIEW_SIZE)
  923. for i in range(self.agent_dir + 1):
  924. grid = grid.rotate_left()
  925. # Process occluders and visibility
  926. # Note that this incurs some performance cost
  927. if not self.see_through_walls:
  928. vis_mask = grid.process_vis(agent_pos=(AGENT_VIEW_SIZE // 2 , AGENT_VIEW_SIZE - 1))
  929. else:
  930. vis_mask = np.ones(shape=(grid.width, grid.height), dtype=np.bool)
  931. # Make it so the agent sees what it's carrying
  932. # We do this by placing the carried object at the agent's position
  933. # in the agent's partially observable view
  934. agent_pos = grid.width // 2, grid.height - 1
  935. if self.carrying:
  936. grid.set(*agent_pos, self.carrying)
  937. else:
  938. grid.set(*agent_pos, None)
  939. return grid, vis_mask
  940. def gen_obs(self):
  941. """
  942. Generate the agent's view (partially observable, low-resolution encoding)
  943. """
  944. grid, vis_mask = self.gen_obs_grid()
  945. # Encode the partially observable view into a numpy array
  946. image = grid.encode(vis_mask)
  947. assert hasattr(self, 'mission'), "environments must define a textual mission string"
  948. # Observations are dictionaries containing:
  949. # - an image (partially observable view of the environment)
  950. # - the agent's direction/orientation (acting as a compass)
  951. # - a textual mission string (instructions for the agent)
  952. obs = {
  953. 'image': image,
  954. 'direction': self.agent_dir,
  955. 'mission': self.mission
  956. }
  957. return obs
  958. def get_obs_render(self, obs, tile_pixels=CELL_PIXELS//2):
  959. """
  960. Render an agent observation for visualization
  961. """
  962. if self.obs_render == None:
  963. from gym_minigrid.rendering import Renderer
  964. self.obs_render = Renderer(
  965. AGENT_VIEW_SIZE * tile_pixels,
  966. AGENT_VIEW_SIZE * tile_pixels
  967. )
  968. r = self.obs_render
  969. r.beginFrame()
  970. grid = Grid.decode(obs)
  971. # Render the whole grid
  972. grid.render(r, tile_pixels)
  973. # Draw the agent
  974. ratio = tile_pixels / CELL_PIXELS
  975. r.push()
  976. r.scale(ratio, ratio)
  977. r.translate(
  978. CELL_PIXELS * (0.5 + AGENT_VIEW_SIZE // 2),
  979. CELL_PIXELS * (AGENT_VIEW_SIZE - 0.5)
  980. )
  981. r.rotate(3 * 90)
  982. r.setLineColor(255, 0, 0)
  983. r.setColor(255, 0, 0)
  984. r.drawPolygon([
  985. (-12, 10),
  986. ( 12, 0),
  987. (-12, -10)
  988. ])
  989. r.pop()
  990. r.endFrame()
  991. return r.getPixmap()
  992. def render(self, mode='human', close=False):
  993. """
  994. Render the whole-grid human view
  995. """
  996. if close:
  997. if self.grid_render:
  998. self.grid_render.close()
  999. return
  1000. if self.grid_render is None:
  1001. from gym_minigrid.rendering import Renderer
  1002. self.grid_render = Renderer(
  1003. self.grid_size * CELL_PIXELS,
  1004. self.grid_size * CELL_PIXELS,
  1005. True if mode == 'human' else False
  1006. )
  1007. r = self.grid_render
  1008. if r.window:
  1009. r.window.setText(self.mission)
  1010. r.beginFrame()
  1011. # Render the whole grid
  1012. self.grid.render(r, CELL_PIXELS)
  1013. # Draw the agent
  1014. r.push()
  1015. r.translate(
  1016. CELL_PIXELS * (self.agent_pos[0] + 0.5),
  1017. CELL_PIXELS * (self.agent_pos[1] + 0.5)
  1018. )
  1019. r.rotate(self.agent_dir * 90)
  1020. r.setLineColor(255, 0, 0)
  1021. r.setColor(255, 0, 0)
  1022. r.drawPolygon([
  1023. (-12, 10),
  1024. ( 12, 0),
  1025. (-12, -10)
  1026. ])
  1027. r.pop()
  1028. # Compute which cells are visible to the agent
  1029. _, vis_mask = self.gen_obs_grid()
  1030. # Compute the absolute coordinates of the bottom-left corner
  1031. # of the agent's view area
  1032. f_vec = self.dir_vec
  1033. r_vec = self.right_vec
  1034. top_left = self.agent_pos + f_vec * (AGENT_VIEW_SIZE-1) - r_vec * (AGENT_VIEW_SIZE // 2)
  1035. # For each cell in the visibility mask
  1036. for vis_j in range(0, AGENT_VIEW_SIZE):
  1037. for vis_i in range(0, AGENT_VIEW_SIZE):
  1038. # If this cell is not visible, don't highlight it
  1039. if not vis_mask[vis_i, vis_j]:
  1040. continue
  1041. # Compute the world coordinates of this cell
  1042. abs_i, abs_j = top_left - (f_vec * vis_j) + (r_vec * vis_i)
  1043. # Highlight the cell
  1044. r.fillRect(
  1045. abs_i * CELL_PIXELS,
  1046. abs_j * CELL_PIXELS,
  1047. CELL_PIXELS,
  1048. CELL_PIXELS,
  1049. 255, 255, 255, 75
  1050. )
  1051. r.endFrame()
  1052. if mode == 'rgb_array':
  1053. return r.getArray()
  1054. elif mode == 'pixmap':
  1055. return r.getPixmap()
  1056. return r