minigrid.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230
  1. import math
  2. import gym
  3. from enum import IntEnum
  4. import numpy as np
  5. from gym import error, spaces, utils
  6. from gym.utils import seeding
  7. from gym_minigrid.rendering import *
  8. # Size in pixels of a cell in the full-scale human view
  9. CELL_PIXELS = 32
  10. # Number of cells (width and height) in the agent view
  11. AGENT_VIEW_SIZE = 7
  12. # Size of the array given as an observation to the agent
  13. OBS_ARRAY_SIZE = (AGENT_VIEW_SIZE, AGENT_VIEW_SIZE, 3)
  14. # Map of color names to RGB values
  15. COLORS = {
  16. 'red' : (255, 0, 0),
  17. 'green' : (0, 255, 0),
  18. 'blue' : (0, 0, 255),
  19. 'purple': (112, 39, 195),
  20. 'yellow': (255, 255, 0),
  21. 'grey' : (100, 100, 100)
  22. }
  23. COLOR_NAMES = sorted(list(COLORS.keys()))
  24. # Used to map colors to integers
  25. COLOR_TO_IDX = {
  26. 'red' : 0,
  27. 'green' : 1,
  28. 'blue' : 2,
  29. 'purple': 3,
  30. 'yellow': 4,
  31. 'grey' : 5
  32. }
  33. IDX_TO_COLOR = dict(zip(COLOR_TO_IDX.values(), COLOR_TO_IDX.keys()))
  34. # Map of object type to integers
  35. OBJECT_TO_IDX = {
  36. 'empty' : 0,
  37. 'wall' : 1,
  38. 'door' : 2,
  39. 'locked_door' : 3,
  40. 'key' : 4,
  41. 'ball' : 5,
  42. 'box' : 6,
  43. 'goal' : 7
  44. }
  45. IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys()))
  46. class WorldObj:
  47. """
  48. Base class for grid world objects
  49. """
  50. def __init__(self, type, color):
  51. assert type in OBJECT_TO_IDX, type
  52. assert color in COLOR_TO_IDX, color
  53. self.type = type
  54. self.color = color
  55. self.contains = None
  56. def can_overlap(self):
  57. """Can the agent overlap with this?"""
  58. return False
  59. def can_pickup(self):
  60. """Can the agent pick this up?"""
  61. return False
  62. def can_contain(self):
  63. """Can this contain another object?"""
  64. return False
  65. def see_behind(self):
  66. """Can the agent see behind this object?"""
  67. return True
  68. def toggle(self, env, pos):
  69. """Method to trigger/toggle an action this object performs"""
  70. return False
  71. def render(self, r):
  72. """Draw this object with the given renderer"""
  73. raise NotImplementedError
  74. def _set_color(self, r):
  75. """Set the color of this object as the active drawing color"""
  76. c = COLORS[self.color]
  77. r.setLineColor(c[0], c[1], c[2])
  78. r.setColor(c[0], c[1], c[2])
  79. class Goal(WorldObj):
  80. def __init__(self):
  81. super(Goal, self).__init__('goal', 'green')
  82. def can_overlap(self):
  83. return True
  84. def render(self, r):
  85. self._set_color(r)
  86. r.drawPolygon([
  87. (0 , CELL_PIXELS),
  88. (CELL_PIXELS, CELL_PIXELS),
  89. (CELL_PIXELS, 0),
  90. (0 , 0)
  91. ])
  92. class Wall(WorldObj):
  93. def __init__(self, color='grey'):
  94. super(Wall, self).__init__('wall', color)
  95. def see_behind(self):
  96. return False
  97. def render(self, r):
  98. self._set_color(r)
  99. r.drawPolygon([
  100. (0 , CELL_PIXELS),
  101. (CELL_PIXELS, CELL_PIXELS),
  102. (CELL_PIXELS, 0),
  103. (0 , 0)
  104. ])
  105. class Door(WorldObj):
  106. def __init__(self, color, is_open=False):
  107. super(Door, self).__init__('door', color)
  108. self.is_open = is_open
  109. def can_overlap(self):
  110. """The agent can only walk over this cell when the door is open"""
  111. return self.is_open
  112. def see_behind(self):
  113. return self.is_open
  114. def toggle(self, env, pos):
  115. if not self.is_open:
  116. self.is_open = True
  117. return True
  118. return False
  119. def render(self, r):
  120. c = COLORS[self.color]
  121. r.setLineColor(c[0], c[1], c[2])
  122. r.setColor(0, 0, 0)
  123. if self.is_open:
  124. r.drawPolygon([
  125. (CELL_PIXELS-2, CELL_PIXELS),
  126. (CELL_PIXELS , CELL_PIXELS),
  127. (CELL_PIXELS , 0),
  128. (CELL_PIXELS-2, 0)
  129. ])
  130. return
  131. r.drawPolygon([
  132. (0 , CELL_PIXELS),
  133. (CELL_PIXELS, CELL_PIXELS),
  134. (CELL_PIXELS, 0),
  135. (0 , 0)
  136. ])
  137. r.drawPolygon([
  138. (2 , CELL_PIXELS-2),
  139. (CELL_PIXELS-2, CELL_PIXELS-2),
  140. (CELL_PIXELS-2, 2),
  141. (2 , 2)
  142. ])
  143. r.drawCircle(CELL_PIXELS * 0.75, CELL_PIXELS * 0.5, 2)
  144. class LockedDoor(WorldObj):
  145. def __init__(self, color, is_open=False):
  146. super(LockedDoor, self).__init__('locked_door', color)
  147. self.is_open = is_open
  148. def toggle(self, env, pos):
  149. # If the player has the right key to open the door
  150. if isinstance(env.carrying, Key) and env.carrying.color == self.color:
  151. self.is_open = True
  152. # The key has been used, remove it from the agent
  153. env.carrying = None
  154. return True
  155. return False
  156. def can_overlap(self):
  157. """The agent can only walk over this cell when the door is open"""
  158. return self.is_open
  159. def see_behind(self):
  160. return self.is_open
  161. def render(self, r):
  162. c = COLORS[self.color]
  163. r.setLineColor(c[0], c[1], c[2])
  164. r.setColor(c[0], c[1], c[2], 50)
  165. if self.is_open:
  166. r.drawPolygon([
  167. (CELL_PIXELS-2, CELL_PIXELS),
  168. (CELL_PIXELS , CELL_PIXELS),
  169. (CELL_PIXELS , 0),
  170. (CELL_PIXELS-2, 0)
  171. ])
  172. return
  173. r.drawPolygon([
  174. (0 , CELL_PIXELS),
  175. (CELL_PIXELS, CELL_PIXELS),
  176. (CELL_PIXELS, 0),
  177. (0 , 0)
  178. ])
  179. r.drawPolygon([
  180. (2 , CELL_PIXELS-2),
  181. (CELL_PIXELS-2, CELL_PIXELS-2),
  182. (CELL_PIXELS-2, 2),
  183. (2 , 2)
  184. ])
  185. r.drawLine(
  186. CELL_PIXELS * 0.55,
  187. CELL_PIXELS * 0.5,
  188. CELL_PIXELS * 0.75,
  189. CELL_PIXELS * 0.5
  190. )
  191. class Key(WorldObj):
  192. def __init__(self, color='blue'):
  193. super(Key, self).__init__('key', color)
  194. def can_pickup(self):
  195. return True
  196. def render(self, r):
  197. self._set_color(r)
  198. # Vertical quad
  199. r.drawPolygon([
  200. (16, 10),
  201. (20, 10),
  202. (20, 28),
  203. (16, 28)
  204. ])
  205. # Teeth
  206. r.drawPolygon([
  207. (12, 19),
  208. (16, 19),
  209. (16, 21),
  210. (12, 21)
  211. ])
  212. r.drawPolygon([
  213. (12, 26),
  214. (16, 26),
  215. (16, 28),
  216. (12, 28)
  217. ])
  218. r.drawCircle(18, 9, 6)
  219. r.setLineColor(0, 0, 0)
  220. r.setColor(0, 0, 0)
  221. r.drawCircle(18, 9, 2)
  222. class Ball(WorldObj):
  223. def __init__(self, color='blue'):
  224. super(Ball, self).__init__('ball', color)
  225. def can_pickup(self):
  226. return True
  227. def render(self, r):
  228. self._set_color(r)
  229. r.drawCircle(CELL_PIXELS * 0.5, CELL_PIXELS * 0.5, 10)
  230. class Box(WorldObj):
  231. def __init__(self, color, contains=None):
  232. super(Box, self).__init__('box', color)
  233. self.contains = contains
  234. def can_pickup(self):
  235. return True
  236. def render(self, r):
  237. c = COLORS[self.color]
  238. r.setLineColor(c[0], c[1], c[2])
  239. r.setColor(0, 0, 0)
  240. r.setLineWidth(2)
  241. r.drawPolygon([
  242. (4 , CELL_PIXELS-4),
  243. (CELL_PIXELS-4, CELL_PIXELS-4),
  244. (CELL_PIXELS-4, 4),
  245. (4 , 4)
  246. ])
  247. r.drawLine(
  248. 4,
  249. CELL_PIXELS / 2,
  250. CELL_PIXELS - 4,
  251. CELL_PIXELS / 2
  252. )
  253. r.setLineWidth(1)
  254. def toggle(self, env, pos):
  255. # Replace the box by its contents
  256. env.grid.set(*pos, self.contains)
  257. return True
  258. class Grid:
  259. """
  260. Represent a grid and operations on it
  261. """
  262. def __init__(self, width, height):
  263. assert width >= 4
  264. assert height >= 4
  265. self.width = width
  266. self.height = height
  267. self.grid = [None] * width * height
  268. def __contains__(self, key):
  269. if isinstance(key, WorldObj):
  270. for e in self.grid:
  271. if e is key:
  272. return True
  273. elif isinstance(key, tuple):
  274. for e in self.grid:
  275. if e is None:
  276. continue
  277. if (e.color, e.type) == key:
  278. return True
  279. return False
  280. def __eq__(self, other):
  281. grid1 = self.encode()
  282. grid2 = other.encode()
  283. return np.array_equal(grid2, grid1)
  284. def __ne__(self, other):
  285. return not self == other
  286. def copy(self):
  287. from copy import deepcopy
  288. return deepcopy(self)
  289. def set(self, i, j, v):
  290. assert i >= 0 and i < self.width
  291. assert j >= 0 and j < self.height
  292. self.grid[j * self.width + i] = v
  293. def get(self, i, j):
  294. assert i >= 0 and i < self.width
  295. assert j >= 0 and j < self.height
  296. return self.grid[j * self.width + i]
  297. def horzWall(self, x, y, length=None):
  298. if length is None:
  299. length = self.width - x
  300. for i in range(0, length):
  301. self.set(x + i, y, Wall())
  302. def vertWall(self, x, y, length=None):
  303. if length is None:
  304. length = self.height - y
  305. for j in range(0, length):
  306. self.set(x, y + j, Wall())
  307. def wallRect(self, x, y, w, h):
  308. self.horzWall(x, y, w)
  309. self.horzWall(x, y+h-1, w)
  310. self.vertWall(x, y, h)
  311. self.vertWall(x+w-1, y, h)
  312. def rotateLeft(self):
  313. """
  314. Rotate the grid to the left (counter-clockwise)
  315. """
  316. grid = Grid(self.width, self.height)
  317. for j in range(0, self.height):
  318. for i in range(0, self.width):
  319. v = self.get(self.width - 1 - j, i)
  320. grid.set(i, j, v)
  321. return grid
  322. def slice(self, topX, topY, width, height):
  323. """
  324. Get a subset of the grid
  325. """
  326. grid = Grid(width, height)
  327. for j in range(0, height):
  328. for i in range(0, width):
  329. x = topX + i
  330. y = topY + j
  331. if x >= 0 and x < self.width and \
  332. y >= 0 and y < self.height:
  333. v = self.get(x, y)
  334. else:
  335. v = Wall()
  336. grid.set(i, j, v)
  337. return grid
  338. def render(self, r, tileSize):
  339. """
  340. Render this grid at a given scale
  341. :param r: target renderer object
  342. :param tileSize: tile size in pixels
  343. """
  344. assert r.width == self.width * tileSize
  345. assert r.height == self.height * tileSize
  346. # Total grid size at native scale
  347. widthPx = self.width * CELL_PIXELS
  348. heightPx = self.height * CELL_PIXELS
  349. """
  350. # Draw background (out-of-world) tiles the same colors as walls
  351. # so the agent understands these areas are not reachable
  352. c = COLORS['grey']
  353. r.setLineColor(c[0], c[1], c[2])
  354. r.setColor(c[0], c[1], c[2])
  355. r.drawPolygon([
  356. (0 , heightPx),
  357. (widthPx, heightPx),
  358. (widthPx, 0),
  359. (0 , 0)
  360. ])
  361. """
  362. r.push()
  363. # Internally, we draw at the "large" full-grid resolution, but we
  364. # use the renderer to scale back to the desired size
  365. r.scale(tileSize / CELL_PIXELS, tileSize / CELL_PIXELS)
  366. # Draw the background of the in-world cells black
  367. r.fillRect(
  368. 0,
  369. 0,
  370. widthPx,
  371. heightPx,
  372. 0, 0, 0
  373. )
  374. # Draw grid lines
  375. r.setLineColor(100, 100, 100)
  376. for rowIdx in range(0, self.height):
  377. y = CELL_PIXELS * rowIdx
  378. r.drawLine(0, y, widthPx, y)
  379. for colIdx in range(0, self.width):
  380. x = CELL_PIXELS * colIdx
  381. r.drawLine(x, 0, x, heightPx)
  382. # Render the grid
  383. for j in range(0, self.height):
  384. for i in range(0, self.width):
  385. cell = self.get(i, j)
  386. if cell == None:
  387. continue
  388. r.push()
  389. r.translate(i * CELL_PIXELS, j * CELL_PIXELS)
  390. cell.render(r)
  391. r.pop()
  392. r.pop()
  393. def encode(self):
  394. """
  395. Produce a compact numpy encoding of the grid
  396. """
  397. codeSize = self.width * self.height * 3
  398. array = np.zeros(shape=(self.width, self.height, 3), dtype='uint8')
  399. for j in range(0, self.height):
  400. for i in range(0, self.width):
  401. v = self.get(i, j)
  402. if v == None:
  403. continue
  404. array[i, j, 0] = OBJECT_TO_IDX[v.type]
  405. array[i, j, 1] = COLOR_TO_IDX[v.color]
  406. if hasattr(v, 'is_open') and v.is_open:
  407. array[i, j, 2] = 1
  408. return array
  409. def decode(array):
  410. """
  411. Decode an array grid encoding back into a grid
  412. """
  413. width = array.shape[0]
  414. height = array.shape[1]
  415. assert array.shape[2] == 3
  416. grid = Grid(width, height)
  417. for j in range(0, height):
  418. for i in range(0, width):
  419. typeIdx = array[i, j, 0]
  420. colorIdx = array[i, j, 1]
  421. openIdx = array[i, j, 2]
  422. if typeIdx == 0:
  423. continue
  424. objType = IDX_TO_OBJECT[typeIdx]
  425. color = IDX_TO_COLOR[colorIdx]
  426. is_open = True if openIdx == 1 else 0
  427. if objType == 'wall':
  428. v = Wall(color)
  429. elif objType == 'ball':
  430. v = Ball(color)
  431. elif objType == 'key':
  432. v = Key(color)
  433. elif objType == 'box':
  434. v = Box(color)
  435. elif objType == 'door':
  436. v = Door(color, is_open)
  437. elif objType == 'locked_door':
  438. v = LockedDoor(color, is_open)
  439. elif objType == 'goal':
  440. v = Goal()
  441. else:
  442. assert False, "unknown obj type in decode '%s'" % objType
  443. grid.set(i, j, v)
  444. return grid
  445. def process_vis(
  446. grid,
  447. agent_pos,
  448. n_rays = 32,
  449. n_steps = 24,
  450. a_min = math.pi,
  451. a_max = 2 * math.pi
  452. ):
  453. """
  454. Use ray casting to determine the visibility of each grid cell
  455. """
  456. mask = np.zeros(shape=(grid.width, grid.height), dtype=np.bool)
  457. ang_step = (a_max - a_min) / n_rays
  458. dst_step = math.sqrt(grid.width ** 2 + grid.height ** 2) / n_steps
  459. ax = agent_pos[0] + 0.5
  460. ay = agent_pos[1] + 0.5
  461. for ray_idx in range(0, n_rays):
  462. angle = a_min + ang_step * ray_idx
  463. dx = dst_step * math.cos(angle)
  464. dy = dst_step * math.sin(angle)
  465. for step_idx in range(0, n_steps):
  466. x = ax + (step_idx * dx)
  467. y = ay + (step_idx * dy)
  468. i = math.floor(x)
  469. j = math.floor(y)
  470. # If we're outside of the grid, stop
  471. if i < 0 or i >= grid.width or j < 0 or j >= grid.height:
  472. break
  473. # Mark this cell as visible
  474. mask[i, j] = True
  475. # If we hit the obstructor, stop
  476. cell = grid.get(i, j)
  477. if cell and not cell.see_behind():
  478. break
  479. for j in range(0, grid.height):
  480. for i in range(0, grid.width):
  481. if not mask[i, j]:
  482. grid.set(i, j, None)
  483. #grid.set(i, j, Wall('red'))
  484. return mask
  485. def process_vis_prop(
  486. grid,
  487. agent_pos
  488. ):
  489. mask = np.zeros(shape=(grid.width, grid.height), dtype=np.bool)
  490. mask[agent_pos[0], agent_pos[1]] = True
  491. for j in reversed(range(1, grid.height)):
  492. for i in range(0, grid.width-1):
  493. if not mask[i, j]:
  494. continue
  495. cell = grid.get(i, j)
  496. if cell and not cell.see_behind():
  497. continue
  498. mask[i+1, j] = True
  499. mask[i+1, j-1] = True
  500. mask[i, j-1] = True
  501. for i in reversed(range(1, grid.width)):
  502. if not mask[i, j]:
  503. continue
  504. cell = grid.get(i, j)
  505. if cell and not cell.see_behind():
  506. continue
  507. mask[i-1, j-1] = True
  508. mask[i-1, j] = True
  509. mask[i, j-1] = True
  510. for j in range(0, grid.height):
  511. for i in range(0, grid.width):
  512. if not mask[i, j]:
  513. grid.set(i, j, None)
  514. #grid.set(i, j, Wall('red'))
  515. class MiniGridEnv(gym.Env):
  516. """
  517. 2D grid world game environment
  518. """
  519. metadata = {
  520. 'render.modes': ['human', 'rgb_array', 'pixmap'],
  521. 'video.frames_per_second' : 10
  522. }
  523. # Enumeration of possible actions
  524. class Actions(IntEnum):
  525. # Turn left, turn right, move forward
  526. left = 0
  527. right = 1
  528. forward = 2
  529. # Pick up an object
  530. pickup = 3
  531. # Drop an object
  532. drop = 4
  533. # Toggle/activate an object
  534. toggle = 5
  535. # Wait/stay put/do nothing
  536. wait = 6
  537. def __init__(
  538. self,
  539. grid_size=16,
  540. max_steps=100,
  541. see_through_walls=False
  542. ):
  543. # Action enumeration for this environment
  544. self.actions = MiniGridEnv.Actions
  545. # Actions are discrete integer values
  546. self.action_space = spaces.Discrete(len(self.actions))
  547. # Observations are dictionaries containing an
  548. # encoding of the grid and a textual 'mission' string
  549. self.observation_space = spaces.Box(
  550. low=0,
  551. high=255,
  552. shape=OBS_ARRAY_SIZE,
  553. dtype='uint8'
  554. )
  555. self.observation_space = spaces.Dict({
  556. 'image': self.observation_space
  557. })
  558. # Range of possible rewards
  559. self.reward_range = (-1, 1000)
  560. # Renderer object used to render the whole grid (full-scale)
  561. self.grid_render = None
  562. # Renderer used to render observations (small-scale agent view)
  563. self.obs_render = None
  564. # Environment configuration
  565. self.grid_size = grid_size
  566. self.max_steps = max_steps
  567. self.see_through_walls = see_through_walls
  568. # Starting position and direction for the agent
  569. self.start_pos = None
  570. self.start_dir = None
  571. # Initialize the state
  572. self.seed()
  573. self.reset()
  574. def reset(self):
  575. # Generate a new random grid at the start of each episode
  576. # To keep the same grid for each episode, call env.seed() with
  577. # the same seed before calling env.reset()
  578. self._gen_grid(self.grid_size, self.grid_size)
  579. # These fields should be defined by _gen_grid
  580. assert self.start_pos != None
  581. assert self.start_dir != None
  582. # Check that the agent doesn't overlap with an object
  583. assert self.grid.get(*self.start_pos) is None
  584. # Place the agent in the starting position and direction
  585. self.agent_pos = self.start_pos
  586. self.agent_dir = self.start_dir
  587. # Item picked up, being carried, initially nothing
  588. self.carrying = None
  589. # Step count since episode start
  590. self.step_count = 0
  591. # Return first observation
  592. obs = self.gen_obs()
  593. return obs
  594. def seed(self, seed=1337):
  595. # Seed the random number generator
  596. self.np_random, _ = seeding.np_random(seed)
  597. return [seed]
  598. @property
  599. def steps_remaining(self):
  600. return self.max_steps - self.step_count
  601. def __str__(self):
  602. """
  603. Produce a pretty string of the environment's grid along with the agent.
  604. The agent is represented by `⏩`. A grid pixel is represented by 2-character
  605. string, the first one for the object and the second one for the color.
  606. """
  607. from copy import deepcopy
  608. def rotate_left(array):
  609. new_array = deepcopy(array)
  610. for i in range(len(array)):
  611. for j in range(len(array[0])):
  612. new_array[j][len(array[0])-1-i] = array[i][j]
  613. return new_array
  614. def vertically_symmetrize(array):
  615. new_array = deepcopy(array)
  616. for i in range(len(array)):
  617. for j in range(len(array[0])):
  618. new_array[i][len(array[0])-1-j] = array[i][j]
  619. return new_array
  620. # Map of object id to short string
  621. OBJECT_IDX_TO_IDS = {
  622. 0: ' ',
  623. 1: 'W',
  624. 2: 'D',
  625. 3: 'L',
  626. 4: 'K',
  627. 5: 'B',
  628. 6: 'X',
  629. 7: 'G'
  630. }
  631. # Short string for opened door
  632. OPENDED_DOOR_IDS = '_'
  633. # Map of color id to short string
  634. COLOR_IDX_TO_IDS = {
  635. 0: 'R',
  636. 1: 'G',
  637. 2: 'B',
  638. 3: 'P',
  639. 4: 'Y',
  640. 5: 'E'
  641. }
  642. # Map agent's direction to short string
  643. AGENT_DIR_TO_IDS = {
  644. 0: '⏩',
  645. 1: '⏬',
  646. 2: '⏪',
  647. 3: '⏫'
  648. }
  649. array = self.grid.encode()
  650. array = rotate_left(array)
  651. array = vertically_symmetrize(array)
  652. new_array = []
  653. for line in array:
  654. new_line = []
  655. for pixel in line:
  656. # If the door is opened
  657. if pixel[0] in [2, 3] and pixel[2] == 1:
  658. object_ids = OPENDED_DOOR_IDS
  659. else:
  660. object_ids = OBJECT_IDX_TO_IDS[pixel[0]]
  661. # If no object
  662. if pixel[0] == 0:
  663. color_ids = ' '
  664. else:
  665. color_ids = COLOR_IDX_TO_IDS[pixel[1]]
  666. new_line.append(object_ids + color_ids)
  667. new_array.append(new_line)
  668. # Add the agent
  669. new_array[self.agent_pos[1]][self.agent_pos[0]] = AGENT_DIR_TO_IDS[self.agent_dir]
  670. return "\n".join([" ".join(line) for line in new_array])
  671. def _gen_grid(self, width, height):
  672. assert False, "_gen_grid needs to be implemented by each environment"
  673. def _randInt(self, low, high):
  674. """
  675. Generate random integer in [low,high[
  676. """
  677. return self.np_random.randint(low, high)
  678. def _randElem(self, iterable):
  679. """
  680. Pick a random element in a list
  681. """
  682. lst = list(iterable)
  683. idx = self._randInt(0, len(lst))
  684. return lst[idx]
  685. def _randPos(self, xLow, xHigh, yLow, yHigh):
  686. """
  687. Generate a random (x,y) position tuple
  688. """
  689. return (
  690. self.np_random.randint(xLow, xHigh),
  691. self.np_random.randint(yLow, yHigh)
  692. )
  693. def placeObj(self, obj, top=None, size=None, reject_fn=None):
  694. """
  695. Place an object at an empty position in the grid
  696. :param top: top-left position of the rectangle where to place
  697. :param size: size of the rectangle where to place
  698. :param reject_fn: function to filter out potential positions
  699. """
  700. if top is None:
  701. top = (0, 0)
  702. if size is None:
  703. size = (self.grid.width, self.grid.height)
  704. while True:
  705. pos = (
  706. self._randInt(top[0], top[0] + size[0]),
  707. self._randInt(top[1], top[1] + size[1])
  708. )
  709. # Don't place the object on top of another object
  710. if self.grid.get(*pos) != None:
  711. continue
  712. # Don't place the object where the agent is
  713. if pos == self.start_pos:
  714. continue
  715. # Check if there is a filtering criterion
  716. if reject_fn and reject_fn(self, pos):
  717. continue
  718. break
  719. self.grid.set(*pos, obj)
  720. return pos
  721. def placeAgent(self, top=None, size=None, randDir=True):
  722. """
  723. Set the agent's starting point at an empty position in the grid
  724. """
  725. pos = self.placeObj(None, top, size)
  726. self.start_pos = pos
  727. if randDir:
  728. self.start_dir = self._randInt(0, 4)
  729. return pos
  730. def get_dir_vec(self):
  731. """
  732. Get the direction vector for the agent, pointing in the direction
  733. of forward movement.
  734. """
  735. # Pointing right
  736. if self.agent_dir == 0:
  737. return (1, 0)
  738. # Down (positive Y)
  739. elif self.agent_dir == 1:
  740. return (0, 1)
  741. # Pointing left
  742. elif self.agent_dir == 2:
  743. return (-1, 0)
  744. # Up (negative Y)
  745. elif self.agent_dir == 3:
  746. return (0, -1)
  747. else:
  748. assert False
  749. def get_right_vec(self):
  750. """
  751. Get the vector pointing to the right of the agent.
  752. """
  753. dx, dy = self.get_dir_vec()
  754. return -dy, dx
  755. def get_view_coords(self, i, j):
  756. """
  757. Translate and rotate absolute grid coordinates (i, j) into the
  758. agent's partially observable view (sub-grid). Note that the resulting
  759. coordinates may be negative or outside of the agent's view size.
  760. """
  761. ax, ay = self.agent_pos
  762. dx, dy = self.get_dir_vec()
  763. rx, ry = self.get_right_vec()
  764. # Compute the absolute coordinates of the top-left view corner
  765. sz = AGENT_VIEW_SIZE
  766. hs = AGENT_VIEW_SIZE // 2
  767. tx = ax + (dx * (sz-1)) - (rx * hs)
  768. ty = ay + (dy * (sz-1)) - (ry * hs)
  769. lx = i - tx
  770. ly = j - ty
  771. # Project the coordinates of the object relative to the top-left
  772. # corner onto the agent's own coordinate system
  773. vx = (rx*lx + ry*ly)
  774. vy = -(dx*lx + dy*ly)
  775. return vx, vy
  776. def get_view_exts(self):
  777. """
  778. Get the extents of the square set of tiles visible to the agent
  779. Note: the bottom extent indices are not included in the set
  780. """
  781. # Facing right
  782. if self.agent_dir == 0:
  783. topX = self.agent_pos[0]
  784. topY = self.agent_pos[1] - AGENT_VIEW_SIZE // 2
  785. # Facing down
  786. elif self.agent_dir == 1:
  787. topX = self.agent_pos[0] - AGENT_VIEW_SIZE // 2
  788. topY = self.agent_pos[1]
  789. # Facing left
  790. elif self.agent_dir == 2:
  791. topX = self.agent_pos[0] - AGENT_VIEW_SIZE + 1
  792. topY = self.agent_pos[1] - AGENT_VIEW_SIZE // 2
  793. # Facing up
  794. elif self.agent_dir == 3:
  795. topX = self.agent_pos[0] - AGENT_VIEW_SIZE // 2
  796. topY = self.agent_pos[1] - AGENT_VIEW_SIZE + 1
  797. else:
  798. assert False, "invalid agent direction"
  799. botX = topX + AGENT_VIEW_SIZE
  800. botY = topY + AGENT_VIEW_SIZE
  801. return (topX, topY, botX, botY)
  802. def agent_sees(self, x, y):
  803. """
  804. Check if a grid position is visible to the agent
  805. """
  806. vx, vy = self.get_view_coords(x, y)
  807. if vx < 0 or vy < 0 or vx >= AGENT_VIEW_SIZE or vy >= AGENT_VIEW_SIZE:
  808. return False
  809. obs = self.gen_obs()
  810. obs_grid = Grid.decode(obs['image'])
  811. obs_cell = obs_grid.get(vx, vy)
  812. world_cell = self.grid.get(x, y)
  813. return obs_cell is not None and obs_cell.type == world_cell.type
  814. def step(self, action):
  815. self.step_count += 1
  816. reward = 0
  817. done = False
  818. # Get the position in front of the agent
  819. u, v = self.get_dir_vec()
  820. fwdPos = (self.agent_pos[0] + u, self.agent_pos[1] + v)
  821. # Get the contents of the cell in front of the agent
  822. fwdCell = self.grid.get(*fwdPos)
  823. # Rotate left
  824. if action == self.actions.left:
  825. self.agent_dir -= 1
  826. if self.agent_dir < 0:
  827. self.agent_dir += 4
  828. # Rotate right
  829. elif action == self.actions.right:
  830. self.agent_dir = (self.agent_dir + 1) % 4
  831. # Move forward
  832. elif action == self.actions.forward:
  833. if fwdCell == None or fwdCell.can_overlap():
  834. self.agent_pos = fwdPos
  835. if fwdCell != None and fwdCell.type == 'goal':
  836. done = True
  837. reward = 1000 - self.step_count
  838. # Pick up an object
  839. elif action == self.actions.pickup:
  840. if fwdCell and fwdCell.can_pickup():
  841. if self.carrying is None:
  842. self.carrying = fwdCell
  843. self.grid.set(*fwdPos, None)
  844. # Drop an object
  845. elif action == self.actions.drop:
  846. if not fwdCell and self.carrying:
  847. self.grid.set(*fwdPos, self.carrying)
  848. self.carrying = None
  849. # Toggle/activate an object
  850. elif action == self.actions.toggle:
  851. if fwdCell:
  852. fwdCell.toggle(self, fwdPos)
  853. # Wait/do nothing
  854. elif action == self.actions.wait:
  855. pass
  856. else:
  857. assert False, "unknown action"
  858. if self.step_count >= self.max_steps:
  859. done = True
  860. obs = self.gen_obs()
  861. return obs, reward, done, {}
  862. def gen_obs(self):
  863. """
  864. Generate the agent's view (partially observable, low-resolution encoding)
  865. """
  866. topX, topY, botX, botY = self.get_view_exts()
  867. grid = self.grid.slice(topX, topY, AGENT_VIEW_SIZE, AGENT_VIEW_SIZE)
  868. for i in range(self.agent_dir + 1):
  869. grid = grid.rotateLeft()
  870. # Make it so the agent sees what it's carrying
  871. # We do this by placing the carried object at the agent's position
  872. # in the agent's partially observable view
  873. agent_pos = grid.width // 2, grid.height - 1
  874. if self.carrying:
  875. grid.set(*agent_pos, self.carrying)
  876. else:
  877. grid.set(*agent_pos, None)
  878. # Process occluders and visibility
  879. # Note that this incurs some performance cost
  880. if not self.see_through_walls:
  881. grid.process_vis_prop(agent_pos=(3, 6))
  882. # Encode the partially observable view into a numpy array
  883. image = grid.encode()
  884. assert hasattr(self, 'mission'), "environments must define a textual mission string"
  885. # Observations are dictionaries containing:
  886. # - an image (partially observable view of the environment)
  887. # - the agent's direction/orientation (acting as a compass)
  888. # - a textual mission string (instructions for the agent)
  889. obs = {
  890. 'image': image,
  891. 'direction': self.agent_dir,
  892. 'mission': self.mission
  893. }
  894. return obs
  895. def get_obs_render(self, obs):
  896. """
  897. Render an agent observation for visualization
  898. """
  899. if self.obs_render == None:
  900. self.obs_render = Renderer(
  901. AGENT_VIEW_SIZE * CELL_PIXELS // 2,
  902. AGENT_VIEW_SIZE * CELL_PIXELS // 2
  903. )
  904. r = self.obs_render
  905. r.beginFrame()
  906. grid = Grid.decode(obs)
  907. # Render the whole grid
  908. grid.render(r, CELL_PIXELS // 2)
  909. # Draw the agent
  910. r.push()
  911. r.scale(0.5, 0.5)
  912. r.translate(
  913. CELL_PIXELS * (0.5 + AGENT_VIEW_SIZE // 2),
  914. CELL_PIXELS * (AGENT_VIEW_SIZE - 0.5)
  915. )
  916. r.rotate(3 * 90)
  917. r.setLineColor(255, 0, 0)
  918. r.setColor(255, 0, 0)
  919. r.drawPolygon([
  920. (-12, 10),
  921. ( 12, 0),
  922. (-12, -10)
  923. ])
  924. r.pop()
  925. r.endFrame()
  926. return r.getPixmap()
  927. def render(self, mode='human', close=False):
  928. """
  929. Render the whole-grid human view
  930. """
  931. if close:
  932. if self.grid_render:
  933. self.grid_render.close()
  934. return
  935. if self.grid_render is None:
  936. self.grid_render = Renderer(
  937. self.grid_size * CELL_PIXELS,
  938. self.grid_size * CELL_PIXELS,
  939. True if mode == 'human' else False
  940. )
  941. r = self.grid_render
  942. r.beginFrame()
  943. # Render the whole grid
  944. self.grid.render(r, CELL_PIXELS)
  945. # Draw the agent
  946. r.push()
  947. r.translate(
  948. CELL_PIXELS * (self.agent_pos[0] + 0.5),
  949. CELL_PIXELS * (self.agent_pos[1] + 0.5)
  950. )
  951. r.rotate(self.agent_dir * 90)
  952. r.setLineColor(255, 0, 0)
  953. r.setColor(255, 0, 0)
  954. r.drawPolygon([
  955. (-12, 10),
  956. ( 12, 0),
  957. (-12, -10)
  958. ])
  959. r.pop()
  960. # Highlight what the agent can see
  961. topX, topY, botX, botY = self.get_view_exts()
  962. r.fillRect(
  963. topX * CELL_PIXELS,
  964. topY * CELL_PIXELS,
  965. AGENT_VIEW_SIZE * CELL_PIXELS,
  966. AGENT_VIEW_SIZE * CELL_PIXELS,
  967. 200, 200, 200, 75
  968. )
  969. r.endFrame()
  970. if mode == 'rgb_array':
  971. return r.getArray()
  972. elif mode == 'pixmap':
  973. return r.getPixmap()
  974. return r