minigrid.py 37 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343
  1. import math
  2. import gym
  3. from enum import IntEnum
  4. import numpy as np
  5. from gym import error, spaces, utils
  6. from gym.utils import seeding
  7. # Size in pixels of a cell in the full-scale human view
  8. CELL_PIXELS = 32
  9. # Map of color names to RGB values
  10. COLORS = {
  11. 'red' : np.array([255, 0, 0]),
  12. 'green' : np.array([0, 255, 0]),
  13. 'blue' : np.array([0, 0, 255]),
  14. 'purple': np.array([112, 39, 195]),
  15. 'yellow': np.array([255, 255, 0]),
  16. 'grey' : np.array([100, 100, 100])
  17. }
  18. COLOR_NAMES = sorted(list(COLORS.keys()))
  19. # Used to map colors to integers
  20. COLOR_TO_IDX = {
  21. 'red' : 0,
  22. 'green' : 1,
  23. 'blue' : 2,
  24. 'purple': 3,
  25. 'yellow': 4,
  26. 'grey' : 5
  27. }
  28. IDX_TO_COLOR = dict(zip(COLOR_TO_IDX.values(), COLOR_TO_IDX.keys()))
  29. # Map of object type to integers
  30. OBJECT_TO_IDX = {
  31. 'unseen' : 0,
  32. 'empty' : 1,
  33. 'wall' : 2,
  34. 'floor' : 3,
  35. 'door' : 4,
  36. 'key' : 5,
  37. 'ball' : 6,
  38. 'box' : 7,
  39. 'goal' : 8,
  40. 'lava' : 9,
  41. 'agent' : 10,
  42. }
  43. IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys()))
  44. # Map of agent direction indices to vectors
  45. DIR_TO_VEC = [
  46. # Pointing right (positive X)
  47. np.array((1, 0)),
  48. # Down (positive Y)
  49. np.array((0, 1)),
  50. # Pointing left (negative X)
  51. np.array((-1, 0)),
  52. # Up (negative Y)
  53. np.array((0, -1)),
  54. ]
  55. class WorldObj:
  56. """
  57. Base class for grid world objects
  58. """
  59. def __init__(self, type, color):
  60. assert type in OBJECT_TO_IDX, type
  61. assert color in COLOR_TO_IDX, color
  62. self.type = type
  63. self.color = color
  64. self.contains = None
  65. # Initial position of the object
  66. self.init_pos = None
  67. # Current position of the object
  68. self.cur_pos = None
  69. def can_overlap(self):
  70. """Can the agent overlap with this?"""
  71. return False
  72. def can_pickup(self):
  73. """Can the agent pick this up?"""
  74. return False
  75. def can_contain(self):
  76. """Can this contain another object?"""
  77. return False
  78. def see_behind(self):
  79. """Can the agent see behind this object?"""
  80. return True
  81. def toggle(self, env, pos):
  82. """Method to trigger/toggle an action this object performs"""
  83. return False
  84. def render(self, r):
  85. """Draw this object with the given renderer"""
  86. raise NotImplementedError
  87. def _set_color(self, r):
  88. """Set the color of this object as the active drawing color"""
  89. c = COLORS[self.color]
  90. r.setLineColor(c[0], c[1], c[2])
  91. r.setColor(c[0], c[1], c[2])
  92. class Goal(WorldObj):
  93. def __init__(self):
  94. super().__init__('goal', 'green')
  95. def can_overlap(self):
  96. return True
  97. def render(self, r):
  98. self._set_color(r)
  99. r.drawPolygon([
  100. (0 , CELL_PIXELS),
  101. (CELL_PIXELS, CELL_PIXELS),
  102. (CELL_PIXELS, 0),
  103. (0 , 0)
  104. ])
  105. class Floor(WorldObj):
  106. """
  107. Colored floor tile the agent can walk over
  108. """
  109. def __init__(self, color='blue'):
  110. super().__init__('floor', color)
  111. def can_overlap(self):
  112. return True
  113. def render(self, r):
  114. # Give the floor a pale color
  115. c = COLORS[self.color]
  116. r.setLineColor(100, 100, 100, 0)
  117. r.setColor(*c/2)
  118. r.drawPolygon([
  119. (1 , CELL_PIXELS),
  120. (CELL_PIXELS, CELL_PIXELS),
  121. (CELL_PIXELS, 1),
  122. (1 , 1)
  123. ])
  124. class Lava(WorldObj):
  125. def __init__(self):
  126. super().__init__('lava', 'red')
  127. def can_overlap(self):
  128. return True
  129. def render(self, r):
  130. orange = 255, 128, 0
  131. r.setLineColor(*orange)
  132. r.setColor(*orange)
  133. r.drawPolygon([
  134. (0 , CELL_PIXELS),
  135. (CELL_PIXELS, CELL_PIXELS),
  136. (CELL_PIXELS, 0),
  137. (0 , 0)
  138. ])
  139. # drawing the waves
  140. r.setLineColor(0, 0, 0)
  141. r.drawPolyline([
  142. (.1 * CELL_PIXELS, .3 * CELL_PIXELS),
  143. (.3 * CELL_PIXELS, .4 * CELL_PIXELS),
  144. (.5 * CELL_PIXELS, .3 * CELL_PIXELS),
  145. (.7 * CELL_PIXELS, .4 * CELL_PIXELS),
  146. (.9 * CELL_PIXELS, .3 * CELL_PIXELS),
  147. ])
  148. r.drawPolyline([
  149. (.1 * CELL_PIXELS, .5 * CELL_PIXELS),
  150. (.3 * CELL_PIXELS, .6 * CELL_PIXELS),
  151. (.5 * CELL_PIXELS, .5 * CELL_PIXELS),
  152. (.7 * CELL_PIXELS, .6 * CELL_PIXELS),
  153. (.9 * CELL_PIXELS, .5 * CELL_PIXELS),
  154. ])
  155. r.drawPolyline([
  156. (.1 * CELL_PIXELS, .7 * CELL_PIXELS),
  157. (.3 * CELL_PIXELS, .8 * CELL_PIXELS),
  158. (.5 * CELL_PIXELS, .7 * CELL_PIXELS),
  159. (.7 * CELL_PIXELS, .8 * CELL_PIXELS),
  160. (.9 * CELL_PIXELS, .7 * CELL_PIXELS),
  161. ])
  162. class Wall(WorldObj):
  163. def __init__(self, color='grey'):
  164. super().__init__('wall', color)
  165. def see_behind(self):
  166. return False
  167. def render(self, r):
  168. self._set_color(r)
  169. r.drawPolygon([
  170. (0 , CELL_PIXELS),
  171. (CELL_PIXELS, CELL_PIXELS),
  172. (CELL_PIXELS, 0),
  173. (0 , 0)
  174. ])
  175. class Door(WorldObj):
  176. def __init__(self, color, is_open=False, is_locked=False):
  177. super().__init__('door', color)
  178. self.is_open = is_open
  179. self.is_locked = is_locked
  180. def can_overlap(self):
  181. """The agent can only walk over this cell when the door is open"""
  182. return self.is_open
  183. def see_behind(self):
  184. return self.is_open
  185. def toggle(self, env, pos):
  186. # If the player has the right key to open the door
  187. if self.is_locked:
  188. if isinstance(env.carrying, Key) and env.carrying.color == self.color:
  189. self.is_locked = False
  190. self.is_open = True
  191. return True
  192. return False
  193. self.is_open = not self.is_open
  194. return True
  195. def render(self, r):
  196. c = COLORS[self.color]
  197. r.setLineColor(c[0], c[1], c[2])
  198. r.setColor(c[0], c[1], c[2], 50 if self.is_locked else 0)
  199. if self.is_open:
  200. r.drawPolygon([
  201. (CELL_PIXELS-2, CELL_PIXELS),
  202. (CELL_PIXELS , CELL_PIXELS),
  203. (CELL_PIXELS , 0),
  204. (CELL_PIXELS-2, 0)
  205. ])
  206. return
  207. r.drawPolygon([
  208. (0 , CELL_PIXELS),
  209. (CELL_PIXELS, CELL_PIXELS),
  210. (CELL_PIXELS, 0),
  211. (0 , 0)
  212. ])
  213. r.drawPolygon([
  214. (2 , CELL_PIXELS-2),
  215. (CELL_PIXELS-2, CELL_PIXELS-2),
  216. (CELL_PIXELS-2, 2),
  217. (2 , 2)
  218. ])
  219. if self.is_locked:
  220. # Draw key slot
  221. r.drawLine(
  222. CELL_PIXELS * 0.55,
  223. CELL_PIXELS * 0.5,
  224. CELL_PIXELS * 0.75,
  225. CELL_PIXELS * 0.5
  226. )
  227. else:
  228. # Draw door handle
  229. r.drawCircle(CELL_PIXELS * 0.75, CELL_PIXELS * 0.5, 2)
  230. class Key(WorldObj):
  231. def __init__(self, color='blue'):
  232. super(Key, self).__init__('key', color)
  233. def can_pickup(self):
  234. return True
  235. def render(self, r):
  236. self._set_color(r)
  237. # Vertical quad
  238. r.drawPolygon([
  239. (16, 10),
  240. (20, 10),
  241. (20, 28),
  242. (16, 28)
  243. ])
  244. # Teeth
  245. r.drawPolygon([
  246. (12, 19),
  247. (16, 19),
  248. (16, 21),
  249. (12, 21)
  250. ])
  251. r.drawPolygon([
  252. (12, 26),
  253. (16, 26),
  254. (16, 28),
  255. (12, 28)
  256. ])
  257. r.drawCircle(18, 9, 6)
  258. r.setLineColor(0, 0, 0)
  259. r.setColor(0, 0, 0)
  260. r.drawCircle(18, 9, 2)
  261. class Ball(WorldObj):
  262. def __init__(self, color='blue'):
  263. super(Ball, self).__init__('ball', color)
  264. def can_pickup(self):
  265. return True
  266. def render(self, r):
  267. self._set_color(r)
  268. r.drawCircle(CELL_PIXELS * 0.5, CELL_PIXELS * 0.5, 10)
  269. class Box(WorldObj):
  270. def __init__(self, color, contains=None):
  271. super(Box, self).__init__('box', color)
  272. self.contains = contains
  273. def can_pickup(self):
  274. return True
  275. def render(self, r):
  276. c = COLORS[self.color]
  277. r.setLineColor(c[0], c[1], c[2])
  278. r.setColor(0, 0, 0)
  279. r.setLineWidth(2)
  280. r.drawPolygon([
  281. (4 , CELL_PIXELS-4),
  282. (CELL_PIXELS-4, CELL_PIXELS-4),
  283. (CELL_PIXELS-4, 4),
  284. (4 , 4)
  285. ])
  286. r.drawLine(
  287. 4,
  288. CELL_PIXELS / 2,
  289. CELL_PIXELS - 4,
  290. CELL_PIXELS / 2
  291. )
  292. r.setLineWidth(1)
  293. def toggle(self, env, pos):
  294. # Replace the box by its contents
  295. env.grid.set(*pos, self.contains)
  296. return True
  297. class Grid:
  298. """
  299. Represent a grid and operations on it
  300. """
  301. def __init__(self, width, height):
  302. assert width >= 3
  303. assert height >= 3
  304. self.width = width
  305. self.height = height
  306. self.grid = [None] * width * height
  307. def __contains__(self, key):
  308. if isinstance(key, WorldObj):
  309. for e in self.grid:
  310. if e is key:
  311. return True
  312. elif isinstance(key, tuple):
  313. for e in self.grid:
  314. if e is None:
  315. continue
  316. if (e.color, e.type) == key:
  317. return True
  318. if key[0] is None and key[1] == e.type:
  319. return True
  320. return False
  321. def __eq__(self, other):
  322. grid1 = self.encode()
  323. grid2 = other.encode()
  324. return np.array_equal(grid2, grid1)
  325. def __ne__(self, other):
  326. return not self == other
  327. def copy(self):
  328. from copy import deepcopy
  329. return deepcopy(self)
  330. def set(self, i, j, v):
  331. assert i >= 0 and i < self.width
  332. assert j >= 0 and j < self.height
  333. self.grid[j * self.width + i] = v
  334. def get(self, i, j):
  335. assert i >= 0 and i < self.width
  336. assert j >= 0 and j < self.height
  337. return self.grid[j * self.width + i]
  338. def horz_wall(self, x, y, length=None):
  339. if length is None:
  340. length = self.width - x
  341. for i in range(0, length):
  342. self.set(x + i, y, Wall())
  343. def vert_wall(self, x, y, length=None):
  344. if length is None:
  345. length = self.height - y
  346. for j in range(0, length):
  347. self.set(x, y + j, Wall())
  348. def wall_rect(self, x, y, w, h):
  349. self.horz_wall(x, y, w)
  350. self.horz_wall(x, y+h-1, w)
  351. self.vert_wall(x, y, h)
  352. self.vert_wall(x+w-1, y, h)
  353. def rotate_left(self):
  354. """
  355. Rotate the grid to the left (counter-clockwise)
  356. """
  357. grid = Grid(self.height, self.width)
  358. for i in range(self.width):
  359. for j in range(self.height):
  360. v = self.get(i, j)
  361. grid.set(j, grid.height - 1 - i, v)
  362. return grid
  363. def slice(self, topX, topY, width, height):
  364. """
  365. Get a subset of the grid
  366. """
  367. grid = Grid(width, height)
  368. for j in range(0, height):
  369. for i in range(0, width):
  370. x = topX + i
  371. y = topY + j
  372. if x >= 0 and x < self.width and \
  373. y >= 0 and y < self.height:
  374. v = self.get(x, y)
  375. else:
  376. v = Wall()
  377. grid.set(i, j, v)
  378. return grid
  379. def render(self, r, tile_size):
  380. """
  381. Render this grid at a given scale
  382. :param r: target renderer object
  383. :param tile_size: tile size in pixels
  384. """
  385. assert r.width == self.width * tile_size
  386. assert r.height == self.height * tile_size
  387. # Total grid size at native scale
  388. widthPx = self.width * CELL_PIXELS
  389. heightPx = self.height * CELL_PIXELS
  390. r.push()
  391. # Internally, we draw at the "large" full-grid resolution, but we
  392. # use the renderer to scale back to the desired size
  393. r.scale(tile_size / CELL_PIXELS, tile_size / CELL_PIXELS)
  394. # Draw the background of the in-world cells black
  395. r.fillRect(
  396. 0,
  397. 0,
  398. widthPx,
  399. heightPx,
  400. 0, 0, 0
  401. )
  402. # Draw grid lines
  403. r.setLineColor(100, 100, 100)
  404. for rowIdx in range(0, self.height):
  405. y = CELL_PIXELS * rowIdx
  406. r.drawLine(0, y, widthPx, y)
  407. for colIdx in range(0, self.width):
  408. x = CELL_PIXELS * colIdx
  409. r.drawLine(x, 0, x, heightPx)
  410. # Render the grid
  411. for j in range(0, self.height):
  412. for i in range(0, self.width):
  413. cell = self.get(i, j)
  414. if cell == None:
  415. continue
  416. r.push()
  417. r.translate(i * CELL_PIXELS, j * CELL_PIXELS)
  418. cell.render(r)
  419. r.pop()
  420. r.pop()
  421. def encode(self, vis_mask=None):
  422. """
  423. Produce a compact numpy encoding of the grid
  424. """
  425. if vis_mask is None:
  426. vis_mask = np.ones((self.width, self.height), dtype=bool)
  427. array = np.zeros((self.width, self.height, 3), dtype='uint8')
  428. for i in range(self.width):
  429. for j in range(self.height):
  430. if vis_mask[i, j]:
  431. v = self.get(i, j)
  432. if v is None:
  433. array[i, j, 0] = OBJECT_TO_IDX['empty']
  434. array[i, j, 1] = 0
  435. array[i, j, 2] = 0
  436. else:
  437. # State, 0: open, 1: closed, 2: locked
  438. state = 0
  439. if hasattr(v, 'is_open') and not v.is_open:
  440. state = 1
  441. if hasattr(v, 'is_locked') and v.is_locked:
  442. state = 2
  443. array[i, j, 0] = OBJECT_TO_IDX[v.type]
  444. array[i, j, 1] = COLOR_TO_IDX[v.color]
  445. array[i, j, 2] = state
  446. return array
  447. @staticmethod
  448. def decode(array):
  449. """
  450. Decode an array grid encoding back into a grid
  451. """
  452. width, height, channels = array.shape
  453. assert channels == 3
  454. grid = Grid(width, height)
  455. for i in range(width):
  456. for j in range(height):
  457. typeIdx, colorIdx, state = array[i, j]
  458. if typeIdx == OBJECT_TO_IDX['unseen'] or \
  459. typeIdx == OBJECT_TO_IDX['empty']:
  460. continue
  461. objType = IDX_TO_OBJECT[typeIdx]
  462. color = IDX_TO_COLOR[colorIdx]
  463. # State, 0: open, 1: closed, 2: locked
  464. is_open = state == 0
  465. is_locked = state == 2
  466. if objType == 'wall':
  467. v = Wall(color)
  468. elif objType == 'floor':
  469. v = Floor(color)
  470. elif objType == 'ball':
  471. v = Ball(color)
  472. elif objType == 'key':
  473. v = Key(color)
  474. elif objType == 'box':
  475. v = Box(color)
  476. elif objType == 'door':
  477. v = Door(color, is_open, is_locked)
  478. elif objType == 'goal':
  479. v = Goal()
  480. elif objType == 'lava':
  481. v = Lava()
  482. else:
  483. assert False, "unknown obj type in decode '%s'" % objType
  484. grid.set(i, j, v)
  485. return grid
  486. def process_vis(grid, agent_pos):
  487. mask = np.zeros(shape=(grid.width, grid.height), dtype=np.bool)
  488. mask[agent_pos[0], agent_pos[1]] = True
  489. for j in reversed(range(0, grid.height)):
  490. for i in range(0, grid.width-1):
  491. if not mask[i, j]:
  492. continue
  493. cell = grid.get(i, j)
  494. if cell and not cell.see_behind():
  495. continue
  496. mask[i+1, j] = True
  497. if j > 0:
  498. mask[i+1, j-1] = True
  499. mask[i, j-1] = True
  500. for i in reversed(range(1, grid.width)):
  501. if not mask[i, j]:
  502. continue
  503. cell = grid.get(i, j)
  504. if cell and not cell.see_behind():
  505. continue
  506. mask[i-1, j] = True
  507. if j > 0:
  508. mask[i-1, j-1] = True
  509. mask[i, j-1] = True
  510. for j in range(0, grid.height):
  511. for i in range(0, grid.width):
  512. if not mask[i, j]:
  513. grid.set(i, j, None)
  514. return mask
  515. class MiniGridEnv(gym.Env):
  516. """
  517. 2D grid world game environment
  518. """
  519. metadata = {
  520. 'render.modes': ['human', 'rgb_array', 'pixmap'],
  521. 'video.frames_per_second' : 10
  522. }
  523. # Enumeration of possible actions
  524. class Actions(IntEnum):
  525. # Turn left, turn right, move forward
  526. left = 0
  527. right = 1
  528. forward = 2
  529. # Pick up an object
  530. pickup = 3
  531. # Drop an object
  532. drop = 4
  533. # Toggle/activate an object
  534. toggle = 5
  535. # Done completing task
  536. done = 6
  537. def __init__(
  538. self,
  539. grid_size=None,
  540. width=None,
  541. height=None,
  542. max_steps=100,
  543. see_through_walls=False,
  544. seed=1337,
  545. agent_view_size=7
  546. ):
  547. # Can't set both grid_size and width/height
  548. if grid_size:
  549. assert width == None and height == None
  550. width = grid_size
  551. height = grid_size
  552. # Action enumeration for this environment
  553. self.actions = MiniGridEnv.Actions
  554. # Actions are discrete integer values
  555. self.action_space = spaces.Discrete(len(self.actions))
  556. # Number of cells (width and height) in the agent view
  557. self.agent_view_size = agent_view_size
  558. # Observations are dictionaries containing an
  559. # encoding of the grid and a textual 'mission' string
  560. self.observation_space = spaces.Box(
  561. low=0,
  562. high=255,
  563. shape=(self.agent_view_size, self.agent_view_size, 3),
  564. dtype='uint8'
  565. )
  566. self.observation_space = spaces.Dict({
  567. 'image': self.observation_space
  568. })
  569. # Range of possible rewards
  570. self.reward_range = (0, 1)
  571. # Renderer object used to render the whole grid (full-scale)
  572. self.grid_render = None
  573. # Renderer used to render observations (small-scale agent view)
  574. self.obs_render = None
  575. # Environment configuration
  576. self.width = width
  577. self.height = height
  578. self.max_steps = max_steps
  579. self.see_through_walls = see_through_walls
  580. # Current position and direction of the agent
  581. self.agent_pos = None
  582. self.agent_dir = None
  583. # Initialize the RNG
  584. self.seed(seed=seed)
  585. # Initialize the state
  586. self.reset()
  587. def reset(self):
  588. # Current position and direction of the agent
  589. self.agent_pos = None
  590. self.agent_dir = None
  591. # Generate a new random grid at the start of each episode
  592. # To keep the same grid for each episode, call env.seed() with
  593. # the same seed before calling env.reset()
  594. self._gen_grid(self.width, self.height)
  595. # These fields should be defined by _gen_grid
  596. assert self.agent_pos is not None
  597. assert self.agent_dir is not None
  598. # Check that the agent doesn't overlap with an object
  599. start_cell = self.grid.get(*self.agent_pos)
  600. assert start_cell is None or start_cell.can_overlap()
  601. # Item picked up, being carried, initially nothing
  602. self.carrying = None
  603. # Step count since episode start
  604. self.step_count = 0
  605. # Return first observation
  606. obs = self.gen_obs()
  607. return obs
  608. def seed(self, seed=1337):
  609. # Seed the random number generator
  610. self.np_random, _ = seeding.np_random(seed)
  611. return [seed]
  612. @property
  613. def steps_remaining(self):
  614. return self.max_steps - self.step_count
  615. def __str__(self):
  616. """
  617. Produce a pretty string of the environment's grid along with the agent.
  618. A grid cell is represented by 2-character string, the first one for
  619. the object and the second one for the color.
  620. """
  621. # Map of object types to short string
  622. OBJECT_TO_STR = {
  623. 'wall' : 'W',
  624. 'floor' : 'F',
  625. 'door' : 'D',
  626. 'key' : 'K',
  627. 'ball' : 'A',
  628. 'box' : 'B',
  629. 'goal' : 'G',
  630. 'lava' : 'V',
  631. }
  632. # Short string for opened door
  633. OPENDED_DOOR_IDS = '_'
  634. # Map agent's direction to short string
  635. AGENT_DIR_TO_STR = {
  636. 0: '>',
  637. 1: 'V',
  638. 2: '<',
  639. 3: '^'
  640. }
  641. str = ''
  642. for j in range(self.grid.height):
  643. for i in range(self.grid.width):
  644. if i == self.agent_pos[0] and j == self.agent_pos[1]:
  645. str += 2 * AGENT_DIR_TO_STR[self.agent_dir]
  646. continue
  647. c = self.grid.get(i, j)
  648. if c == None:
  649. str += ' '
  650. continue
  651. if c.type == 'door':
  652. if c.is_open:
  653. str += '__'
  654. elif c.is_locked:
  655. str += 'L' + c.color[0].upper()
  656. else:
  657. str += 'D' + c.color[0].upper()
  658. continue
  659. str += OBJECT_TO_STR[c.type] + c.color[0].upper()
  660. if j < self.grid.height - 1:
  661. str += '\n'
  662. return str
  663. def _gen_grid(self, width, height):
  664. assert False, "_gen_grid needs to be implemented by each environment"
  665. def _reward(self):
  666. """
  667. Compute the reward to be given upon success
  668. """
  669. return 1 - 0.9 * (self.step_count / self.max_steps)
  670. def _rand_int(self, low, high):
  671. """
  672. Generate random integer in [low,high[
  673. """
  674. return self.np_random.randint(low, high)
  675. def _rand_float(self, low, high):
  676. """
  677. Generate random float in [low,high[
  678. """
  679. return self.np_random.uniform(low, high)
  680. def _rand_bool(self):
  681. """
  682. Generate random boolean value
  683. """
  684. return (self.np_random.randint(0, 2) == 0)
  685. def _rand_elem(self, iterable):
  686. """
  687. Pick a random element in a list
  688. """
  689. lst = list(iterable)
  690. idx = self._rand_int(0, len(lst))
  691. return lst[idx]
  692. def _rand_subset(self, iterable, num_elems):
  693. """
  694. Sample a random subset of distinct elements of a list
  695. """
  696. lst = list(iterable)
  697. assert num_elems <= len(lst)
  698. out = []
  699. while len(out) < num_elems:
  700. elem = self._rand_elem(lst)
  701. lst.remove(elem)
  702. out.append(elem)
  703. return out
  704. def _rand_color(self):
  705. """
  706. Generate a random color name (string)
  707. """
  708. return self._rand_elem(COLOR_NAMES)
  709. def _rand_pos(self, xLow, xHigh, yLow, yHigh):
  710. """
  711. Generate a random (x,y) position tuple
  712. """
  713. return (
  714. self.np_random.randint(xLow, xHigh),
  715. self.np_random.randint(yLow, yHigh)
  716. )
  717. def place_obj(self,
  718. obj,
  719. top=None,
  720. size=None,
  721. reject_fn=None,
  722. max_tries=math.inf
  723. ):
  724. """
  725. Place an object at an empty position in the grid
  726. :param top: top-left position of the rectangle where to place
  727. :param size: size of the rectangle where to place
  728. :param reject_fn: function to filter out potential positions
  729. """
  730. if top is None:
  731. top = (0, 0)
  732. else:
  733. top = (max(top[0], 0), max(top[1], 0))
  734. if size is None:
  735. size = (self.grid.width, self.grid.height)
  736. num_tries = 0
  737. while True:
  738. # This is to handle with rare cases where rejection sampling
  739. # gets stuck in an infinite loop
  740. if num_tries > max_tries:
  741. raise RecursionError('rejection sampling failed in place_obj')
  742. num_tries += 1
  743. pos = np.array((
  744. self._rand_int(top[0], min(top[0] + size[0], self.grid.width)),
  745. self._rand_int(top[1], min(top[1] + size[1], self.grid.height))
  746. ))
  747. # Don't place the object on top of another object
  748. if self.grid.get(*pos) != None:
  749. continue
  750. # Don't place the object where the agent is
  751. if np.array_equal(pos, self.agent_pos):
  752. continue
  753. # Check if there is a filtering criterion
  754. if reject_fn and reject_fn(self, pos):
  755. continue
  756. break
  757. self.grid.set(*pos, obj)
  758. if obj is not None:
  759. obj.init_pos = pos
  760. obj.cur_pos = pos
  761. return pos
  762. def place_agent(
  763. self,
  764. top=None,
  765. size=None,
  766. rand_dir=True,
  767. max_tries=math.inf
  768. ):
  769. """
  770. Set the agent's starting point at an empty position in the grid
  771. """
  772. self.agent_pos = None
  773. pos = self.place_obj(None, top, size, max_tries=max_tries)
  774. self.agent_pos = pos
  775. if rand_dir:
  776. self.agent_dir = self._rand_int(0, 4)
  777. return pos
  778. @property
  779. def dir_vec(self):
  780. """
  781. Get the direction vector for the agent, pointing in the direction
  782. of forward movement.
  783. """
  784. assert self.agent_dir >= 0 and self.agent_dir < 4
  785. return DIR_TO_VEC[self.agent_dir]
  786. @property
  787. def right_vec(self):
  788. """
  789. Get the vector pointing to the right of the agent.
  790. """
  791. dx, dy = self.dir_vec
  792. return np.array((-dy, dx))
  793. @property
  794. def front_pos(self):
  795. """
  796. Get the position of the cell that is right in front of the agent
  797. """
  798. return self.agent_pos + self.dir_vec
  799. def get_view_coords(self, i, j):
  800. """
  801. Translate and rotate absolute grid coordinates (i, j) into the
  802. agent's partially observable view (sub-grid). Note that the resulting
  803. coordinates may be negative or outside of the agent's view size.
  804. """
  805. ax, ay = self.agent_pos
  806. dx, dy = self.dir_vec
  807. rx, ry = self.right_vec
  808. # Compute the absolute coordinates of the top-left view corner
  809. sz = self.agent_view_size
  810. hs = self.agent_view_size // 2
  811. tx = ax + (dx * (sz-1)) - (rx * hs)
  812. ty = ay + (dy * (sz-1)) - (ry * hs)
  813. lx = i - tx
  814. ly = j - ty
  815. # Project the coordinates of the object relative to the top-left
  816. # corner onto the agent's own coordinate system
  817. vx = (rx*lx + ry*ly)
  818. vy = -(dx*lx + dy*ly)
  819. return vx, vy
  820. def get_view_exts(self):
  821. """
  822. Get the extents of the square set of tiles visible to the agent
  823. Note: the bottom extent indices are not included in the set
  824. """
  825. # Facing right
  826. if self.agent_dir == 0:
  827. topX = self.agent_pos[0]
  828. topY = self.agent_pos[1] - self.agent_view_size // 2
  829. # Facing down
  830. elif self.agent_dir == 1:
  831. topX = self.agent_pos[0] - self.agent_view_size // 2
  832. topY = self.agent_pos[1]
  833. # Facing left
  834. elif self.agent_dir == 2:
  835. topX = self.agent_pos[0] - self.agent_view_size + 1
  836. topY = self.agent_pos[1] - self.agent_view_size // 2
  837. # Facing up
  838. elif self.agent_dir == 3:
  839. topX = self.agent_pos[0] - self.agent_view_size // 2
  840. topY = self.agent_pos[1] - self.agent_view_size + 1
  841. else:
  842. assert False, "invalid agent direction"
  843. botX = topX + self.agent_view_size
  844. botY = topY + self.agent_view_size
  845. return (topX, topY, botX, botY)
  846. def relative_coords(self, x, y):
  847. """
  848. Check if a grid position belongs to the agent's field of view, and returns the corresponding coordinates
  849. """
  850. vx, vy = self.get_view_coords(x, y)
  851. if vx < 0 or vy < 0 or vx >= self.agent_view_size or vy >= self.agent_view_size:
  852. return None
  853. return vx, vy
  854. def in_view(self, x, y):
  855. """
  856. check if a grid position is visible to the agent
  857. """
  858. return self.relative_coords(x, y) is not None
  859. def agent_sees(self, x, y):
  860. """
  861. Check if a non-empty grid position is visible to the agent
  862. """
  863. coordinates = self.relative_coords(x, y)
  864. if coordinates is None:
  865. return False
  866. vx, vy = coordinates
  867. obs = self.gen_obs()
  868. obs_grid = Grid.decode(obs['image'])
  869. obs_cell = obs_grid.get(vx, vy)
  870. world_cell = self.grid.get(x, y)
  871. return obs_cell is not None and obs_cell.type == world_cell.type
  872. def step(self, action):
  873. self.step_count += 1
  874. reward = 0
  875. done = False
  876. # Get the position in front of the agent
  877. fwd_pos = self.front_pos
  878. # Get the contents of the cell in front of the agent
  879. fwd_cell = self.grid.get(*fwd_pos)
  880. # Rotate left
  881. if action == self.actions.left:
  882. self.agent_dir -= 1
  883. if self.agent_dir < 0:
  884. self.agent_dir += 4
  885. # Rotate right
  886. elif action == self.actions.right:
  887. self.agent_dir = (self.agent_dir + 1) % 4
  888. # Move forward
  889. elif action == self.actions.forward:
  890. if fwd_cell == None or fwd_cell.can_overlap():
  891. self.agent_pos = fwd_pos
  892. if fwd_cell != None and fwd_cell.type == 'goal':
  893. done = True
  894. reward = self._reward()
  895. if fwd_cell != None and fwd_cell.type == 'lava':
  896. done = True
  897. # Pick up an object
  898. elif action == self.actions.pickup:
  899. if fwd_cell and fwd_cell.can_pickup():
  900. if self.carrying is None:
  901. self.carrying = fwd_cell
  902. self.carrying.cur_pos = np.array([-1, -1])
  903. self.grid.set(*fwd_pos, None)
  904. # Drop an object
  905. elif action == self.actions.drop:
  906. if not fwd_cell and self.carrying:
  907. self.grid.set(*fwd_pos, self.carrying)
  908. self.carrying.cur_pos = fwd_pos
  909. self.carrying = None
  910. # Toggle/activate an object
  911. elif action == self.actions.toggle:
  912. if fwd_cell:
  913. fwd_cell.toggle(self, fwd_pos)
  914. # Done action (not used by default)
  915. elif action == self.actions.done:
  916. pass
  917. else:
  918. assert False, "unknown action"
  919. if self.step_count >= self.max_steps:
  920. done = True
  921. obs = self.gen_obs()
  922. return obs, reward, done, {}
  923. def gen_obs_grid(self):
  924. """
  925. Generate the sub-grid observed by the agent.
  926. This method also outputs a visibility mask telling us which grid
  927. cells the agent can actually see.
  928. """
  929. topX, topY, botX, botY = self.get_view_exts()
  930. grid = self.grid.slice(topX, topY, self.agent_view_size, self.agent_view_size)
  931. for i in range(self.agent_dir + 1):
  932. grid = grid.rotate_left()
  933. # Process occluders and visibility
  934. # Note that this incurs some performance cost
  935. if not self.see_through_walls:
  936. vis_mask = grid.process_vis(agent_pos=(self.agent_view_size // 2 , self.agent_view_size - 1))
  937. else:
  938. vis_mask = np.ones(shape=(grid.width, grid.height), dtype=np.bool)
  939. # Make it so the agent sees what it's carrying
  940. # We do this by placing the carried object at the agent's position
  941. # in the agent's partially observable view
  942. agent_pos = grid.width // 2, grid.height - 1
  943. if self.carrying:
  944. grid.set(*agent_pos, self.carrying)
  945. else:
  946. grid.set(*agent_pos, None)
  947. return grid, vis_mask
  948. def gen_obs(self):
  949. """
  950. Generate the agent's view (partially observable, low-resolution encoding)
  951. """
  952. grid, vis_mask = self.gen_obs_grid()
  953. # Encode the partially observable view into a numpy array
  954. image = grid.encode(vis_mask)
  955. assert hasattr(self, 'mission'), "environments must define a textual mission string"
  956. # Observations are dictionaries containing:
  957. # - an image (partially observable view of the environment)
  958. # - the agent's direction/orientation (acting as a compass)
  959. # - a textual mission string (instructions for the agent)
  960. obs = {
  961. 'image': image,
  962. 'direction': self.agent_dir,
  963. 'mission': self.mission
  964. }
  965. return obs
  966. def get_obs_render(self, obs, tile_pixels=CELL_PIXELS//2, mode='pixmap'):
  967. """
  968. Render an agent observation for visualization
  969. """
  970. if self.obs_render == None:
  971. from gym_minigrid.rendering import Renderer
  972. self.obs_render = Renderer(
  973. self.agent_view_size * tile_pixels,
  974. self.agent_view_size * tile_pixels
  975. )
  976. r = self.obs_render
  977. r.beginFrame()
  978. grid = Grid.decode(obs)
  979. # Render the whole grid
  980. grid.render(r, tile_pixels)
  981. # Draw the agent
  982. ratio = tile_pixels / CELL_PIXELS
  983. r.push()
  984. r.scale(ratio, ratio)
  985. r.translate(
  986. CELL_PIXELS * (0.5 + self.agent_view_size // 2),
  987. CELL_PIXELS * (self.agent_view_size - 0.5)
  988. )
  989. r.rotate(3 * 90)
  990. r.setLineColor(255, 0, 0)
  991. r.setColor(255, 0, 0)
  992. r.drawPolygon([
  993. (-12, 10),
  994. ( 12, 0),
  995. (-12, -10)
  996. ])
  997. r.pop()
  998. r.endFrame()
  999. if mode == 'rgb_array':
  1000. return r.getArray()
  1001. elif mode == 'pixmap':
  1002. return r.getPixmap()
  1003. return r
  1004. def render(self, mode='human', close=False, highlight=True):
  1005. """
  1006. Render the whole-grid human view
  1007. """
  1008. if close:
  1009. if self.grid_render:
  1010. self.grid_render.close()
  1011. return
  1012. if self.grid_render is None or self.grid_render.window is None:
  1013. from gym_minigrid.rendering import Renderer
  1014. self.grid_render = Renderer(
  1015. self.width * CELL_PIXELS,
  1016. self.height * CELL_PIXELS,
  1017. True if mode == 'human' else False
  1018. )
  1019. r = self.grid_render
  1020. if r.window:
  1021. r.window.setText(self.mission)
  1022. r.beginFrame()
  1023. # Render the whole grid
  1024. self.grid.render(r, CELL_PIXELS)
  1025. # Draw the agent
  1026. r.push()
  1027. r.translate(
  1028. CELL_PIXELS * (self.agent_pos[0] + 0.5),
  1029. CELL_PIXELS * (self.agent_pos[1] + 0.5)
  1030. )
  1031. r.rotate(self.agent_dir * 90)
  1032. r.setLineColor(255, 0, 0)
  1033. r.setColor(255, 0, 0)
  1034. r.drawPolygon([
  1035. (-12, 10),
  1036. ( 12, 0),
  1037. (-12, -10)
  1038. ])
  1039. r.pop()
  1040. # Compute which cells are visible to the agent
  1041. _, vis_mask = self.gen_obs_grid()
  1042. # Compute the absolute coordinates of the bottom-left corner
  1043. # of the agent's view area
  1044. f_vec = self.dir_vec
  1045. r_vec = self.right_vec
  1046. top_left = self.agent_pos + f_vec * (self.agent_view_size-1) - r_vec * (self.agent_view_size // 2)
  1047. # For each cell in the visibility mask
  1048. if highlight:
  1049. for vis_j in range(0, self.agent_view_size):
  1050. for vis_i in range(0, self.agent_view_size):
  1051. # If this cell is not visible, don't highlight it
  1052. if not vis_mask[vis_i, vis_j]:
  1053. continue
  1054. # Compute the world coordinates of this cell
  1055. abs_i, abs_j = top_left - (f_vec * vis_j) + (r_vec * vis_i)
  1056. # Highlight the cell
  1057. r.fillRect(
  1058. abs_i * CELL_PIXELS,
  1059. abs_j * CELL_PIXELS,
  1060. CELL_PIXELS,
  1061. CELL_PIXELS,
  1062. 255, 255, 255, 75
  1063. )
  1064. r.endFrame()
  1065. if mode == 'rgb_array':
  1066. return r.getArray()
  1067. elif mode == 'pixmap':
  1068. return r.getPixmap()
  1069. return r