minigrid.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338
  1. import math
  2. import gym
  3. from enum import IntEnum
  4. import numpy as np
  5. from gym import error, spaces, utils
  6. from gym.utils import seeding
  7. # Size in pixels of a cell in the full-scale human view
  8. CELL_PIXELS = 32
  9. # Number of cells (width and height) in the agent view
  10. AGENT_VIEW_SIZE = 7
  11. # Size of the array given as an observation to the agent
  12. OBS_ARRAY_SIZE = (AGENT_VIEW_SIZE, AGENT_VIEW_SIZE, 3)
  13. # Map of color names to RGB values
  14. COLORS = {
  15. 'red' : np.array([255, 0, 0]),
  16. 'green' : np.array([0, 255, 0]),
  17. 'blue' : np.array([0, 0, 255]),
  18. 'purple': np.array([112, 39, 195]),
  19. 'yellow': np.array([255, 255, 0]),
  20. 'grey' : np.array([100, 100, 100])
  21. }
  22. COLOR_NAMES = sorted(list(COLORS.keys()))
  23. # Used to map colors to integers
  24. COLOR_TO_IDX = {
  25. 'red' : 0,
  26. 'green' : 1,
  27. 'blue' : 2,
  28. 'purple': 3,
  29. 'yellow': 4,
  30. 'grey' : 5
  31. }
  32. IDX_TO_COLOR = dict(zip(COLOR_TO_IDX.values(), COLOR_TO_IDX.keys()))
  33. # Map of object type to integers
  34. OBJECT_TO_IDX = {
  35. 'unseen' : 0,
  36. 'empty' : 1,
  37. 'wall' : 2,
  38. 'floor' : 3,
  39. 'door' : 4,
  40. 'key' : 5,
  41. 'ball' : 6,
  42. 'box' : 7,
  43. 'goal' : 8,
  44. 'lava' : 9
  45. }
  46. IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys()))
  47. # Map of agent direction indices to vectors
  48. DIR_TO_VEC = [
  49. # Pointing right (positive X)
  50. np.array((1, 0)),
  51. # Down (positive Y)
  52. np.array((0, 1)),
  53. # Pointing left (negative X)
  54. np.array((-1, 0)),
  55. # Up (negative Y)
  56. np.array((0, -1)),
  57. ]
  58. class WorldObj:
  59. """
  60. Base class for grid world objects
  61. """
  62. def __init__(self, type, color):
  63. assert type in OBJECT_TO_IDX, type
  64. assert color in COLOR_TO_IDX, color
  65. self.type = type
  66. self.color = color
  67. self.contains = None
  68. # Initial position of the object
  69. self.init_pos = None
  70. # Current position of the object
  71. self.cur_pos = None
  72. def can_overlap(self):
  73. """Can the agent overlap with this?"""
  74. return False
  75. def can_pickup(self):
  76. """Can the agent pick this up?"""
  77. return False
  78. def can_contain(self):
  79. """Can this contain another object?"""
  80. return False
  81. def see_behind(self):
  82. """Can the agent see behind this object?"""
  83. return True
  84. def toggle(self, env, pos):
  85. """Method to trigger/toggle an action this object performs"""
  86. return False
  87. def render(self, r):
  88. """Draw this object with the given renderer"""
  89. raise NotImplementedError
  90. def _set_color(self, r):
  91. """Set the color of this object as the active drawing color"""
  92. c = COLORS[self.color]
  93. r.setLineColor(c[0], c[1], c[2])
  94. r.setColor(c[0], c[1], c[2])
  95. class Goal(WorldObj):
  96. def __init__(self):
  97. super().__init__('goal', 'green')
  98. def can_overlap(self):
  99. return True
  100. def render(self, r):
  101. self._set_color(r)
  102. r.drawPolygon([
  103. (0 , CELL_PIXELS),
  104. (CELL_PIXELS, CELL_PIXELS),
  105. (CELL_PIXELS, 0),
  106. (0 , 0)
  107. ])
  108. class Floor(WorldObj):
  109. """
  110. Colored floor tile the agent can walk over
  111. """
  112. def __init__(self, color='blue'):
  113. super().__init__('floor', color)
  114. def can_overlap(self):
  115. return True
  116. def render(self, r):
  117. # Give the floor a pale color
  118. c = COLORS[self.color]
  119. r.setLineColor(100, 100, 100, 0)
  120. r.setColor(*c/2)
  121. r.drawPolygon([
  122. (1 , CELL_PIXELS),
  123. (CELL_PIXELS, CELL_PIXELS),
  124. (CELL_PIXELS, 1),
  125. (1 , 1)
  126. ])
  127. class Lava(WorldObj):
  128. def __init__(self):
  129. super().__init__('lava', 'red')
  130. def can_overlap(self):
  131. return True
  132. def render(self, r):
  133. orange = 255, 128, 0
  134. r.setLineColor(*orange)
  135. r.setColor(*orange)
  136. r.drawPolygon([
  137. (0 , CELL_PIXELS),
  138. (CELL_PIXELS, CELL_PIXELS),
  139. (CELL_PIXELS, 0),
  140. (0 , 0)
  141. ])
  142. # drawing the waves
  143. r.setLineColor(0, 0, 0)
  144. r.drawPolyline([
  145. (.1 * CELL_PIXELS, .3 * CELL_PIXELS),
  146. (.3 * CELL_PIXELS, .4 * CELL_PIXELS),
  147. (.5 * CELL_PIXELS, .3 * CELL_PIXELS),
  148. (.7 * CELL_PIXELS, .4 * CELL_PIXELS),
  149. (.9 * CELL_PIXELS, .3 * CELL_PIXELS),
  150. ])
  151. r.drawPolyline([
  152. (.1 * CELL_PIXELS, .5 * CELL_PIXELS),
  153. (.3 * CELL_PIXELS, .6 * CELL_PIXELS),
  154. (.5 * CELL_PIXELS, .5 * CELL_PIXELS),
  155. (.7 * CELL_PIXELS, .6 * CELL_PIXELS),
  156. (.9 * CELL_PIXELS, .5 * CELL_PIXELS),
  157. ])
  158. r.drawPolyline([
  159. (.1 * CELL_PIXELS, .7 * CELL_PIXELS),
  160. (.3 * CELL_PIXELS, .8 * CELL_PIXELS),
  161. (.5 * CELL_PIXELS, .7 * CELL_PIXELS),
  162. (.7 * CELL_PIXELS, .8 * CELL_PIXELS),
  163. (.9 * CELL_PIXELS, .7 * CELL_PIXELS),
  164. ])
  165. class Wall(WorldObj):
  166. def __init__(self, color='grey'):
  167. super().__init__('wall', color)
  168. def see_behind(self):
  169. return False
  170. def render(self, r):
  171. self._set_color(r)
  172. r.drawPolygon([
  173. (0 , CELL_PIXELS),
  174. (CELL_PIXELS, CELL_PIXELS),
  175. (CELL_PIXELS, 0),
  176. (0 , 0)
  177. ])
  178. class Door(WorldObj):
  179. def __init__(self, color, is_open=False, is_locked=False):
  180. super().__init__('door', color)
  181. self.is_open = is_open
  182. self.is_locked = is_locked
  183. def can_overlap(self):
  184. """The agent can only walk over this cell when the door is open"""
  185. return self.is_open
  186. def see_behind(self):
  187. return self.is_open
  188. def toggle(self, env, pos):
  189. # If the player has the right key to open the door
  190. if self.is_locked:
  191. if isinstance(env.carrying, Key) and env.carrying.color == self.color:
  192. self.is_locked = False
  193. self.is_open = True
  194. return True
  195. return False
  196. self.is_open = not self.is_open
  197. return True
  198. def render(self, r):
  199. c = COLORS[self.color]
  200. r.setLineColor(c[0], c[1], c[2])
  201. r.setColor(c[0], c[1], c[2], 50 if self.is_locked else 0)
  202. if self.is_open:
  203. r.drawPolygon([
  204. (CELL_PIXELS-2, CELL_PIXELS),
  205. (CELL_PIXELS , CELL_PIXELS),
  206. (CELL_PIXELS , 0),
  207. (CELL_PIXELS-2, 0)
  208. ])
  209. return
  210. r.drawPolygon([
  211. (0 , CELL_PIXELS),
  212. (CELL_PIXELS, CELL_PIXELS),
  213. (CELL_PIXELS, 0),
  214. (0 , 0)
  215. ])
  216. r.drawPolygon([
  217. (2 , CELL_PIXELS-2),
  218. (CELL_PIXELS-2, CELL_PIXELS-2),
  219. (CELL_PIXELS-2, 2),
  220. (2 , 2)
  221. ])
  222. if self.is_locked:
  223. # Draw key slot
  224. r.drawLine(
  225. CELL_PIXELS * 0.55,
  226. CELL_PIXELS * 0.5,
  227. CELL_PIXELS * 0.75,
  228. CELL_PIXELS * 0.5
  229. )
  230. else:
  231. # Draw door handle
  232. r.drawCircle(CELL_PIXELS * 0.75, CELL_PIXELS * 0.5, 2)
  233. class Key(WorldObj):
  234. def __init__(self, color='blue'):
  235. super(Key, self).__init__('key', color)
  236. def can_pickup(self):
  237. return True
  238. def render(self, r):
  239. self._set_color(r)
  240. # Vertical quad
  241. r.drawPolygon([
  242. (16, 10),
  243. (20, 10),
  244. (20, 28),
  245. (16, 28)
  246. ])
  247. # Teeth
  248. r.drawPolygon([
  249. (12, 19),
  250. (16, 19),
  251. (16, 21),
  252. (12, 21)
  253. ])
  254. r.drawPolygon([
  255. (12, 26),
  256. (16, 26),
  257. (16, 28),
  258. (12, 28)
  259. ])
  260. r.drawCircle(18, 9, 6)
  261. r.setLineColor(0, 0, 0)
  262. r.setColor(0, 0, 0)
  263. r.drawCircle(18, 9, 2)
  264. class Ball(WorldObj):
  265. def __init__(self, color='blue'):
  266. super(Ball, self).__init__('ball', color)
  267. def can_pickup(self):
  268. return True
  269. def render(self, r):
  270. self._set_color(r)
  271. r.drawCircle(CELL_PIXELS * 0.5, CELL_PIXELS * 0.5, 10)
  272. class Box(WorldObj):
  273. def __init__(self, color, contains=None):
  274. super(Box, self).__init__('box', color)
  275. self.contains = contains
  276. def can_pickup(self):
  277. return True
  278. def render(self, r):
  279. c = COLORS[self.color]
  280. r.setLineColor(c[0], c[1], c[2])
  281. r.setColor(0, 0, 0)
  282. r.setLineWidth(2)
  283. r.drawPolygon([
  284. (4 , CELL_PIXELS-4),
  285. (CELL_PIXELS-4, CELL_PIXELS-4),
  286. (CELL_PIXELS-4, 4),
  287. (4 , 4)
  288. ])
  289. r.drawLine(
  290. 4,
  291. CELL_PIXELS / 2,
  292. CELL_PIXELS - 4,
  293. CELL_PIXELS / 2
  294. )
  295. r.setLineWidth(1)
  296. def toggle(self, env, pos):
  297. # Replace the box by its contents
  298. env.grid.set(*pos, self.contains)
  299. return True
  300. class Grid:
  301. """
  302. Represent a grid and operations on it
  303. """
  304. def __init__(self, width, height):
  305. assert width >= 3
  306. assert height >= 3
  307. self.width = width
  308. self.height = height
  309. self.grid = [None] * width * height
  310. def __contains__(self, key):
  311. if isinstance(key, WorldObj):
  312. for e in self.grid:
  313. if e is key:
  314. return True
  315. elif isinstance(key, tuple):
  316. for e in self.grid:
  317. if e is None:
  318. continue
  319. if (e.color, e.type) == key:
  320. return True
  321. if key[0] is None and key[1] == e.type:
  322. return True
  323. return False
  324. def __eq__(self, other):
  325. grid1 = self.encode()
  326. grid2 = other.encode()
  327. return np.array_equal(grid2, grid1)
  328. def __ne__(self, other):
  329. return not self == other
  330. def copy(self):
  331. from copy import deepcopy
  332. return deepcopy(self)
  333. def set(self, i, j, v):
  334. assert i >= 0 and i < self.width
  335. assert j >= 0 and j < self.height
  336. self.grid[j * self.width + i] = v
  337. def get(self, i, j):
  338. assert i >= 0 and i < self.width
  339. assert j >= 0 and j < self.height
  340. return self.grid[j * self.width + i]
  341. def horz_wall(self, x, y, length=None):
  342. if length is None:
  343. length = self.width - x
  344. for i in range(0, length):
  345. self.set(x + i, y, Wall())
  346. def vert_wall(self, x, y, length=None):
  347. if length is None:
  348. length = self.height - y
  349. for j in range(0, length):
  350. self.set(x, y + j, Wall())
  351. def wall_rect(self, x, y, w, h):
  352. self.horz_wall(x, y, w)
  353. self.horz_wall(x, y+h-1, w)
  354. self.vert_wall(x, y, h)
  355. self.vert_wall(x+w-1, y, h)
  356. def rotate_left(self):
  357. """
  358. Rotate the grid to the left (counter-clockwise)
  359. """
  360. grid = Grid(self.height, self.width)
  361. for i in range(self.width):
  362. for j in range(self.height):
  363. v = self.get(i, j)
  364. grid.set(j, grid.height - 1 - i, v)
  365. return grid
  366. def slice(self, topX, topY, width, height):
  367. """
  368. Get a subset of the grid
  369. """
  370. grid = Grid(width, height)
  371. for j in range(0, height):
  372. for i in range(0, width):
  373. x = topX + i
  374. y = topY + j
  375. if x >= 0 and x < self.width and \
  376. y >= 0 and y < self.height:
  377. v = self.get(x, y)
  378. else:
  379. v = Wall()
  380. grid.set(i, j, v)
  381. return grid
  382. def render(self, r, tile_size):
  383. """
  384. Render this grid at a given scale
  385. :param r: target renderer object
  386. :param tile_size: tile size in pixels
  387. """
  388. assert r.width == self.width * tile_size
  389. assert r.height == self.height * tile_size
  390. # Total grid size at native scale
  391. widthPx = self.width * CELL_PIXELS
  392. heightPx = self.height * CELL_PIXELS
  393. r.push()
  394. # Internally, we draw at the "large" full-grid resolution, but we
  395. # use the renderer to scale back to the desired size
  396. r.scale(tile_size / CELL_PIXELS, tile_size / CELL_PIXELS)
  397. # Draw the background of the in-world cells black
  398. r.fillRect(
  399. 0,
  400. 0,
  401. widthPx,
  402. heightPx,
  403. 0, 0, 0
  404. )
  405. # Draw grid lines
  406. r.setLineColor(100, 100, 100)
  407. for rowIdx in range(0, self.height):
  408. y = CELL_PIXELS * rowIdx
  409. r.drawLine(0, y, widthPx, y)
  410. for colIdx in range(0, self.width):
  411. x = CELL_PIXELS * colIdx
  412. r.drawLine(x, 0, x, heightPx)
  413. # Render the grid
  414. for j in range(0, self.height):
  415. for i in range(0, self.width):
  416. cell = self.get(i, j)
  417. if cell == None:
  418. continue
  419. r.push()
  420. r.translate(i * CELL_PIXELS, j * CELL_PIXELS)
  421. cell.render(r)
  422. r.pop()
  423. r.pop()
  424. def encode(self, vis_mask=None):
  425. """
  426. Produce a compact numpy encoding of the grid
  427. """
  428. if vis_mask is None:
  429. vis_mask = np.ones((self.width, self.height), dtype=bool)
  430. array = np.zeros((self.width, self.height, 3), dtype='uint8')
  431. for i in range(self.width):
  432. for j in range(self.height):
  433. if vis_mask[i, j]:
  434. v = self.get(i, j)
  435. if v is None:
  436. array[i, j, 0] = OBJECT_TO_IDX['empty']
  437. array[i, j, 1] = 0
  438. array[i, j, 2] = 0
  439. else:
  440. # State, 0: open, 1: closed, 2: locked
  441. state = 0
  442. if hasattr(v, 'is_open') and not v.is_open:
  443. state = 1
  444. if hasattr(v, 'is_locked') and v.is_locked:
  445. state = 2
  446. array[i, j, 0] = OBJECT_TO_IDX[v.type]
  447. array[i, j, 1] = COLOR_TO_IDX[v.color]
  448. array[i, j, 2] = state
  449. return array
  450. @staticmethod
  451. def decode(array):
  452. """
  453. Decode an array grid encoding back into a grid
  454. """
  455. width, height, channels = array.shape
  456. assert channels == 3
  457. grid = Grid(width, height)
  458. for i in range(width):
  459. for j in range(height):
  460. typeIdx, colorIdx, state = array[i, j]
  461. if typeIdx == OBJECT_TO_IDX['unseen'] or \
  462. typeIdx == OBJECT_TO_IDX['empty']:
  463. continue
  464. objType = IDX_TO_OBJECT[typeIdx]
  465. color = IDX_TO_COLOR[colorIdx]
  466. # State, 0: open, 1: closed, 2: locked
  467. is_open = state == 0
  468. is_locked = state == 2
  469. if objType == 'wall':
  470. v = Wall(color)
  471. elif objType == 'floor':
  472. v = Floor(color)
  473. elif objType == 'ball':
  474. v = Ball(color)
  475. elif objType == 'key':
  476. v = Key(color)
  477. elif objType == 'box':
  478. v = Box(color)
  479. elif objType == 'door':
  480. v = Door(color, is_open, is_locked)
  481. elif objType == 'goal':
  482. v = Goal()
  483. elif objType == 'lava':
  484. v = Lava()
  485. else:
  486. assert False, "unknown obj type in decode '%s'" % objType
  487. grid.set(i, j, v)
  488. return grid
  489. def process_vis(grid, agent_pos):
  490. mask = np.zeros(shape=(grid.width, grid.height), dtype=np.bool)
  491. mask[agent_pos[0], agent_pos[1]] = True
  492. for j in reversed(range(0, grid.height)):
  493. for i in range(0, grid.width-1):
  494. if not mask[i, j]:
  495. continue
  496. cell = grid.get(i, j)
  497. if cell and not cell.see_behind():
  498. continue
  499. mask[i+1, j] = True
  500. if j > 0:
  501. mask[i+1, j-1] = True
  502. mask[i, j-1] = True
  503. for i in reversed(range(1, grid.width)):
  504. if not mask[i, j]:
  505. continue
  506. cell = grid.get(i, j)
  507. if cell and not cell.see_behind():
  508. continue
  509. mask[i-1, j] = True
  510. if j > 0:
  511. mask[i-1, j-1] = True
  512. mask[i, j-1] = True
  513. for j in range(0, grid.height):
  514. for i in range(0, grid.width):
  515. if not mask[i, j]:
  516. grid.set(i, j, None)
  517. return mask
  518. class MiniGridEnv(gym.Env):
  519. """
  520. 2D grid world game environment
  521. """
  522. metadata = {
  523. 'render.modes': ['human', 'rgb_array', 'pixmap'],
  524. 'video.frames_per_second' : 10
  525. }
  526. # Enumeration of possible actions
  527. class Actions(IntEnum):
  528. # Turn left, turn right, move forward
  529. left = 0
  530. right = 1
  531. forward = 2
  532. # Pick up an object
  533. pickup = 3
  534. # Drop an object
  535. drop = 4
  536. # Toggle/activate an object
  537. toggle = 5
  538. # Done completing task
  539. done = 6
  540. def __init__(
  541. self,
  542. grid_size=None,
  543. width=None,
  544. height=None,
  545. max_steps=100,
  546. see_through_walls=False,
  547. seed=1337
  548. ):
  549. # Can't set both grid_size and width/height
  550. if grid_size:
  551. assert width == None and height == None
  552. width = grid_size
  553. height = grid_size
  554. # Action enumeration for this environment
  555. self.actions = MiniGridEnv.Actions
  556. # Actions are discrete integer values
  557. self.action_space = spaces.Discrete(len(self.actions))
  558. # Observations are dictionaries containing an
  559. # encoding of the grid and a textual 'mission' string
  560. self.observation_space = spaces.Box(
  561. low=0,
  562. high=255,
  563. shape=OBS_ARRAY_SIZE,
  564. dtype='uint8'
  565. )
  566. self.observation_space = spaces.Dict({
  567. 'image': self.observation_space
  568. })
  569. # Range of possible rewards
  570. self.reward_range = (0, 1)
  571. # Renderer object used to render the whole grid (full-scale)
  572. self.grid_render = None
  573. # Renderer used to render observations (small-scale agent view)
  574. self.obs_render = None
  575. # Environment configuration
  576. self.width = width
  577. self.height = height
  578. self.max_steps = max_steps
  579. self.see_through_walls = see_through_walls
  580. # Starting position and direction for the agent
  581. self.start_pos = None
  582. self.start_dir = None
  583. # Initialize the RNG
  584. self.seed(seed=seed)
  585. # Initialize the state
  586. self.reset()
  587. def reset(self):
  588. # Generate a new random grid at the start of each episode
  589. # To keep the same grid for each episode, call env.seed() with
  590. # the same seed before calling env.reset()
  591. self._gen_grid(self.width, self.height)
  592. # These fields should be defined by _gen_grid
  593. assert self.start_pos is not None
  594. assert self.start_dir is not None
  595. # Check that the agent doesn't overlap with an object
  596. start_cell = self.grid.get(*self.start_pos)
  597. assert start_cell is None or start_cell.can_overlap()
  598. # Place the agent in the starting position and direction
  599. self.agent_pos = self.start_pos
  600. self.agent_dir = self.start_dir
  601. # Item picked up, being carried, initially nothing
  602. self.carrying = None
  603. # Step count since episode start
  604. self.step_count = 0
  605. # Return first observation
  606. obs = self.gen_obs()
  607. return obs
  608. def seed(self, seed=1337):
  609. # Seed the random number generator
  610. self.np_random, _ = seeding.np_random(seed)
  611. return [seed]
  612. @property
  613. def steps_remaining(self):
  614. return self.max_steps - self.step_count
  615. def __str__(self):
  616. """
  617. Produce a pretty string of the environment's grid along with the agent.
  618. A grid cell is represented by 2-character string, the first one for
  619. the object and the second one for the color.
  620. """
  621. # Map of object types to short string
  622. OBJECT_TO_STR = {
  623. 'wall' : 'W',
  624. 'floor' : 'F',
  625. 'door' : 'D',
  626. 'key' : 'K',
  627. 'ball' : 'A',
  628. 'box' : 'B',
  629. 'goal' : 'G',
  630. 'lava' : 'V',
  631. }
  632. # Short string for opened door
  633. OPENDED_DOOR_IDS = '_'
  634. # Map agent's direction to short string
  635. AGENT_DIR_TO_STR = {
  636. 0: '>',
  637. 1: 'V',
  638. 2: '<',
  639. 3: '^'
  640. }
  641. str = ''
  642. for j in range(self.grid.height):
  643. for i in range(self.grid.width):
  644. if i == self.agent_pos[0] and j == self.agent_pos[1]:
  645. str += 2 * AGENT_DIR_TO_STR[self.agent_dir]
  646. continue
  647. c = self.grid.get(i, j)
  648. if c == None:
  649. str += ' '
  650. continue
  651. if c.type == 'door':
  652. if c.is_open:
  653. str += '__'
  654. elif c.is_locked:
  655. str += 'L' + c.color[0].upper()
  656. else:
  657. str += 'D' + c.color[0].upper()
  658. continue
  659. str += OBJECT_TO_STR[c.type] + c.color[0].upper()
  660. if j < self.grid.height - 1:
  661. str += '\n'
  662. return str
  663. def _gen_grid(self, width, height):
  664. assert False, "_gen_grid needs to be implemented by each environment"
  665. def _reward(self):
  666. """
  667. Compute the reward to be given upon success
  668. """
  669. return 1 - 0.9 * (self.step_count / self.max_steps)
  670. def _rand_int(self, low, high):
  671. """
  672. Generate random integer in [low,high[
  673. """
  674. return self.np_random.randint(low, high)
  675. def _rand_float(self, low, high):
  676. """
  677. Generate random float in [low,high[
  678. """
  679. return self.np_random.uniform(low, high)
  680. def _rand_bool(self):
  681. """
  682. Generate random boolean value
  683. """
  684. return (self.np_random.randint(0, 2) == 0)
  685. def _rand_elem(self, iterable):
  686. """
  687. Pick a random element in a list
  688. """
  689. lst = list(iterable)
  690. idx = self._rand_int(0, len(lst))
  691. return lst[idx]
  692. def _rand_subset(self, iterable, num_elems):
  693. """
  694. Sample a random subset of distinct elements of a list
  695. """
  696. lst = list(iterable)
  697. assert num_elems <= len(lst)
  698. out = []
  699. while len(out) < num_elems:
  700. elem = self._rand_elem(lst)
  701. lst.remove(elem)
  702. out.append(elem)
  703. return out
  704. def _rand_color(self):
  705. """
  706. Generate a random color name (string)
  707. """
  708. return self._rand_elem(COLOR_NAMES)
  709. def _rand_pos(self, xLow, xHigh, yLow, yHigh):
  710. """
  711. Generate a random (x,y) position tuple
  712. """
  713. return (
  714. self.np_random.randint(xLow, xHigh),
  715. self.np_random.randint(yLow, yHigh)
  716. )
  717. def place_obj(self,
  718. obj,
  719. top=None,
  720. size=None,
  721. reject_fn=None,
  722. max_tries=math.inf
  723. ):
  724. """
  725. Place an object at an empty position in the grid
  726. :param top: top-left position of the rectangle where to place
  727. :param size: size of the rectangle where to place
  728. :param reject_fn: function to filter out potential positions
  729. """
  730. if top is None:
  731. top = (0, 0)
  732. if size is None:
  733. size = (self.grid.width, self.grid.height)
  734. num_tries = 0
  735. while True:
  736. # This is to handle with rare cases where rejection sampling
  737. # gets stuck in an infinite loop
  738. if num_tries > max_tries:
  739. raise RecursionError('rejection sampling failed in place_obj')
  740. num_tries += 1
  741. pos = np.array((
  742. self._rand_int(top[0], top[0] + size[0]),
  743. self._rand_int(top[1], top[1] + size[1])
  744. ))
  745. # Don't place the object on top of another object
  746. if self.grid.get(*pos) != None:
  747. continue
  748. # Don't place the object where the agent is
  749. if np.array_equal(pos, self.start_pos):
  750. continue
  751. # Check if there is a filtering criterion
  752. if reject_fn and reject_fn(self, pos):
  753. continue
  754. break
  755. self.grid.set(*pos, obj)
  756. if obj is not None:
  757. obj.init_pos = pos
  758. obj.cur_pos = pos
  759. return pos
  760. def place_agent(
  761. self,
  762. top=None,
  763. size=None,
  764. rand_dir=True,
  765. max_tries=math.inf
  766. ):
  767. """
  768. Set the agent's starting point at an empty position in the grid
  769. """
  770. self.start_pos = None
  771. pos = self.place_obj(None, top, size, max_tries=max_tries)
  772. self.start_pos = pos
  773. if rand_dir:
  774. self.start_dir = self._rand_int(0, 4)
  775. return pos
  776. @property
  777. def dir_vec(self):
  778. """
  779. Get the direction vector for the agent, pointing in the direction
  780. of forward movement.
  781. """
  782. assert self.agent_dir >= 0 and self.agent_dir < 4
  783. return DIR_TO_VEC[self.agent_dir]
  784. @property
  785. def right_vec(self):
  786. """
  787. Get the vector pointing to the right of the agent.
  788. """
  789. dx, dy = self.dir_vec
  790. return np.array((-dy, dx))
  791. @property
  792. def front_pos(self):
  793. """
  794. Get the position of the cell that is right in front of the agent
  795. """
  796. return self.agent_pos + self.dir_vec
  797. def get_view_coords(self, i, j):
  798. """
  799. Translate and rotate absolute grid coordinates (i, j) into the
  800. agent's partially observable view (sub-grid). Note that the resulting
  801. coordinates may be negative or outside of the agent's view size.
  802. """
  803. ax, ay = self.agent_pos
  804. dx, dy = self.dir_vec
  805. rx, ry = self.right_vec
  806. # Compute the absolute coordinates of the top-left view corner
  807. sz = AGENT_VIEW_SIZE
  808. hs = AGENT_VIEW_SIZE // 2
  809. tx = ax + (dx * (sz-1)) - (rx * hs)
  810. ty = ay + (dy * (sz-1)) - (ry * hs)
  811. lx = i - tx
  812. ly = j - ty
  813. # Project the coordinates of the object relative to the top-left
  814. # corner onto the agent's own coordinate system
  815. vx = (rx*lx + ry*ly)
  816. vy = -(dx*lx + dy*ly)
  817. return vx, vy
  818. def get_view_exts(self):
  819. """
  820. Get the extents of the square set of tiles visible to the agent
  821. Note: the bottom extent indices are not included in the set
  822. """
  823. # Facing right
  824. if self.agent_dir == 0:
  825. topX = self.agent_pos[0]
  826. topY = self.agent_pos[1] - AGENT_VIEW_SIZE // 2
  827. # Facing down
  828. elif self.agent_dir == 1:
  829. topX = self.agent_pos[0] - AGENT_VIEW_SIZE // 2
  830. topY = self.agent_pos[1]
  831. # Facing left
  832. elif self.agent_dir == 2:
  833. topX = self.agent_pos[0] - AGENT_VIEW_SIZE + 1
  834. topY = self.agent_pos[1] - AGENT_VIEW_SIZE // 2
  835. # Facing up
  836. elif self.agent_dir == 3:
  837. topX = self.agent_pos[0] - AGENT_VIEW_SIZE // 2
  838. topY = self.agent_pos[1] - AGENT_VIEW_SIZE + 1
  839. else:
  840. assert False, "invalid agent direction"
  841. botX = topX + AGENT_VIEW_SIZE
  842. botY = topY + AGENT_VIEW_SIZE
  843. return (topX, topY, botX, botY)
  844. def relative_coords(self, x, y):
  845. """
  846. Check if a grid position belongs to the agent's field of view, and returns the corresponding coordinates
  847. """
  848. vx, vy = self.get_view_coords(x, y)
  849. if vx < 0 or vy < 0 or vx >= AGENT_VIEW_SIZE or vy >= AGENT_VIEW_SIZE:
  850. return None
  851. return vx, vy
  852. def in_view(self, x, y):
  853. """
  854. check if a grid position is visible to the agent
  855. """
  856. return self.relative_coords(x, y) is not None
  857. def agent_sees(self, x, y):
  858. """
  859. Check if a non-empty grid position is visible to the agent
  860. """
  861. coordinates = self.relative_coords(x, y)
  862. if coordinates is None:
  863. return False
  864. vx, vy = coordinates
  865. obs = self.gen_obs()
  866. obs_grid = Grid.decode(obs['image'])
  867. obs_cell = obs_grid.get(vx, vy)
  868. world_cell = self.grid.get(x, y)
  869. return obs_cell is not None and obs_cell.type == world_cell.type
  870. def step(self, action):
  871. self.step_count += 1
  872. reward = 0
  873. done = False
  874. # Get the position in front of the agent
  875. fwd_pos = self.front_pos
  876. # Get the contents of the cell in front of the agent
  877. fwd_cell = self.grid.get(*fwd_pos)
  878. # Rotate left
  879. if action == self.actions.left:
  880. self.agent_dir -= 1
  881. if self.agent_dir < 0:
  882. self.agent_dir += 4
  883. # Rotate right
  884. elif action == self.actions.right:
  885. self.agent_dir = (self.agent_dir + 1) % 4
  886. # Move forward
  887. elif action == self.actions.forward:
  888. if fwd_cell == None or fwd_cell.can_overlap():
  889. self.agent_pos = fwd_pos
  890. if fwd_cell != None and fwd_cell.type == 'goal':
  891. done = True
  892. reward = self._reward()
  893. if fwd_cell != None and fwd_cell.type == 'lava':
  894. done = True
  895. # Pick up an object
  896. elif action == self.actions.pickup:
  897. if fwd_cell and fwd_cell.can_pickup():
  898. if self.carrying is None:
  899. self.carrying = fwd_cell
  900. self.carrying.cur_pos = np.array([-1, -1])
  901. self.grid.set(*fwd_pos, None)
  902. # Drop an object
  903. elif action == self.actions.drop:
  904. if not fwd_cell and self.carrying:
  905. self.grid.set(*fwd_pos, self.carrying)
  906. self.carrying.cur_pos = fwd_pos
  907. self.carrying = None
  908. # Toggle/activate an object
  909. elif action == self.actions.toggle:
  910. if fwd_cell:
  911. fwd_cell.toggle(self, fwd_pos)
  912. # Done action (not used by default)
  913. elif action == self.actions.done:
  914. pass
  915. else:
  916. assert False, "unknown action"
  917. if self.step_count >= self.max_steps:
  918. done = True
  919. obs = self.gen_obs()
  920. return obs, reward, done, {}
  921. def gen_obs_grid(self):
  922. """
  923. Generate the sub-grid observed by the agent.
  924. This method also outputs a visibility mask telling us which grid
  925. cells the agent can actually see.
  926. """
  927. topX, topY, botX, botY = self.get_view_exts()
  928. grid = self.grid.slice(topX, topY, AGENT_VIEW_SIZE, AGENT_VIEW_SIZE)
  929. for i in range(self.agent_dir + 1):
  930. grid = grid.rotate_left()
  931. # Process occluders and visibility
  932. # Note that this incurs some performance cost
  933. if not self.see_through_walls:
  934. vis_mask = grid.process_vis(agent_pos=(AGENT_VIEW_SIZE // 2 , AGENT_VIEW_SIZE - 1))
  935. else:
  936. vis_mask = np.ones(shape=(grid.width, grid.height), dtype=np.bool)
  937. # Make it so the agent sees what it's carrying
  938. # We do this by placing the carried object at the agent's position
  939. # in the agent's partially observable view
  940. agent_pos = grid.width // 2, grid.height - 1
  941. if self.carrying:
  942. grid.set(*agent_pos, self.carrying)
  943. else:
  944. grid.set(*agent_pos, None)
  945. return grid, vis_mask
  946. def gen_obs(self):
  947. """
  948. Generate the agent's view (partially observable, low-resolution encoding)
  949. """
  950. grid, vis_mask = self.gen_obs_grid()
  951. # Encode the partially observable view into a numpy array
  952. image = grid.encode(vis_mask)
  953. assert hasattr(self, 'mission'), "environments must define a textual mission string"
  954. # Observations are dictionaries containing:
  955. # - an image (partially observable view of the environment)
  956. # - the agent's direction/orientation (acting as a compass)
  957. # - a textual mission string (instructions for the agent)
  958. obs = {
  959. 'image': image,
  960. 'direction': self.agent_dir,
  961. 'mission': self.mission
  962. }
  963. return obs
  964. def get_obs_render(self, obs, tile_pixels=CELL_PIXELS//2):
  965. """
  966. Render an agent observation for visualization
  967. """
  968. if self.obs_render == None:
  969. from gym_minigrid.rendering import Renderer
  970. self.obs_render = Renderer(
  971. AGENT_VIEW_SIZE * tile_pixels,
  972. AGENT_VIEW_SIZE * tile_pixels
  973. )
  974. r = self.obs_render
  975. r.beginFrame()
  976. grid = Grid.decode(obs)
  977. # Render the whole grid
  978. grid.render(r, tile_pixels)
  979. # Draw the agent
  980. ratio = tile_pixels / CELL_PIXELS
  981. r.push()
  982. r.scale(ratio, ratio)
  983. r.translate(
  984. CELL_PIXELS * (0.5 + AGENT_VIEW_SIZE // 2),
  985. CELL_PIXELS * (AGENT_VIEW_SIZE - 0.5)
  986. )
  987. r.rotate(3 * 90)
  988. r.setLineColor(255, 0, 0)
  989. r.setColor(255, 0, 0)
  990. r.drawPolygon([
  991. (-12, 10),
  992. ( 12, 0),
  993. (-12, -10)
  994. ])
  995. r.pop()
  996. r.endFrame()
  997. return r.getPixmap()
  998. def render(self, mode='human', close=False):
  999. """
  1000. Render the whole-grid human view
  1001. """
  1002. if close:
  1003. if self.grid_render:
  1004. self.grid_render.close()
  1005. return
  1006. if self.grid_render is None:
  1007. from gym_minigrid.rendering import Renderer
  1008. self.grid_render = Renderer(
  1009. self.width * CELL_PIXELS,
  1010. self.height * CELL_PIXELS,
  1011. True if mode == 'human' else False
  1012. )
  1013. r = self.grid_render
  1014. if r.window:
  1015. r.window.setText(self.mission)
  1016. r.beginFrame()
  1017. # Render the whole grid
  1018. self.grid.render(r, CELL_PIXELS)
  1019. # Draw the agent
  1020. r.push()
  1021. r.translate(
  1022. CELL_PIXELS * (self.agent_pos[0] + 0.5),
  1023. CELL_PIXELS * (self.agent_pos[1] + 0.5)
  1024. )
  1025. r.rotate(self.agent_dir * 90)
  1026. r.setLineColor(255, 0, 0)
  1027. r.setColor(255, 0, 0)
  1028. r.drawPolygon([
  1029. (-12, 10),
  1030. ( 12, 0),
  1031. (-12, -10)
  1032. ])
  1033. r.pop()
  1034. # Compute which cells are visible to the agent
  1035. _, vis_mask = self.gen_obs_grid()
  1036. # Compute the absolute coordinates of the bottom-left corner
  1037. # of the agent's view area
  1038. f_vec = self.dir_vec
  1039. r_vec = self.right_vec
  1040. top_left = self.agent_pos + f_vec * (AGENT_VIEW_SIZE-1) - r_vec * (AGENT_VIEW_SIZE // 2)
  1041. # For each cell in the visibility mask
  1042. for vis_j in range(0, AGENT_VIEW_SIZE):
  1043. for vis_i in range(0, AGENT_VIEW_SIZE):
  1044. # If this cell is not visible, don't highlight it
  1045. if not vis_mask[vis_i, vis_j]:
  1046. continue
  1047. # Compute the world coordinates of this cell
  1048. abs_i, abs_j = top_left - (f_vec * vis_j) + (r_vec * vis_i)
  1049. # Highlight the cell
  1050. r.fillRect(
  1051. abs_i * CELL_PIXELS,
  1052. abs_j * CELL_PIXELS,
  1053. CELL_PIXELS,
  1054. CELL_PIXELS,
  1055. 255, 255, 255, 75
  1056. )
  1057. r.endFrame()
  1058. if mode == 'rgb_array':
  1059. return r.getArray()
  1060. elif mode == 'pixmap':
  1061. return r.getPixmap()
  1062. return r