minigrid.py 37 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397
  1. import math
  2. import gym
  3. from enum import IntEnum
  4. import numpy as np
  5. from gym import error, spaces, utils
  6. from gym.utils import seeding
  7. # Size in pixels of a tile in the full-scale human view
  8. TILE_PIXELS = 32
  9. # Map of color names to RGB values
  10. COLORS = {
  11. 'red' : np.array([255, 0, 0]),
  12. 'green' : np.array([0, 255, 0]),
  13. 'blue' : np.array([0, 0, 255]),
  14. 'purple': np.array([112, 39, 195]),
  15. 'yellow': np.array([255, 255, 0]),
  16. 'grey' : np.array([100, 100, 100])
  17. }
  18. COLOR_NAMES = sorted(list(COLORS.keys()))
  19. # Used to map colors to integers
  20. COLOR_TO_IDX = {
  21. 'red' : 0,
  22. 'green' : 1,
  23. 'blue' : 2,
  24. 'purple': 3,
  25. 'yellow': 4,
  26. 'grey' : 5
  27. }
  28. IDX_TO_COLOR = dict(zip(COLOR_TO_IDX.values(), COLOR_TO_IDX.keys()))
  29. # Map of object type to integers
  30. OBJECT_TO_IDX = {
  31. 'unseen' : 0,
  32. 'empty' : 1,
  33. 'wall' : 2,
  34. 'floor' : 3,
  35. 'door' : 4,
  36. 'key' : 5,
  37. 'ball' : 6,
  38. 'box' : 7,
  39. 'goal' : 8,
  40. 'lava' : 9,
  41. 'agent' : 10,
  42. }
  43. IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys()))
  44. # Map of state names to integers
  45. STATE_TO_IDX = {
  46. 'open' : 0,
  47. 'closed': 1,
  48. 'locked': 2,
  49. }
  50. # Map of agent direction indices to vectors
  51. DIR_TO_VEC = [
  52. # Pointing right (positive X)
  53. np.array((1, 0)),
  54. # Down (positive Y)
  55. np.array((0, 1)),
  56. # Pointing left (negative X)
  57. np.array((-1, 0)),
  58. # Up (negative Y)
  59. np.array((0, -1)),
  60. ]
  61. class WorldObj:
  62. """
  63. Base class for grid world objects
  64. """
  65. def __init__(self, type, color):
  66. assert type in OBJECT_TO_IDX, type
  67. assert color in COLOR_TO_IDX, color
  68. self.type = type
  69. self.color = color
  70. self.contains = None
  71. # Initial position of the object
  72. self.init_pos = None
  73. # Current position of the object
  74. self.cur_pos = None
  75. def can_overlap(self):
  76. """Can the agent overlap with this?"""
  77. return False
  78. def can_pickup(self):
  79. """Can the agent pick this up?"""
  80. return False
  81. def can_contain(self):
  82. """Can this contain another object?"""
  83. return False
  84. def see_behind(self):
  85. """Can the agent see behind this object?"""
  86. return True
  87. def toggle(self, env, pos):
  88. """Method to trigger/toggle an action this object performs"""
  89. return False
  90. def encode(self):
  91. """Encode the a description of this object as a 3-tuple of integers"""
  92. return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], 0)
  93. def decode(type_idx, color_idx, state):
  94. """Create an object from a 3-tuple state description"""
  95. obj_type = IDX_TO_OBJECT[type_idx]
  96. color = IDX_TO_COLOR[color_idx]
  97. if obj_type == 'empty' or obj_type == 'unseen':
  98. return None
  99. # State, 0: open, 1: closed, 2: locked
  100. is_open = state == 0
  101. is_locked = state == 2
  102. if obj_type == 'wall':
  103. v = Wall(color)
  104. elif obj_type == 'floor':
  105. v = Floor(color)
  106. elif obj_type == 'ball':
  107. v = Ball(color)
  108. elif obj_type == 'key':
  109. v = Key(color)
  110. elif obj_type == 'box':
  111. v = Box(color)
  112. elif obj_type == 'door':
  113. v = Door(color, is_open, is_locked)
  114. elif obj_type == 'goal':
  115. v = Goal()
  116. elif obj_type == 'lava':
  117. v = Lava()
  118. else:
  119. assert False, "unknown object type in decode '%s'" % objType
  120. return v
  121. def render(self, r):
  122. """Draw this object with the given renderer"""
  123. raise NotImplementedError
  124. def _set_color(self, r):
  125. """Set the color of this object as the active drawing color"""
  126. c = COLORS[self.color]
  127. r.setLineColor(c[0], c[1], c[2])
  128. r.setColor(c[0], c[1], c[2])
  129. class Goal(WorldObj):
  130. def __init__(self):
  131. super().__init__('goal', 'green')
  132. def can_overlap(self):
  133. return True
  134. def render(self, r):
  135. self._set_color(r)
  136. r.drawPolygon([
  137. (0 , TILE_PIXELS),
  138. (TILE_PIXELS, TILE_PIXELS),
  139. (TILE_PIXELS, 0),
  140. (0 , 0)
  141. ])
  142. class Floor(WorldObj):
  143. """
  144. Colored floor tile the agent can walk over
  145. """
  146. def __init__(self, color='blue'):
  147. super().__init__('floor', color)
  148. def can_overlap(self):
  149. return True
  150. def render(self, r):
  151. # Give the floor a pale color
  152. c = COLORS[self.color]
  153. r.setLineColor(100, 100, 100, 0)
  154. r.setColor(*c/2)
  155. r.drawPolygon([
  156. (1 , TILE_PIXELS),
  157. (TILE_PIXELS, TILE_PIXELS),
  158. (TILE_PIXELS, 1),
  159. (1 , 1)
  160. ])
  161. class Lava(WorldObj):
  162. def __init__(self):
  163. super().__init__('lava', 'red')
  164. def can_overlap(self):
  165. return True
  166. def render(self, r):
  167. orange = 255, 128, 0
  168. r.setLineColor(*orange)
  169. r.setColor(*orange)
  170. r.drawPolygon([
  171. (0 , TILE_PIXELS),
  172. (TILE_PIXELS, TILE_PIXELS),
  173. (TILE_PIXELS, 0),
  174. (0 , 0)
  175. ])
  176. # drawing the waves
  177. r.setLineColor(0, 0, 0)
  178. r.drawPolyline([
  179. (.1 * TILE_PIXELS, .3 * TILE_PIXELS),
  180. (.3 * TILE_PIXELS, .4 * TILE_PIXELS),
  181. (.5 * TILE_PIXELS, .3 * TILE_PIXELS),
  182. (.7 * TILE_PIXELS, .4 * TILE_PIXELS),
  183. (.9 * TILE_PIXELS, .3 * TILE_PIXELS),
  184. ])
  185. r.drawPolyline([
  186. (.1 * TILE_PIXELS, .5 * TILE_PIXELS),
  187. (.3 * TILE_PIXELS, .6 * TILE_PIXELS),
  188. (.5 * TILE_PIXELS, .5 * TILE_PIXELS),
  189. (.7 * TILE_PIXELS, .6 * TILE_PIXELS),
  190. (.9 * TILE_PIXELS, .5 * TILE_PIXELS),
  191. ])
  192. r.drawPolyline([
  193. (.1 * TILE_PIXELS, .7 * TILE_PIXELS),
  194. (.3 * TILE_PIXELS, .8 * TILE_PIXELS),
  195. (.5 * TILE_PIXELS, .7 * TILE_PIXELS),
  196. (.7 * TILE_PIXELS, .8 * TILE_PIXELS),
  197. (.9 * TILE_PIXELS, .7 * TILE_PIXELS),
  198. ])
  199. class Wall(WorldObj):
  200. def __init__(self, color='grey'):
  201. super().__init__('wall', color)
  202. def see_behind(self):
  203. return False
  204. def render(self, r):
  205. self._set_color(r)
  206. r.drawPolygon([
  207. (0 , TILE_PIXELS),
  208. (TILE_PIXELS, TILE_PIXELS),
  209. (TILE_PIXELS, 0),
  210. (0 , 0)
  211. ])
  212. class Door(WorldObj):
  213. def __init__(self, color, is_open=False, is_locked=False):
  214. super().__init__('door', color)
  215. self.is_open = is_open
  216. self.is_locked = is_locked
  217. def can_overlap(self):
  218. """The agent can only walk over this cell when the door is open"""
  219. return self.is_open
  220. def see_behind(self):
  221. return self.is_open
  222. def toggle(self, env, pos):
  223. # If the player has the right key to open the door
  224. if self.is_locked:
  225. if isinstance(env.carrying, Key) and env.carrying.color == self.color:
  226. self.is_locked = False
  227. self.is_open = True
  228. return True
  229. return False
  230. self.is_open = not self.is_open
  231. return True
  232. def encode(self):
  233. """Encode the a description of this object as a 3-tuple of integers"""
  234. # State, 0: open, 1: closed, 2: locked
  235. if self.is_open:
  236. state = 0
  237. elif self.is_locked:
  238. state = 2
  239. elif not self.is_open:
  240. state = 1
  241. return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], state)
  242. def render(self, r):
  243. c = COLORS[self.color]
  244. r.setLineColor(c[0], c[1], c[2])
  245. r.setColor(c[0], c[1], c[2], 50 if self.is_locked else 0)
  246. if self.is_open:
  247. r.drawPolygon([
  248. (TILE_PIXELS-2, TILE_PIXELS),
  249. (TILE_PIXELS , TILE_PIXELS),
  250. (TILE_PIXELS , 0),
  251. (TILE_PIXELS-2, 0)
  252. ])
  253. return
  254. r.drawPolygon([
  255. (0 , TILE_PIXELS),
  256. (TILE_PIXELS, TILE_PIXELS),
  257. (TILE_PIXELS, 0),
  258. (0 , 0)
  259. ])
  260. r.drawPolygon([
  261. (2 , TILE_PIXELS-2),
  262. (TILE_PIXELS-2, TILE_PIXELS-2),
  263. (TILE_PIXELS-2, 2),
  264. (2 , 2)
  265. ])
  266. if self.is_locked:
  267. # Draw key slot
  268. r.drawLine(
  269. TILE_PIXELS * 0.55,
  270. TILE_PIXELS * 0.5,
  271. TILE_PIXELS * 0.75,
  272. TILE_PIXELS * 0.5
  273. )
  274. else:
  275. # Draw door handle
  276. r.drawCircle(TILE_PIXELS * 0.75, TILE_PIXELS * 0.5, 2)
  277. class Key(WorldObj):
  278. def __init__(self, color='blue'):
  279. super(Key, self).__init__('key', color)
  280. def can_pickup(self):
  281. return True
  282. def render(self, r):
  283. self._set_color(r)
  284. # Vertical quad
  285. r.drawPolygon([
  286. (16, 10),
  287. (20, 10),
  288. (20, 28),
  289. (16, 28)
  290. ])
  291. # Teeth
  292. r.drawPolygon([
  293. (12, 19),
  294. (16, 19),
  295. (16, 21),
  296. (12, 21)
  297. ])
  298. r.drawPolygon([
  299. (12, 26),
  300. (16, 26),
  301. (16, 28),
  302. (12, 28)
  303. ])
  304. r.drawCircle(18, 9, 6)
  305. r.setLineColor(0, 0, 0)
  306. r.setColor(0, 0, 0)
  307. r.drawCircle(18, 9, 2)
  308. class Ball(WorldObj):
  309. def __init__(self, color='blue'):
  310. super(Ball, self).__init__('ball', color)
  311. def can_pickup(self):
  312. return True
  313. def render(self, r):
  314. self._set_color(r)
  315. r.drawCircle(TILE_PIXELS * 0.5, TILE_PIXELS * 0.5, 10)
  316. class Box(WorldObj):
  317. def __init__(self, color, contains=None):
  318. super(Box, self).__init__('box', color)
  319. self.contains = contains
  320. def can_pickup(self):
  321. return True
  322. def render(self, r):
  323. c = COLORS[self.color]
  324. r.setLineColor(c[0], c[1], c[2])
  325. r.setColor(0, 0, 0)
  326. r.setLineWidth(2)
  327. r.drawPolygon([
  328. (4 , TILE_PIXELS-4),
  329. (TILE_PIXELS-4, TILE_PIXELS-4),
  330. (TILE_PIXELS-4, 4),
  331. (4 , 4)
  332. ])
  333. r.drawLine(
  334. 4,
  335. TILE_PIXELS / 2,
  336. TILE_PIXELS - 4,
  337. TILE_PIXELS / 2
  338. )
  339. r.setLineWidth(1)
  340. def toggle(self, env, pos):
  341. # Replace the box by its contents
  342. env.grid.set(*pos, self.contains)
  343. return True
  344. def render_tile(
  345. obj,
  346. agent_dir=None,
  347. highlight=False,
  348. tile_size=TILE_PIXELS
  349. ):
  350. """
  351. Render a tile and cache the result
  352. """
  353. pass
  354. class Grid:
  355. """
  356. Represent a grid and operations on it
  357. """
  358. def __init__(self, width, height):
  359. assert width >= 3
  360. assert height >= 3
  361. self.width = width
  362. self.height = height
  363. self.grid = [None] * width * height
  364. def __contains__(self, key):
  365. if isinstance(key, WorldObj):
  366. for e in self.grid:
  367. if e is key:
  368. return True
  369. elif isinstance(key, tuple):
  370. for e in self.grid:
  371. if e is None:
  372. continue
  373. if (e.color, e.type) == key:
  374. return True
  375. if key[0] is None and key[1] == e.type:
  376. return True
  377. return False
  378. def __eq__(self, other):
  379. grid1 = self.encode()
  380. grid2 = other.encode()
  381. return np.array_equal(grid2, grid1)
  382. def __ne__(self, other):
  383. return not self == other
  384. def copy(self):
  385. from copy import deepcopy
  386. return deepcopy(self)
  387. def set(self, i, j, v):
  388. assert i >= 0 and i < self.width
  389. assert j >= 0 and j < self.height
  390. self.grid[j * self.width + i] = v
  391. def get(self, i, j):
  392. assert i >= 0 and i < self.width
  393. assert j >= 0 and j < self.height
  394. return self.grid[j * self.width + i]
  395. def horz_wall(self, x, y, length=None, obj_type=Wall):
  396. if length is None:
  397. length = self.width - x
  398. for i in range(0, length):
  399. self.set(x + i, y, obj_type())
  400. def vert_wall(self, x, y, length=None, obj_type=Wall):
  401. if length is None:
  402. length = self.height - y
  403. for j in range(0, length):
  404. self.set(x, y + j, obj_type())
  405. def wall_rect(self, x, y, w, h):
  406. self.horz_wall(x, y, w)
  407. self.horz_wall(x, y+h-1, w)
  408. self.vert_wall(x, y, h)
  409. self.vert_wall(x+w-1, y, h)
  410. def rotate_left(self):
  411. """
  412. Rotate the grid to the left (counter-clockwise)
  413. """
  414. grid = Grid(self.height, self.width)
  415. for i in range(self.width):
  416. for j in range(self.height):
  417. v = self.get(i, j)
  418. grid.set(j, grid.height - 1 - i, v)
  419. return grid
  420. def slice(self, topX, topY, width, height):
  421. """
  422. Get a subset of the grid
  423. """
  424. grid = Grid(width, height)
  425. for j in range(0, height):
  426. for i in range(0, width):
  427. x = topX + i
  428. y = topY + j
  429. if x >= 0 and x < self.width and \
  430. y >= 0 and y < self.height:
  431. v = self.get(x, y)
  432. else:
  433. v = Wall()
  434. grid.set(i, j, v)
  435. return grid
  436. def render(self, r, tile_size):
  437. """
  438. Render this grid at a given scale
  439. :param r: target renderer object
  440. :param tile_size: tile size in pixels
  441. """
  442. assert r.width == self.width * tile_size
  443. assert r.height == self.height * tile_size
  444. # Total grid size at native scale
  445. widthPx = self.width * TILE_PIXELS
  446. heightPx = self.height * TILE_PIXELS
  447. r.push()
  448. # Internally, we draw at the "large" full-grid resolution, but we
  449. # use the renderer to scale back to the desired size
  450. r.scale(tile_size / TILE_PIXELS, tile_size / TILE_PIXELS)
  451. # Draw the background of the in-world cells black
  452. r.fillRect(
  453. 0,
  454. 0,
  455. widthPx,
  456. heightPx,
  457. 0, 0, 0
  458. )
  459. # Draw grid lines
  460. r.setLineColor(100, 100, 100)
  461. for rowIdx in range(0, self.height):
  462. y = TILE_PIXELS * rowIdx
  463. r.drawLine(0, y, widthPx, y)
  464. for colIdx in range(0, self.width):
  465. x = TILE_PIXELS * colIdx
  466. r.drawLine(x, 0, x, heightPx)
  467. # Render the grid
  468. for j in range(0, self.height):
  469. for i in range(0, self.width):
  470. cell = self.get(i, j)
  471. if cell == None:
  472. continue
  473. r.push()
  474. r.translate(i * TILE_PIXELS, j * TILE_PIXELS)
  475. cell.render(r)
  476. r.pop()
  477. r.pop()
  478. def encode(self, vis_mask=None):
  479. """
  480. Produce a compact numpy encoding of the grid
  481. """
  482. if vis_mask is None:
  483. vis_mask = np.ones((self.width, self.height), dtype=bool)
  484. array = np.zeros((self.width, self.height, 3), dtype='uint8')
  485. for i in range(self.width):
  486. for j in range(self.height):
  487. if vis_mask[i, j]:
  488. v = self.get(i, j)
  489. if v is None:
  490. array[i, j, 0] = OBJECT_TO_IDX['empty']
  491. array[i, j, 1] = 0
  492. array[i, j, 2] = 0
  493. else:
  494. array[i, j, :] = v.encode()
  495. return array
  496. @staticmethod
  497. def decode(array):
  498. """
  499. Decode an array grid encoding back into a grid
  500. """
  501. width, height, channels = array.shape
  502. assert channels == 3
  503. grid = Grid(width, height)
  504. for i in range(width):
  505. for j in range(height):
  506. type_idx, color_idx, state = array[i, j]
  507. v = WorldObj.decode(type_idx, color_idx, state)
  508. grid.set(i, j, v)
  509. return grid
  510. def process_vis(grid, agent_pos):
  511. mask = np.zeros(shape=(grid.width, grid.height), dtype=np.bool)
  512. mask[agent_pos[0], agent_pos[1]] = True
  513. for j in reversed(range(0, grid.height)):
  514. for i in range(0, grid.width-1):
  515. if not mask[i, j]:
  516. continue
  517. cell = grid.get(i, j)
  518. if cell and not cell.see_behind():
  519. continue
  520. mask[i+1, j] = True
  521. if j > 0:
  522. mask[i+1, j-1] = True
  523. mask[i, j-1] = True
  524. for i in reversed(range(1, grid.width)):
  525. if not mask[i, j]:
  526. continue
  527. cell = grid.get(i, j)
  528. if cell and not cell.see_behind():
  529. continue
  530. mask[i-1, j] = True
  531. if j > 0:
  532. mask[i-1, j-1] = True
  533. mask[i, j-1] = True
  534. for j in range(0, grid.height):
  535. for i in range(0, grid.width):
  536. if not mask[i, j]:
  537. grid.set(i, j, None)
  538. return mask
  539. class MiniGridEnv(gym.Env):
  540. """
  541. 2D grid world game environment
  542. """
  543. metadata = {
  544. 'render.modes': ['human', 'rgb_array', 'pixmap'],
  545. 'video.frames_per_second' : 10
  546. }
  547. # Enumeration of possible actions
  548. class Actions(IntEnum):
  549. # Turn left, turn right, move forward
  550. left = 0
  551. right = 1
  552. forward = 2
  553. # Pick up an object
  554. pickup = 3
  555. # Drop an object
  556. drop = 4
  557. # Toggle/activate an object
  558. toggle = 5
  559. # Done completing task
  560. done = 6
  561. def __init__(
  562. self,
  563. grid_size=None,
  564. width=None,
  565. height=None,
  566. max_steps=100,
  567. see_through_walls=False,
  568. seed=1337,
  569. agent_view_size=7
  570. ):
  571. # Can't set both grid_size and width/height
  572. if grid_size:
  573. assert width == None and height == None
  574. width = grid_size
  575. height = grid_size
  576. # Action enumeration for this environment
  577. self.actions = MiniGridEnv.Actions
  578. # Actions are discrete integer values
  579. self.action_space = spaces.Discrete(len(self.actions))
  580. # Number of cells (width and height) in the agent view
  581. self.agent_view_size = agent_view_size
  582. # Observations are dictionaries containing an
  583. # encoding of the grid and a textual 'mission' string
  584. self.observation_space = spaces.Box(
  585. low=0,
  586. high=255,
  587. shape=(self.agent_view_size, self.agent_view_size, 3),
  588. dtype='uint8'
  589. )
  590. self.observation_space = spaces.Dict({
  591. 'image': self.observation_space
  592. })
  593. # Range of possible rewards
  594. self.reward_range = (0, 1)
  595. # Renderer object used to render the whole grid (full-scale)
  596. self.grid_render = None
  597. # Renderer used to render observations (small-scale agent view)
  598. self.obs_render = None
  599. # Environment configuration
  600. self.width = width
  601. self.height = height
  602. self.max_steps = max_steps
  603. self.see_through_walls = see_through_walls
  604. # Current position and direction of the agent
  605. self.agent_pos = None
  606. self.agent_dir = None
  607. # Initialize the RNG
  608. self.seed(seed=seed)
  609. # Initialize the state
  610. self.reset()
  611. def reset(self):
  612. # Current position and direction of the agent
  613. self.agent_pos = None
  614. self.agent_dir = None
  615. # Generate a new random grid at the start of each episode
  616. # To keep the same grid for each episode, call env.seed() with
  617. # the same seed before calling env.reset()
  618. self._gen_grid(self.width, self.height)
  619. # These fields should be defined by _gen_grid
  620. assert self.agent_pos is not None
  621. assert self.agent_dir is not None
  622. # Check that the agent doesn't overlap with an object
  623. start_cell = self.grid.get(*self.agent_pos)
  624. assert start_cell is None or start_cell.can_overlap()
  625. # Item picked up, being carried, initially nothing
  626. self.carrying = None
  627. # Step count since episode start
  628. self.step_count = 0
  629. # Return first observation
  630. obs = self.gen_obs()
  631. return obs
  632. def seed(self, seed=1337):
  633. # Seed the random number generator
  634. self.np_random, _ = seeding.np_random(seed)
  635. return [seed]
  636. @property
  637. def steps_remaining(self):
  638. return self.max_steps - self.step_count
  639. def __str__(self):
  640. """
  641. Produce a pretty string of the environment's grid along with the agent.
  642. A grid cell is represented by 2-character string, the first one for
  643. the object and the second one for the color.
  644. """
  645. # Map of object types to short string
  646. OBJECT_TO_STR = {
  647. 'wall' : 'W',
  648. 'floor' : 'F',
  649. 'door' : 'D',
  650. 'key' : 'K',
  651. 'ball' : 'A',
  652. 'box' : 'B',
  653. 'goal' : 'G',
  654. 'lava' : 'V',
  655. }
  656. # Short string for opened door
  657. OPENDED_DOOR_IDS = '_'
  658. # Map agent's direction to short string
  659. AGENT_DIR_TO_STR = {
  660. 0: '>',
  661. 1: 'V',
  662. 2: '<',
  663. 3: '^'
  664. }
  665. str = ''
  666. for j in range(self.grid.height):
  667. for i in range(self.grid.width):
  668. if i == self.agent_pos[0] and j == self.agent_pos[1]:
  669. str += 2 * AGENT_DIR_TO_STR[self.agent_dir]
  670. continue
  671. c = self.grid.get(i, j)
  672. if c == None:
  673. str += ' '
  674. continue
  675. if c.type == 'door':
  676. if c.is_open:
  677. str += '__'
  678. elif c.is_locked:
  679. str += 'L' + c.color[0].upper()
  680. else:
  681. str += 'D' + c.color[0].upper()
  682. continue
  683. str += OBJECT_TO_STR[c.type] + c.color[0].upper()
  684. if j < self.grid.height - 1:
  685. str += '\n'
  686. return str
  687. def _gen_grid(self, width, height):
  688. assert False, "_gen_grid needs to be implemented by each environment"
  689. def _reward(self):
  690. """
  691. Compute the reward to be given upon success
  692. """
  693. return 1 - 0.9 * (self.step_count / self.max_steps)
  694. def _rand_int(self, low, high):
  695. """
  696. Generate random integer in [low,high[
  697. """
  698. return self.np_random.randint(low, high)
  699. def _rand_float(self, low, high):
  700. """
  701. Generate random float in [low,high[
  702. """
  703. return self.np_random.uniform(low, high)
  704. def _rand_bool(self):
  705. """
  706. Generate random boolean value
  707. """
  708. return (self.np_random.randint(0, 2) == 0)
  709. def _rand_elem(self, iterable):
  710. """
  711. Pick a random element in a list
  712. """
  713. lst = list(iterable)
  714. idx = self._rand_int(0, len(lst))
  715. return lst[idx]
  716. def _rand_subset(self, iterable, num_elems):
  717. """
  718. Sample a random subset of distinct elements of a list
  719. """
  720. lst = list(iterable)
  721. assert num_elems <= len(lst)
  722. out = []
  723. while len(out) < num_elems:
  724. elem = self._rand_elem(lst)
  725. lst.remove(elem)
  726. out.append(elem)
  727. return out
  728. def _rand_color(self):
  729. """
  730. Generate a random color name (string)
  731. """
  732. return self._rand_elem(COLOR_NAMES)
  733. def _rand_pos(self, xLow, xHigh, yLow, yHigh):
  734. """
  735. Generate a random (x,y) position tuple
  736. """
  737. return (
  738. self.np_random.randint(xLow, xHigh),
  739. self.np_random.randint(yLow, yHigh)
  740. )
  741. def place_obj(self,
  742. obj,
  743. top=None,
  744. size=None,
  745. reject_fn=None,
  746. max_tries=math.inf
  747. ):
  748. """
  749. Place an object at an empty position in the grid
  750. :param top: top-left position of the rectangle where to place
  751. :param size: size of the rectangle where to place
  752. :param reject_fn: function to filter out potential positions
  753. """
  754. if top is None:
  755. top = (0, 0)
  756. else:
  757. top = (max(top[0], 0), max(top[1], 0))
  758. if size is None:
  759. size = (self.grid.width, self.grid.height)
  760. num_tries = 0
  761. while True:
  762. # This is to handle with rare cases where rejection sampling
  763. # gets stuck in an infinite loop
  764. if num_tries > max_tries:
  765. raise RecursionError('rejection sampling failed in place_obj')
  766. num_tries += 1
  767. pos = np.array((
  768. self._rand_int(top[0], min(top[0] + size[0], self.grid.width)),
  769. self._rand_int(top[1], min(top[1] + size[1], self.grid.height))
  770. ))
  771. # Don't place the object on top of another object
  772. if self.grid.get(*pos) != None:
  773. continue
  774. # Don't place the object where the agent is
  775. if np.array_equal(pos, self.agent_pos):
  776. continue
  777. # Check if there is a filtering criterion
  778. if reject_fn and reject_fn(self, pos):
  779. continue
  780. break
  781. self.grid.set(*pos, obj)
  782. if obj is not None:
  783. obj.init_pos = pos
  784. obj.cur_pos = pos
  785. return pos
  786. def put_obj(self, obj, i, j):
  787. """
  788. Put an object at a specific position in the grid
  789. """
  790. self.grid.set(i, j, obj)
  791. obj.init_pos = (i, j)
  792. obj.cur_pos = (i, j)
  793. def place_agent(
  794. self,
  795. top=None,
  796. size=None,
  797. rand_dir=True,
  798. max_tries=math.inf
  799. ):
  800. """
  801. Set the agent's starting point at an empty position in the grid
  802. """
  803. self.agent_pos = None
  804. pos = self.place_obj(None, top, size, max_tries=max_tries)
  805. self.agent_pos = pos
  806. if rand_dir:
  807. self.agent_dir = self._rand_int(0, 4)
  808. return pos
  809. @property
  810. def dir_vec(self):
  811. """
  812. Get the direction vector for the agent, pointing in the direction
  813. of forward movement.
  814. """
  815. assert self.agent_dir >= 0 and self.agent_dir < 4
  816. return DIR_TO_VEC[self.agent_dir]
  817. @property
  818. def right_vec(self):
  819. """
  820. Get the vector pointing to the right of the agent.
  821. """
  822. dx, dy = self.dir_vec
  823. return np.array((-dy, dx))
  824. @property
  825. def front_pos(self):
  826. """
  827. Get the position of the cell that is right in front of the agent
  828. """
  829. return self.agent_pos + self.dir_vec
  830. def get_view_coords(self, i, j):
  831. """
  832. Translate and rotate absolute grid coordinates (i, j) into the
  833. agent's partially observable view (sub-grid). Note that the resulting
  834. coordinates may be negative or outside of the agent's view size.
  835. """
  836. ax, ay = self.agent_pos
  837. dx, dy = self.dir_vec
  838. rx, ry = self.right_vec
  839. # Compute the absolute coordinates of the top-left view corner
  840. sz = self.agent_view_size
  841. hs = self.agent_view_size // 2
  842. tx = ax + (dx * (sz-1)) - (rx * hs)
  843. ty = ay + (dy * (sz-1)) - (ry * hs)
  844. lx = i - tx
  845. ly = j - ty
  846. # Project the coordinates of the object relative to the top-left
  847. # corner onto the agent's own coordinate system
  848. vx = (rx*lx + ry*ly)
  849. vy = -(dx*lx + dy*ly)
  850. return vx, vy
  851. def get_view_exts(self):
  852. """
  853. Get the extents of the square set of tiles visible to the agent
  854. Note: the bottom extent indices are not included in the set
  855. """
  856. # Facing right
  857. if self.agent_dir == 0:
  858. topX = self.agent_pos[0]
  859. topY = self.agent_pos[1] - self.agent_view_size // 2
  860. # Facing down
  861. elif self.agent_dir == 1:
  862. topX = self.agent_pos[0] - self.agent_view_size // 2
  863. topY = self.agent_pos[1]
  864. # Facing left
  865. elif self.agent_dir == 2:
  866. topX = self.agent_pos[0] - self.agent_view_size + 1
  867. topY = self.agent_pos[1] - self.agent_view_size // 2
  868. # Facing up
  869. elif self.agent_dir == 3:
  870. topX = self.agent_pos[0] - self.agent_view_size // 2
  871. topY = self.agent_pos[1] - self.agent_view_size + 1
  872. else:
  873. assert False, "invalid agent direction"
  874. botX = topX + self.agent_view_size
  875. botY = topY + self.agent_view_size
  876. return (topX, topY, botX, botY)
  877. def relative_coords(self, x, y):
  878. """
  879. Check if a grid position belongs to the agent's field of view, and returns the corresponding coordinates
  880. """
  881. vx, vy = self.get_view_coords(x, y)
  882. if vx < 0 or vy < 0 or vx >= self.agent_view_size or vy >= self.agent_view_size:
  883. return None
  884. return vx, vy
  885. def in_view(self, x, y):
  886. """
  887. check if a grid position is visible to the agent
  888. """
  889. return self.relative_coords(x, y) is not None
  890. def agent_sees(self, x, y):
  891. """
  892. Check if a non-empty grid position is visible to the agent
  893. """
  894. coordinates = self.relative_coords(x, y)
  895. if coordinates is None:
  896. return False
  897. vx, vy = coordinates
  898. obs = self.gen_obs()
  899. obs_grid = Grid.decode(obs['image'])
  900. obs_cell = obs_grid.get(vx, vy)
  901. world_cell = self.grid.get(x, y)
  902. return obs_cell is not None and obs_cell.type == world_cell.type
  903. def step(self, action):
  904. self.step_count += 1
  905. reward = 0
  906. done = False
  907. # Get the position in front of the agent
  908. fwd_pos = self.front_pos
  909. # Get the contents of the cell in front of the agent
  910. fwd_cell = self.grid.get(*fwd_pos)
  911. # Rotate left
  912. if action == self.actions.left:
  913. self.agent_dir -= 1
  914. if self.agent_dir < 0:
  915. self.agent_dir += 4
  916. # Rotate right
  917. elif action == self.actions.right:
  918. self.agent_dir = (self.agent_dir + 1) % 4
  919. # Move forward
  920. elif action == self.actions.forward:
  921. if fwd_cell == None or fwd_cell.can_overlap():
  922. self.agent_pos = fwd_pos
  923. if fwd_cell != None and fwd_cell.type == 'goal':
  924. done = True
  925. reward = self._reward()
  926. if fwd_cell != None and fwd_cell.type == 'lava':
  927. done = True
  928. # Pick up an object
  929. elif action == self.actions.pickup:
  930. if fwd_cell and fwd_cell.can_pickup():
  931. if self.carrying is None:
  932. self.carrying = fwd_cell
  933. self.carrying.cur_pos = np.array([-1, -1])
  934. self.grid.set(*fwd_pos, None)
  935. # Drop an object
  936. elif action == self.actions.drop:
  937. if not fwd_cell and self.carrying:
  938. self.grid.set(*fwd_pos, self.carrying)
  939. self.carrying.cur_pos = fwd_pos
  940. self.carrying = None
  941. # Toggle/activate an object
  942. elif action == self.actions.toggle:
  943. if fwd_cell:
  944. fwd_cell.toggle(self, fwd_pos)
  945. # Done action (not used by default)
  946. elif action == self.actions.done:
  947. pass
  948. else:
  949. assert False, "unknown action"
  950. if self.step_count >= self.max_steps:
  951. done = True
  952. obs = self.gen_obs()
  953. return obs, reward, done, {}
  954. def gen_obs_grid(self):
  955. """
  956. Generate the sub-grid observed by the agent.
  957. This method also outputs a visibility mask telling us which grid
  958. cells the agent can actually see.
  959. """
  960. topX, topY, botX, botY = self.get_view_exts()
  961. grid = self.grid.slice(topX, topY, self.agent_view_size, self.agent_view_size)
  962. for i in range(self.agent_dir + 1):
  963. grid = grid.rotate_left()
  964. # Process occluders and visibility
  965. # Note that this incurs some performance cost
  966. if not self.see_through_walls:
  967. vis_mask = grid.process_vis(agent_pos=(self.agent_view_size // 2 , self.agent_view_size - 1))
  968. else:
  969. vis_mask = np.ones(shape=(grid.width, grid.height), dtype=np.bool)
  970. # Make it so the agent sees what it's carrying
  971. # We do this by placing the carried object at the agent's position
  972. # in the agent's partially observable view
  973. agent_pos = grid.width // 2, grid.height - 1
  974. if self.carrying:
  975. grid.set(*agent_pos, self.carrying)
  976. else:
  977. grid.set(*agent_pos, None)
  978. return grid, vis_mask
  979. def gen_obs(self):
  980. """
  981. Generate the agent's view (partially observable, low-resolution encoding)
  982. """
  983. grid, vis_mask = self.gen_obs_grid()
  984. # Encode the partially observable view into a numpy array
  985. image = grid.encode(vis_mask)
  986. assert hasattr(self, 'mission'), "environments must define a textual mission string"
  987. # Observations are dictionaries containing:
  988. # - an image (partially observable view of the environment)
  989. # - the agent's direction/orientation (acting as a compass)
  990. # - a textual mission string (instructions for the agent)
  991. obs = {
  992. 'image': image,
  993. 'direction': self.agent_dir,
  994. 'mission': self.mission
  995. }
  996. return obs
  997. def get_obs_render(self, obs, tile_size=TILE_PIXELS//2, mode='pixmap'):
  998. """
  999. Render an agent observation for visualization
  1000. """
  1001. if self.obs_render == None:
  1002. from gym_minigrid.rendering import Renderer
  1003. self.obs_render = Renderer(
  1004. self.agent_view_size * tile_size,
  1005. self.agent_view_size * tile_size
  1006. )
  1007. r = self.obs_render
  1008. r.beginFrame()
  1009. grid = Grid.decode(obs)
  1010. # Render the whole grid
  1011. grid.render(r, tile_size)
  1012. # Draw the agent
  1013. ratio = tile_size / TILE_PIXELS
  1014. r.push()
  1015. r.scale(ratio, ratio)
  1016. r.translate(
  1017. TILE_PIXELS * (0.5 + self.agent_view_size // 2),
  1018. TILE_PIXELS * (self.agent_view_size - 0.5)
  1019. )
  1020. r.rotate(3 * 90)
  1021. r.setLineColor(255, 0, 0)
  1022. r.setColor(255, 0, 0)
  1023. r.drawPolygon([
  1024. (-12, 10),
  1025. ( 12, 0),
  1026. (-12, -10)
  1027. ])
  1028. r.pop()
  1029. r.endFrame()
  1030. if mode == 'rgb_array':
  1031. return r.getArray()
  1032. elif mode == 'pixmap':
  1033. return r.getPixmap()
  1034. return r
  1035. def render(self, mode='human', close=False, highlight=True, tile_size=TILE_PIXELS):
  1036. """
  1037. Render the whole-grid human view
  1038. """
  1039. if close:
  1040. if self.grid_render:
  1041. self.grid_render.close()
  1042. return
  1043. if self.grid_render is None or self.grid_render.window is None or (self.grid_render.width != self.width * tile_size):
  1044. from gym_minigrid.rendering import Renderer
  1045. self.grid_render = Renderer(
  1046. self.width * tile_size,
  1047. self.height * tile_size,
  1048. True if mode == 'human' else False
  1049. )
  1050. r = self.grid_render
  1051. if r.window:
  1052. r.window.setText(self.mission)
  1053. r.beginFrame()
  1054. # Render the whole grid
  1055. self.grid.render(r, tile_size)
  1056. # Draw the agent
  1057. ratio = tile_size / TILE_PIXELS
  1058. r.push()
  1059. r.scale(ratio, ratio)
  1060. r.translate(
  1061. TILE_PIXELS * (self.agent_pos[0] + 0.5),
  1062. TILE_PIXELS * (self.agent_pos[1] + 0.5)
  1063. )
  1064. r.rotate(self.agent_dir * 90)
  1065. r.setLineColor(255, 0, 0)
  1066. r.setColor(255, 0, 0)
  1067. r.drawPolygon([
  1068. (-12, 10),
  1069. ( 12, 0),
  1070. (-12, -10)
  1071. ])
  1072. r.pop()
  1073. # Compute which cells are visible to the agent
  1074. _, vis_mask = self.gen_obs_grid()
  1075. # Compute the absolute coordinates of the bottom-left corner
  1076. # of the agent's view area
  1077. f_vec = self.dir_vec
  1078. r_vec = self.right_vec
  1079. top_left = self.agent_pos + f_vec * (self.agent_view_size-1) - r_vec * (self.agent_view_size // 2)
  1080. # For each cell in the visibility mask
  1081. if highlight:
  1082. for vis_j in range(0, self.agent_view_size):
  1083. for vis_i in range(0, self.agent_view_size):
  1084. # If this cell is not visible, don't highlight it
  1085. if not vis_mask[vis_i, vis_j]:
  1086. continue
  1087. # Compute the world coordinates of this cell
  1088. abs_i, abs_j = top_left - (f_vec * vis_j) + (r_vec * vis_i)
  1089. # Highlight the cell
  1090. r.fillRect(
  1091. abs_i * tile_size,
  1092. abs_j * tile_size,
  1093. tile_size,
  1094. tile_size,
  1095. 255, 255, 255, 75
  1096. )
  1097. r.endFrame()
  1098. if mode == 'rgb_array':
  1099. return r.getArray()
  1100. elif mode == 'pixmap':
  1101. return r.getPixmap()
  1102. return r