minigrid.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311
  1. import math
  2. import gym
  3. from enum import IntEnum
  4. import numpy as np
  5. from gym import error, spaces, utils
  6. from gym.utils import seeding
  7. from gym_minigrid.rendering import *
  8. # Size in pixels of a cell in the full-scale human view
  9. CELL_PIXELS = 32
  10. # Number of cells (width and height) in the agent view
  11. AGENT_VIEW_SIZE = 7
  12. # Size of the array given as an observation to the agent
  13. OBS_ARRAY_SIZE = (AGENT_VIEW_SIZE, AGENT_VIEW_SIZE, 3)
  14. # Map of color names to RGB values
  15. COLORS = {
  16. 'red' : np.array([255, 0, 0]),
  17. 'green' : np.array([0, 255, 0]),
  18. 'blue' : np.array([0, 0, 255]),
  19. 'purple': np.array([112, 39, 195]),
  20. 'yellow': np.array([255, 255, 0]),
  21. 'grey' : np.array([100, 100, 100])
  22. }
  23. COLOR_NAMES = sorted(list(COLORS.keys()))
  24. # Used to map colors to integers
  25. COLOR_TO_IDX = {
  26. 'red' : 0,
  27. 'green' : 1,
  28. 'blue' : 2,
  29. 'purple': 3,
  30. 'yellow': 4,
  31. 'grey' : 5
  32. }
  33. IDX_TO_COLOR = dict(zip(COLOR_TO_IDX.values(), COLOR_TO_IDX.keys()))
  34. # Map of object type to integers
  35. OBJECT_TO_IDX = {
  36. 'empty' : 0,
  37. 'wall' : 1,
  38. 'floor' : 2,
  39. 'door' : 3,
  40. 'locked_door' : 4,
  41. 'key' : 5,
  42. 'ball' : 6,
  43. 'box' : 7,
  44. 'goal' : 8
  45. }
  46. IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys()))
  47. # Map of agent direction indices to vectors
  48. DIR_TO_VEC = [
  49. # Pointing right (positive X)
  50. np.array((1, 0)),
  51. # Down (positive Y)
  52. np.array((0, 1)),
  53. # Pointing left (negative X)
  54. np.array((-1, 0)),
  55. # Up (negative Y)
  56. np.array((0, -1)),
  57. ]
  58. class WorldObj:
  59. """
  60. Base class for grid world objects
  61. """
  62. def __init__(self, type, color):
  63. assert type in OBJECT_TO_IDX, type
  64. assert color in COLOR_TO_IDX, color
  65. self.type = type
  66. self.color = color
  67. self.contains = None
  68. # Initial position of the object
  69. self.init_pos = None
  70. # Current position of the object
  71. self.cur_pos = None
  72. def can_overlap(self):
  73. """Can the agent overlap with this?"""
  74. return False
  75. def can_pickup(self):
  76. """Can the agent pick this up?"""
  77. return False
  78. def can_contain(self):
  79. """Can this contain another object?"""
  80. return False
  81. def see_behind(self):
  82. """Can the agent see behind this object?"""
  83. return True
  84. def toggle(self, env, pos):
  85. """Method to trigger/toggle an action this object performs"""
  86. return False
  87. def render(self, r):
  88. """Draw this object with the given renderer"""
  89. raise NotImplementedError
  90. def _set_color(self, r):
  91. """Set the color of this object as the active drawing color"""
  92. c = COLORS[self.color]
  93. r.setLineColor(c[0], c[1], c[2])
  94. r.setColor(c[0], c[1], c[2])
  95. class Goal(WorldObj):
  96. def __init__(self):
  97. super().__init__('goal', 'green')
  98. def can_overlap(self):
  99. return True
  100. def render(self, r):
  101. self._set_color(r)
  102. r.drawPolygon([
  103. (0 , CELL_PIXELS),
  104. (CELL_PIXELS, CELL_PIXELS),
  105. (CELL_PIXELS, 0),
  106. (0 , 0)
  107. ])
  108. class Floor(WorldObj):
  109. """
  110. Colored floor tile the agent can walk over
  111. """
  112. def __init__(self, color='blue'):
  113. super().__init__('floor', color)
  114. def can_overlap(self):
  115. return True
  116. def render(self, r):
  117. # Give the floor a pale color
  118. c = COLORS[self.color]
  119. r.setLineColor(100, 100, 100, 0)
  120. r.setColor(*c/2)
  121. r.drawPolygon([
  122. (1 , CELL_PIXELS),
  123. (CELL_PIXELS, CELL_PIXELS),
  124. (CELL_PIXELS, 1),
  125. (1 , 1)
  126. ])
  127. class Wall(WorldObj):
  128. def __init__(self, color='grey'):
  129. super().__init__('wall', color)
  130. def see_behind(self):
  131. return False
  132. def render(self, r):
  133. self._set_color(r)
  134. r.drawPolygon([
  135. (0 , CELL_PIXELS),
  136. (CELL_PIXELS, CELL_PIXELS),
  137. (CELL_PIXELS, 0),
  138. (0 , 0)
  139. ])
  140. class Door(WorldObj):
  141. def __init__(self, color, is_open=False):
  142. super().__init__('door', color)
  143. self.is_open = is_open
  144. def can_overlap(self):
  145. """The agent can only walk over this cell when the door is open"""
  146. return self.is_open
  147. def see_behind(self):
  148. return self.is_open
  149. def toggle(self, env, pos):
  150. self.is_open = not self.is_open
  151. return True
  152. def render(self, r):
  153. c = COLORS[self.color]
  154. r.setLineColor(c[0], c[1], c[2])
  155. r.setColor(0, 0, 0)
  156. if self.is_open:
  157. r.drawPolygon([
  158. (CELL_PIXELS-2, CELL_PIXELS),
  159. (CELL_PIXELS , CELL_PIXELS),
  160. (CELL_PIXELS , 0),
  161. (CELL_PIXELS-2, 0)
  162. ])
  163. return
  164. r.drawPolygon([
  165. (0 , CELL_PIXELS),
  166. (CELL_PIXELS, CELL_PIXELS),
  167. (CELL_PIXELS, 0),
  168. (0 , 0)
  169. ])
  170. r.drawPolygon([
  171. (2 , CELL_PIXELS-2),
  172. (CELL_PIXELS-2, CELL_PIXELS-2),
  173. (CELL_PIXELS-2, 2),
  174. (2 , 2)
  175. ])
  176. r.drawCircle(CELL_PIXELS * 0.75, CELL_PIXELS * 0.5, 2)
  177. class LockedDoor(WorldObj):
  178. def __init__(self, color, is_open=False):
  179. super(LockedDoor, self).__init__('locked_door', color)
  180. self.is_open = is_open
  181. def toggle(self, env, pos):
  182. # If the player has the right key to open the door
  183. if isinstance(env.carrying, Key) and env.carrying.color == self.color:
  184. self.is_open = True
  185. # The key has been used, remove it from the agent
  186. env.carrying = None
  187. return True
  188. return False
  189. def can_overlap(self):
  190. """The agent can only walk over this cell when the door is open"""
  191. return self.is_open
  192. def see_behind(self):
  193. return self.is_open
  194. def render(self, r):
  195. c = COLORS[self.color]
  196. r.setLineColor(c[0], c[1], c[2])
  197. r.setColor(c[0], c[1], c[2], 50)
  198. if self.is_open:
  199. r.drawPolygon([
  200. (CELL_PIXELS-2, CELL_PIXELS),
  201. (CELL_PIXELS , CELL_PIXELS),
  202. (CELL_PIXELS , 0),
  203. (CELL_PIXELS-2, 0)
  204. ])
  205. return
  206. r.drawPolygon([
  207. (0 , CELL_PIXELS),
  208. (CELL_PIXELS, CELL_PIXELS),
  209. (CELL_PIXELS, 0),
  210. (0 , 0)
  211. ])
  212. r.drawPolygon([
  213. (2 , CELL_PIXELS-2),
  214. (CELL_PIXELS-2, CELL_PIXELS-2),
  215. (CELL_PIXELS-2, 2),
  216. (2 , 2)
  217. ])
  218. r.drawLine(
  219. CELL_PIXELS * 0.55,
  220. CELL_PIXELS * 0.5,
  221. CELL_PIXELS * 0.75,
  222. CELL_PIXELS * 0.5
  223. )
  224. class Key(WorldObj):
  225. def __init__(self, color='blue'):
  226. super(Key, self).__init__('key', color)
  227. def can_pickup(self):
  228. return True
  229. def render(self, r):
  230. self._set_color(r)
  231. # Vertical quad
  232. r.drawPolygon([
  233. (16, 10),
  234. (20, 10),
  235. (20, 28),
  236. (16, 28)
  237. ])
  238. # Teeth
  239. r.drawPolygon([
  240. (12, 19),
  241. (16, 19),
  242. (16, 21),
  243. (12, 21)
  244. ])
  245. r.drawPolygon([
  246. (12, 26),
  247. (16, 26),
  248. (16, 28),
  249. (12, 28)
  250. ])
  251. r.drawCircle(18, 9, 6)
  252. r.setLineColor(0, 0, 0)
  253. r.setColor(0, 0, 0)
  254. r.drawCircle(18, 9, 2)
  255. class Ball(WorldObj):
  256. def __init__(self, color='blue'):
  257. super(Ball, self).__init__('ball', color)
  258. def can_pickup(self):
  259. return True
  260. def render(self, r):
  261. self._set_color(r)
  262. r.drawCircle(CELL_PIXELS * 0.5, CELL_PIXELS * 0.5, 10)
  263. class Box(WorldObj):
  264. def __init__(self, color, contains=None):
  265. super(Box, self).__init__('box', color)
  266. self.contains = contains
  267. def can_pickup(self):
  268. return True
  269. def render(self, r):
  270. c = COLORS[self.color]
  271. r.setLineColor(c[0], c[1], c[2])
  272. r.setColor(0, 0, 0)
  273. r.setLineWidth(2)
  274. r.drawPolygon([
  275. (4 , CELL_PIXELS-4),
  276. (CELL_PIXELS-4, CELL_PIXELS-4),
  277. (CELL_PIXELS-4, 4),
  278. (4 , 4)
  279. ])
  280. r.drawLine(
  281. 4,
  282. CELL_PIXELS / 2,
  283. CELL_PIXELS - 4,
  284. CELL_PIXELS / 2
  285. )
  286. r.setLineWidth(1)
  287. def toggle(self, env, pos):
  288. # Replace the box by its contents
  289. env.grid.set(*pos, self.contains)
  290. return True
  291. class Grid:
  292. """
  293. Represent a grid and operations on it
  294. """
  295. def __init__(self, width, height):
  296. assert width >= 4
  297. assert height >= 4
  298. self.width = width
  299. self.height = height
  300. self.grid = [None] * width * height
  301. def __contains__(self, key):
  302. if isinstance(key, WorldObj):
  303. for e in self.grid:
  304. if e is key:
  305. return True
  306. elif isinstance(key, tuple):
  307. for e in self.grid:
  308. if e is None:
  309. continue
  310. if (e.color, e.type) == key:
  311. return True
  312. if key[0] is None and key[1] == e.type:
  313. return True
  314. return False
  315. def __eq__(self, other):
  316. grid1 = self.encode()
  317. grid2 = other.encode()
  318. return np.array_equal(grid2, grid1)
  319. def __ne__(self, other):
  320. return not self == other
  321. def copy(self):
  322. from copy import deepcopy
  323. return deepcopy(self)
  324. def set(self, i, j, v):
  325. assert i >= 0 and i < self.width
  326. assert j >= 0 and j < self.height
  327. self.grid[j * self.width + i] = v
  328. def get(self, i, j):
  329. assert i >= 0 and i < self.width
  330. assert j >= 0 and j < self.height
  331. return self.grid[j * self.width + i]
  332. def horz_wall(self, x, y, length=None):
  333. if length is None:
  334. length = self.width - x
  335. for i in range(0, length):
  336. self.set(x + i, y, Wall())
  337. def vert_wall(self, x, y, length=None):
  338. if length is None:
  339. length = self.height - y
  340. for j in range(0, length):
  341. self.set(x, y + j, Wall())
  342. def wall_rect(self, x, y, w, h):
  343. self.horz_wall(x, y, w)
  344. self.horz_wall(x, y+h-1, w)
  345. self.vert_wall(x, y, h)
  346. self.vert_wall(x+w-1, y, h)
  347. def rotate_left(self):
  348. """
  349. Rotate the grid to the left (counter-clockwise)
  350. """
  351. grid = Grid(self.width, self.height)
  352. for j in range(0, self.height):
  353. for i in range(0, self.width):
  354. v = self.get(self.width - 1 - j, i)
  355. grid.set(i, j, v)
  356. return grid
  357. def slice(self, topX, topY, width, height):
  358. """
  359. Get a subset of the grid
  360. """
  361. grid = Grid(width, height)
  362. for j in range(0, height):
  363. for i in range(0, width):
  364. x = topX + i
  365. y = topY + j
  366. if x >= 0 and x < self.width and \
  367. y >= 0 and y < self.height:
  368. v = self.get(x, y)
  369. else:
  370. v = Wall()
  371. grid.set(i, j, v)
  372. return grid
  373. def render(self, r, tile_size):
  374. """
  375. Render this grid at a given scale
  376. :param r: target renderer object
  377. :param tile_size: tile size in pixels
  378. """
  379. assert r.width == self.width * tile_size
  380. assert r.height == self.height * tile_size
  381. # Total grid size at native scale
  382. widthPx = self.width * CELL_PIXELS
  383. heightPx = self.height * CELL_PIXELS
  384. r.push()
  385. # Internally, we draw at the "large" full-grid resolution, but we
  386. # use the renderer to scale back to the desired size
  387. r.scale(tile_size / CELL_PIXELS, tile_size / CELL_PIXELS)
  388. # Draw the background of the in-world cells black
  389. r.fillRect(
  390. 0,
  391. 0,
  392. widthPx,
  393. heightPx,
  394. 0, 0, 0
  395. )
  396. # Draw grid lines
  397. r.setLineColor(100, 100, 100)
  398. for rowIdx in range(0, self.height):
  399. y = CELL_PIXELS * rowIdx
  400. r.drawLine(0, y, widthPx, y)
  401. for colIdx in range(0, self.width):
  402. x = CELL_PIXELS * colIdx
  403. r.drawLine(x, 0, x, heightPx)
  404. # Render the grid
  405. for j in range(0, self.height):
  406. for i in range(0, self.width):
  407. cell = self.get(i, j)
  408. if cell == None:
  409. continue
  410. r.push()
  411. r.translate(i * CELL_PIXELS, j * CELL_PIXELS)
  412. cell.render(r)
  413. r.pop()
  414. r.pop()
  415. def encode(self):
  416. """
  417. Produce a compact numpy encoding of the grid
  418. """
  419. codeSize = self.width * self.height * 3
  420. array = np.zeros(shape=(self.width, self.height, 3), dtype='uint8')
  421. for j in range(0, self.height):
  422. for i in range(0, self.width):
  423. v = self.get(i, j)
  424. if v == None:
  425. continue
  426. array[i, j, 0] = OBJECT_TO_IDX[v.type]
  427. array[i, j, 1] = COLOR_TO_IDX[v.color]
  428. if hasattr(v, 'is_open') and v.is_open:
  429. array[i, j, 2] = 1
  430. return array
  431. def decode(array):
  432. """
  433. Decode an array grid encoding back into a grid
  434. """
  435. width = array.shape[0]
  436. height = array.shape[1]
  437. assert array.shape[2] == 3
  438. grid = Grid(width, height)
  439. for j in range(0, height):
  440. for i in range(0, width):
  441. typeIdx = array[i, j, 0]
  442. colorIdx = array[i, j, 1]
  443. openIdx = array[i, j, 2]
  444. if typeIdx == 0:
  445. continue
  446. objType = IDX_TO_OBJECT[typeIdx]
  447. color = IDX_TO_COLOR[colorIdx]
  448. is_open = True if openIdx == 1 else 0
  449. if objType == 'wall':
  450. v = Wall(color)
  451. elif objType == 'floor':
  452. v = Floor(color)
  453. elif objType == 'ball':
  454. v = Ball(color)
  455. elif objType == 'key':
  456. v = Key(color)
  457. elif objType == 'box':
  458. v = Box(color)
  459. elif objType == 'door':
  460. v = Door(color, is_open)
  461. elif objType == 'locked_door':
  462. v = LockedDoor(color, is_open)
  463. elif objType == 'goal':
  464. v = Goal()
  465. else:
  466. assert False, "unknown obj type in decode '%s'" % objType
  467. grid.set(i, j, v)
  468. return grid
  469. def process_vis(grid, agent_pos):
  470. mask = np.zeros(shape=(grid.width, grid.height), dtype=np.bool)
  471. mask[agent_pos[0], agent_pos[1]] = True
  472. for j in reversed(range(1, grid.height)):
  473. for i in range(0, grid.width-1):
  474. if not mask[i, j]:
  475. continue
  476. cell = grid.get(i, j)
  477. if cell and not cell.see_behind():
  478. continue
  479. mask[i+1, j] = True
  480. mask[i+1, j-1] = True
  481. mask[i, j-1] = True
  482. for i in reversed(range(1, grid.width)):
  483. if not mask[i, j]:
  484. continue
  485. cell = grid.get(i, j)
  486. if cell and not cell.see_behind():
  487. continue
  488. mask[i-1, j-1] = True
  489. mask[i-1, j] = True
  490. mask[i, j-1] = True
  491. for j in range(0, grid.height):
  492. for i in range(0, grid.width):
  493. if not mask[i, j]:
  494. grid.set(i, j, None)
  495. return mask
  496. class MiniGridEnv(gym.Env):
  497. """
  498. 2D grid world game environment
  499. """
  500. metadata = {
  501. 'render.modes': ['human', 'rgb_array', 'pixmap'],
  502. 'video.frames_per_second' : 10
  503. }
  504. # Enumeration of possible actions
  505. class Actions(IntEnum):
  506. # Turn left, turn right, move forward
  507. left = 0
  508. right = 1
  509. forward = 2
  510. # Pick up an object
  511. pickup = 3
  512. # Drop an object
  513. drop = 4
  514. # Toggle/activate an object
  515. toggle = 5
  516. # Done completing task
  517. done = 6
  518. def __init__(
  519. self,
  520. grid_size=16,
  521. max_steps=100,
  522. see_through_walls=False,
  523. seed=1337
  524. ):
  525. # Action enumeration for this environment
  526. self.actions = MiniGridEnv.Actions
  527. # Actions are discrete integer values
  528. self.action_space = spaces.Discrete(len(self.actions))
  529. # Observations are dictionaries containing an
  530. # encoding of the grid and a textual 'mission' string
  531. self.observation_space = spaces.Box(
  532. low=0,
  533. high=255,
  534. shape=OBS_ARRAY_SIZE,
  535. dtype='uint8'
  536. )
  537. self.observation_space = spaces.Dict({
  538. 'image': self.observation_space
  539. })
  540. # Range of possible rewards
  541. self.reward_range = (0, 1)
  542. # Renderer object used to render the whole grid (full-scale)
  543. self.grid_render = None
  544. # Renderer used to render observations (small-scale agent view)
  545. self.obs_render = None
  546. # Environment configuration
  547. self.grid_size = grid_size
  548. self.max_steps = max_steps
  549. self.see_through_walls = see_through_walls
  550. # Starting position and direction for the agent
  551. self.start_pos = None
  552. self.start_dir = None
  553. # Initialize the RNG
  554. self.seed(seed=seed)
  555. # Initialize the state
  556. self.reset()
  557. def reset(self):
  558. # Generate a new random grid at the start of each episode
  559. # To keep the same grid for each episode, call env.seed() with
  560. # the same seed before calling env.reset()
  561. self._gen_grid(self.grid_size, self.grid_size)
  562. # These fields should be defined by _gen_grid
  563. assert self.start_pos is not None
  564. assert self.start_dir is not None
  565. # Check that the agent doesn't overlap with an object
  566. start_cell = self.grid.get(*self.start_pos)
  567. assert start_cell is None or start_cell.can_overlap()
  568. # Place the agent in the starting position and direction
  569. self.agent_pos = self.start_pos
  570. self.agent_dir = self.start_dir
  571. # Item picked up, being carried, initially nothing
  572. self.carrying = None
  573. # Step count since episode start
  574. self.step_count = 0
  575. # Return first observation
  576. obs = self.gen_obs()
  577. return obs
  578. def seed(self, seed=1337):
  579. # Seed the random number generator
  580. self.np_random, _ = seeding.np_random(seed)
  581. return [seed]
  582. @property
  583. def steps_remaining(self):
  584. return self.max_steps - self.step_count
  585. def __str__(self):
  586. """
  587. Produce a pretty string of the environment's grid along with the agent.
  588. The agent is represented by `⏩`. A grid pixel is represented by 2-character
  589. string, the first one for the object and the second one for the color.
  590. """
  591. from copy import deepcopy
  592. def rotate_left(array):
  593. new_array = deepcopy(array)
  594. for i in range(len(array)):
  595. for j in range(len(array[0])):
  596. new_array[j][len(array[0])-1-i] = array[i][j]
  597. return new_array
  598. def vertically_symmetrize(array):
  599. new_array = deepcopy(array)
  600. for i in range(len(array)):
  601. for j in range(len(array[0])):
  602. new_array[i][len(array[0])-1-j] = array[i][j]
  603. return new_array
  604. # Map of object id to short string
  605. OBJECT_IDX_TO_IDS = {
  606. 0: ' ',
  607. 1: 'W',
  608. 2: 'D',
  609. 3: 'L',
  610. 4: 'K',
  611. 5: 'B',
  612. 6: 'X',
  613. 7: 'G'
  614. }
  615. # Short string for opened door
  616. OPENDED_DOOR_IDS = '_'
  617. # Map of color id to short string
  618. COLOR_IDX_TO_IDS = {
  619. 0: 'R',
  620. 1: 'G',
  621. 2: 'B',
  622. 3: 'P',
  623. 4: 'Y',
  624. 5: 'E'
  625. }
  626. # Map agent's direction to short string
  627. AGENT_DIR_TO_IDS = {
  628. 0: '⏩ ',
  629. 1: '⏬ ',
  630. 2: '⏪ ',
  631. 3: '⏫ '
  632. }
  633. array = self.grid.encode()
  634. array = rotate_left(array)
  635. array = vertically_symmetrize(array)
  636. new_array = []
  637. for line in array:
  638. new_line = []
  639. for pixel in line:
  640. # If the door is opened
  641. if pixel[0] in [2, 3] and pixel[2] == 1:
  642. object_ids = OPENDED_DOOR_IDS
  643. else:
  644. object_ids = OBJECT_IDX_TO_IDS[pixel[0]]
  645. # If no object
  646. if pixel[0] == 0:
  647. color_ids = ' '
  648. else:
  649. color_ids = COLOR_IDX_TO_IDS[pixel[1]]
  650. new_line.append(object_ids + color_ids)
  651. new_array.append(new_line)
  652. # Add the agent
  653. new_array[self.agent_pos[1]][self.agent_pos[0]] = AGENT_DIR_TO_IDS[self.agent_dir]
  654. return "\n".join([" ".join(line) for line in new_array])
  655. def _gen_grid(self, width, height):
  656. assert False, "_gen_grid needs to be implemented by each environment"
  657. def _reward(self):
  658. """
  659. Compute the reward to be given upon success
  660. """
  661. return 1 - 0.9 * (self.step_count / self.max_steps)
  662. def _rand_int(self, low, high):
  663. """
  664. Generate random integer in [low,high[
  665. """
  666. return self.np_random.randint(low, high)
  667. def _rand_float(self, low, high):
  668. """
  669. Generate random float in [low,high[
  670. """
  671. return self.np_random.uniform(low, high)
  672. def _rand_bool(self):
  673. """
  674. Generate random boolean value
  675. """
  676. return (self.np_random.randint(0, 2) == 0)
  677. def _rand_elem(self, iterable):
  678. """
  679. Pick a random element in a list
  680. """
  681. lst = list(iterable)
  682. idx = self._rand_int(0, len(lst))
  683. return lst[idx]
  684. def _rand_subset(self, iterable, num_elems):
  685. """
  686. Sample a random subset of distinct elements of a list
  687. """
  688. lst = list(iterable)
  689. assert num_elems <= len(lst)
  690. out = []
  691. while len(out) < num_elems:
  692. elem = self._rand_elem(lst)
  693. lst.remove(elem)
  694. out.append(elem)
  695. return out
  696. def _rand_color(self):
  697. """
  698. Generate a random color name (string)
  699. """
  700. return self._rand_elem(COLOR_NAMES)
  701. def _rand_pos(self, xLow, xHigh, yLow, yHigh):
  702. """
  703. Generate a random (x,y) position tuple
  704. """
  705. return (
  706. self.np_random.randint(xLow, xHigh),
  707. self.np_random.randint(yLow, yHigh)
  708. )
  709. def place_obj(self,
  710. obj,
  711. top=None,
  712. size=None,
  713. reject_fn=None,
  714. max_tries=math.inf
  715. ):
  716. """
  717. Place an object at an empty position in the grid
  718. :param top: top-left position of the rectangle where to place
  719. :param size: size of the rectangle where to place
  720. :param reject_fn: function to filter out potential positions
  721. """
  722. if top is None:
  723. top = (0, 0)
  724. if size is None:
  725. size = (self.grid.width, self.grid.height)
  726. num_tries = 0
  727. while True:
  728. # This is to handle with rare cases where rejection sampling
  729. # gets stuck in an infinite loop
  730. if num_tries > max_tries:
  731. raise RecursionError('rejection sampling failed in place_obj')
  732. num_tries += 1
  733. pos = np.array((
  734. self._rand_int(top[0], top[0] + size[0]),
  735. self._rand_int(top[1], top[1] + size[1])
  736. ))
  737. # Don't place the object on top of another object
  738. if self.grid.get(*pos) != None:
  739. continue
  740. # Don't place the object where the agent is
  741. if np.array_equal(pos, self.start_pos):
  742. continue
  743. # Check if there is a filtering criterion
  744. if reject_fn and reject_fn(self, pos):
  745. continue
  746. break
  747. self.grid.set(*pos, obj)
  748. if obj is not None:
  749. obj.init_pos = pos
  750. obj.cur_pos = pos
  751. return pos
  752. def place_agent(
  753. self,
  754. top=None,
  755. size=None,
  756. rand_dir=True,
  757. max_tries=math.inf
  758. ):
  759. """
  760. Set the agent's starting point at an empty position in the grid
  761. """
  762. self.start_pos = None
  763. pos = self.place_obj(None, top, size, max_tries=max_tries)
  764. self.start_pos = pos
  765. if rand_dir:
  766. self.start_dir = self._rand_int(0, 4)
  767. return pos
  768. @property
  769. def dir_vec(self):
  770. """
  771. Get the direction vector for the agent, pointing in the direction
  772. of forward movement.
  773. """
  774. assert self.agent_dir >= 0 and self.agent_dir < 4
  775. return DIR_TO_VEC[self.agent_dir]
  776. @property
  777. def right_vec(self):
  778. """
  779. Get the vector pointing to the right of the agent.
  780. """
  781. dx, dy = self.dir_vec
  782. return np.array((-dy, dx))
  783. @property
  784. def front_pos(self):
  785. """
  786. Get the position of the cell that is right in front of the agent
  787. """
  788. return self.agent_pos + self.dir_vec
  789. def get_view_coords(self, i, j):
  790. """
  791. Translate and rotate absolute grid coordinates (i, j) into the
  792. agent's partially observable view (sub-grid). Note that the resulting
  793. coordinates may be negative or outside of the agent's view size.
  794. """
  795. ax, ay = self.agent_pos
  796. dx, dy = self.dir_vec
  797. rx, ry = self.right_vec
  798. # Compute the absolute coordinates of the top-left view corner
  799. sz = AGENT_VIEW_SIZE
  800. hs = AGENT_VIEW_SIZE // 2
  801. tx = ax + (dx * (sz-1)) - (rx * hs)
  802. ty = ay + (dy * (sz-1)) - (ry * hs)
  803. lx = i - tx
  804. ly = j - ty
  805. # Project the coordinates of the object relative to the top-left
  806. # corner onto the agent's own coordinate system
  807. vx = (rx*lx + ry*ly)
  808. vy = -(dx*lx + dy*ly)
  809. return vx, vy
  810. def get_view_exts(self):
  811. """
  812. Get the extents of the square set of tiles visible to the agent
  813. Note: the bottom extent indices are not included in the set
  814. """
  815. # Facing right
  816. if self.agent_dir == 0:
  817. topX = self.agent_pos[0]
  818. topY = self.agent_pos[1] - AGENT_VIEW_SIZE // 2
  819. # Facing down
  820. elif self.agent_dir == 1:
  821. topX = self.agent_pos[0] - AGENT_VIEW_SIZE // 2
  822. topY = self.agent_pos[1]
  823. # Facing left
  824. elif self.agent_dir == 2:
  825. topX = self.agent_pos[0] - AGENT_VIEW_SIZE + 1
  826. topY = self.agent_pos[1] - AGENT_VIEW_SIZE // 2
  827. # Facing up
  828. elif self.agent_dir == 3:
  829. topX = self.agent_pos[0] - AGENT_VIEW_SIZE // 2
  830. topY = self.agent_pos[1] - AGENT_VIEW_SIZE + 1
  831. else:
  832. assert False, "invalid agent direction"
  833. botX = topX + AGENT_VIEW_SIZE
  834. botY = topY + AGENT_VIEW_SIZE
  835. return (topX, topY, botX, botY)
  836. def agent_sees(self, x, y):
  837. """
  838. Check if a grid position is visible to the agent
  839. """
  840. vx, vy = self.get_view_coords(x, y)
  841. if vx < 0 or vy < 0 or vx >= AGENT_VIEW_SIZE or vy >= AGENT_VIEW_SIZE:
  842. return False
  843. obs = self.gen_obs()
  844. obs_grid = Grid.decode(obs['image'])
  845. obs_cell = obs_grid.get(vx, vy)
  846. world_cell = self.grid.get(x, y)
  847. return obs_cell is not None and obs_cell.type == world_cell.type
  848. def step(self, action):
  849. self.step_count += 1
  850. reward = 0
  851. done = False
  852. # Get the position in front of the agent
  853. fwd_pos = self.front_pos
  854. # Get the contents of the cell in front of the agent
  855. fwd_cell = self.grid.get(*fwd_pos)
  856. # Rotate left
  857. if action == self.actions.left:
  858. self.agent_dir -= 1
  859. if self.agent_dir < 0:
  860. self.agent_dir += 4
  861. # Rotate right
  862. elif action == self.actions.right:
  863. self.agent_dir = (self.agent_dir + 1) % 4
  864. # Move forward
  865. elif action == self.actions.forward:
  866. if fwd_cell == None or fwd_cell.can_overlap():
  867. self.agent_pos = fwd_pos
  868. if fwd_cell != None and fwd_cell.type == 'goal':
  869. done = True
  870. reward = self._reward()
  871. # Pick up an object
  872. elif action == self.actions.pickup:
  873. if fwd_cell and fwd_cell.can_pickup():
  874. if self.carrying is None:
  875. self.carrying = fwd_cell
  876. self.carrying.cur_pos = np.array([-1, -1])
  877. self.grid.set(*fwd_pos, None)
  878. # Drop an object
  879. elif action == self.actions.drop:
  880. if not fwd_cell and self.carrying:
  881. self.grid.set(*fwd_pos, self.carrying)
  882. self.carrying.cur_pos = fwd_pos
  883. self.carrying = None
  884. # Toggle/activate an object
  885. elif action == self.actions.toggle:
  886. if fwd_cell:
  887. fwd_cell.toggle(self, fwd_pos)
  888. # Done action (not used by default)
  889. elif action == self.actions.done:
  890. pass
  891. else:
  892. assert False, "unknown action"
  893. if self.step_count >= self.max_steps:
  894. done = True
  895. obs = self.gen_obs()
  896. return obs, reward, done, {}
  897. def gen_obs_grid(self):
  898. """
  899. Generate the sub-grid observed by the agent.
  900. This method also outputs a visibility mask telling us which grid
  901. cells the agent can actually see.
  902. """
  903. topX, topY, botX, botY = self.get_view_exts()
  904. grid = self.grid.slice(topX, topY, AGENT_VIEW_SIZE, AGENT_VIEW_SIZE)
  905. for i in range(self.agent_dir + 1):
  906. grid = grid.rotate_left()
  907. # Process occluders and visibility
  908. # Note that this incurs some performance cost
  909. if not self.see_through_walls:
  910. vis_mask = grid.process_vis(agent_pos=(AGENT_VIEW_SIZE // 2 , AGENT_VIEW_SIZE - 1))
  911. else:
  912. vis_mask = np.ones(shape=(grid.width, grid.height), dtype=np.bool)
  913. # Make it so the agent sees what it's carrying
  914. # We do this by placing the carried object at the agent's position
  915. # in the agent's partially observable view
  916. agent_pos = grid.width // 2, grid.height - 1
  917. if self.carrying:
  918. grid.set(*agent_pos, self.carrying)
  919. else:
  920. grid.set(*agent_pos, None)
  921. return grid, vis_mask
  922. def gen_obs(self):
  923. """
  924. Generate the agent's view (partially observable, low-resolution encoding)
  925. """
  926. grid, vis_mask = self.gen_obs_grid()
  927. # Encode the partially observable view into a numpy array
  928. image = grid.encode()
  929. assert hasattr(self, 'mission'), "environments must define a textual mission string"
  930. # Observations are dictionaries containing:
  931. # - an image (partially observable view of the environment)
  932. # - the agent's direction/orientation (acting as a compass)
  933. # - a textual mission string (instructions for the agent)
  934. obs = {
  935. 'image': image,
  936. 'direction': self.agent_dir,
  937. 'mission': self.mission
  938. }
  939. return obs
  940. def get_obs_render(self, obs, tile_pixels=CELL_PIXELS//2):
  941. """
  942. Render an agent observation for visualization
  943. """
  944. if self.obs_render == None:
  945. self.obs_render = Renderer(
  946. AGENT_VIEW_SIZE * tile_pixels,
  947. AGENT_VIEW_SIZE * tile_pixels
  948. )
  949. r = self.obs_render
  950. r.beginFrame()
  951. grid = Grid.decode(obs)
  952. # Render the whole grid
  953. grid.render(r, tile_pixels)
  954. # Draw the agent
  955. ratio = tile_pixels / CELL_PIXELS
  956. r.push()
  957. r.scale(ratio, ratio)
  958. r.translate(
  959. CELL_PIXELS * (0.5 + AGENT_VIEW_SIZE // 2),
  960. CELL_PIXELS * (AGENT_VIEW_SIZE - 0.5)
  961. )
  962. r.rotate(3 * 90)
  963. r.setLineColor(255, 0, 0)
  964. r.setColor(255, 0, 0)
  965. r.drawPolygon([
  966. (-12, 10),
  967. ( 12, 0),
  968. (-12, -10)
  969. ])
  970. r.pop()
  971. r.endFrame()
  972. return r.getPixmap()
  973. def render(self, mode='human', close=False):
  974. """
  975. Render the whole-grid human view
  976. """
  977. if close:
  978. if self.grid_render:
  979. self.grid_render.close()
  980. return
  981. if self.grid_render is None:
  982. self.grid_render = Renderer(
  983. self.grid_size * CELL_PIXELS,
  984. self.grid_size * CELL_PIXELS,
  985. True if mode == 'human' else False
  986. )
  987. r = self.grid_render
  988. r.beginFrame()
  989. # Render the whole grid
  990. self.grid.render(r, CELL_PIXELS)
  991. # Draw the agent
  992. r.push()
  993. r.translate(
  994. CELL_PIXELS * (self.agent_pos[0] + 0.5),
  995. CELL_PIXELS * (self.agent_pos[1] + 0.5)
  996. )
  997. r.rotate(self.agent_dir * 90)
  998. r.setLineColor(255, 0, 0)
  999. r.setColor(255, 0, 0)
  1000. r.drawPolygon([
  1001. (-12, 10),
  1002. ( 12, 0),
  1003. (-12, -10)
  1004. ])
  1005. r.pop()
  1006. # Compute which cells are visible to the agent
  1007. _, vis_mask = self.gen_obs_grid()
  1008. # Compute the absolute coordinates of the bottom-left corner
  1009. # of the agent's view area
  1010. f_vec = self.dir_vec
  1011. r_vec = self.right_vec
  1012. top_left = self.agent_pos + f_vec * (AGENT_VIEW_SIZE-1) - r_vec * (AGENT_VIEW_SIZE // 2)
  1013. # For each cell in the visibility mask
  1014. for vis_j in range(0, AGENT_VIEW_SIZE):
  1015. for vis_i in range(0, AGENT_VIEW_SIZE):
  1016. # If this cell is not visible, don't highlight it
  1017. if not vis_mask[vis_i, vis_j]:
  1018. continue
  1019. # Compute the world coordinates of this cell
  1020. abs_i, abs_j = top_left - (f_vec * vis_j) + (r_vec * vis_i)
  1021. # Highlight the cell
  1022. r.fillRect(
  1023. abs_i * CELL_PIXELS,
  1024. abs_j * CELL_PIXELS,
  1025. CELL_PIXELS,
  1026. CELL_PIXELS,
  1027. 255, 255, 255, 75
  1028. )
  1029. r.endFrame()
  1030. if mode == 'rgb_array':
  1031. return r.getArray()
  1032. elif mode == 'pixmap':
  1033. return r.getPixmap()
  1034. return r