minigrid.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245
  1. import math
  2. import gym
  3. from enum import IntEnum
  4. import numpy as np
  5. from gym import error, spaces, utils
  6. from gym.utils import seeding
  7. from gym_minigrid.rendering import *
  8. # Size in pixels of a cell in the full-scale human view
  9. CELL_PIXELS = 32
  10. # Number of cells (width and height) in the agent view
  11. AGENT_VIEW_SIZE = 7
  12. # Size of the array given as an observation to the agent
  13. OBS_ARRAY_SIZE = (AGENT_VIEW_SIZE, AGENT_VIEW_SIZE, 3)
  14. # Map of color names to RGB values
  15. COLORS = {
  16. 'red' : (255, 0, 0),
  17. 'green' : (0, 255, 0),
  18. 'blue' : (0, 0, 255),
  19. 'purple': (112, 39, 195),
  20. 'yellow': (255, 255, 0),
  21. 'grey' : (100, 100, 100)
  22. }
  23. COLOR_NAMES = sorted(list(COLORS.keys()))
  24. # Used to map colors to integers
  25. COLOR_TO_IDX = {
  26. 'red' : 0,
  27. 'green' : 1,
  28. 'blue' : 2,
  29. 'purple': 3,
  30. 'yellow': 4,
  31. 'grey' : 5
  32. }
  33. IDX_TO_COLOR = dict(zip(COLOR_TO_IDX.values(), COLOR_TO_IDX.keys()))
  34. # Map of object type to integers
  35. OBJECT_TO_IDX = {
  36. 'empty' : 0,
  37. 'wall' : 1,
  38. 'door' : 2,
  39. 'locked_door' : 3,
  40. 'key' : 4,
  41. 'ball' : 5,
  42. 'box' : 6,
  43. 'goal' : 7
  44. }
  45. IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys()))
  46. # Map of agent direction indices to vectors
  47. DIR_TO_VEC = [
  48. # Pointing right (positive X)
  49. np.array((1, 0)),
  50. # Down (positive Y)
  51. np.array((0, 1)),
  52. # Pointing left (negative X)
  53. np.array((-1, 0)),
  54. # Up (negative Y)
  55. np.array((0, -1)),
  56. ]
  57. class WorldObj:
  58. """
  59. Base class for grid world objects
  60. """
  61. def __init__(self, type, color):
  62. assert type in OBJECT_TO_IDX, type
  63. assert color in COLOR_TO_IDX, color
  64. self.type = type
  65. self.color = color
  66. self.contains = None
  67. def can_overlap(self):
  68. """Can the agent overlap with this?"""
  69. return False
  70. def can_pickup(self):
  71. """Can the agent pick this up?"""
  72. return False
  73. def can_contain(self):
  74. """Can this contain another object?"""
  75. return False
  76. def see_behind(self):
  77. """Can the agent see behind this object?"""
  78. return True
  79. def toggle(self, env, pos):
  80. """Method to trigger/toggle an action this object performs"""
  81. return False
  82. def render(self, r):
  83. """Draw this object with the given renderer"""
  84. raise NotImplementedError
  85. def _set_color(self, r):
  86. """Set the color of this object as the active drawing color"""
  87. c = COLORS[self.color]
  88. r.setLineColor(c[0], c[1], c[2])
  89. r.setColor(c[0], c[1], c[2])
  90. class Goal(WorldObj):
  91. def __init__(self):
  92. super(Goal, self).__init__('goal', 'green')
  93. def can_overlap(self):
  94. return True
  95. def render(self, r):
  96. self._set_color(r)
  97. r.drawPolygon([
  98. (0 , CELL_PIXELS),
  99. (CELL_PIXELS, CELL_PIXELS),
  100. (CELL_PIXELS, 0),
  101. (0 , 0)
  102. ])
  103. class Wall(WorldObj):
  104. def __init__(self, color='grey'):
  105. super(Wall, self).__init__('wall', color)
  106. def see_behind(self):
  107. return False
  108. def render(self, r):
  109. self._set_color(r)
  110. r.drawPolygon([
  111. (0 , CELL_PIXELS),
  112. (CELL_PIXELS, CELL_PIXELS),
  113. (CELL_PIXELS, 0),
  114. (0 , 0)
  115. ])
  116. class Door(WorldObj):
  117. def __init__(self, color, is_open=False):
  118. super(Door, self).__init__('door', color)
  119. self.is_open = is_open
  120. def can_overlap(self):
  121. """The agent can only walk over this cell when the door is open"""
  122. return self.is_open
  123. def see_behind(self):
  124. return self.is_open
  125. def toggle(self, env, pos):
  126. if not self.is_open:
  127. self.is_open = True
  128. return True
  129. return False
  130. def render(self, r):
  131. c = COLORS[self.color]
  132. r.setLineColor(c[0], c[1], c[2])
  133. r.setColor(0, 0, 0)
  134. if self.is_open:
  135. r.drawPolygon([
  136. (CELL_PIXELS-2, CELL_PIXELS),
  137. (CELL_PIXELS , CELL_PIXELS),
  138. (CELL_PIXELS , 0),
  139. (CELL_PIXELS-2, 0)
  140. ])
  141. return
  142. r.drawPolygon([
  143. (0 , CELL_PIXELS),
  144. (CELL_PIXELS, CELL_PIXELS),
  145. (CELL_PIXELS, 0),
  146. (0 , 0)
  147. ])
  148. r.drawPolygon([
  149. (2 , CELL_PIXELS-2),
  150. (CELL_PIXELS-2, CELL_PIXELS-2),
  151. (CELL_PIXELS-2, 2),
  152. (2 , 2)
  153. ])
  154. r.drawCircle(CELL_PIXELS * 0.75, CELL_PIXELS * 0.5, 2)
  155. class LockedDoor(WorldObj):
  156. def __init__(self, color, is_open=False):
  157. super(LockedDoor, self).__init__('locked_door', color)
  158. self.is_open = is_open
  159. def toggle(self, env, pos):
  160. # If the player has the right key to open the door
  161. if isinstance(env.carrying, Key) and env.carrying.color == self.color:
  162. self.is_open = True
  163. # The key has been used, remove it from the agent
  164. env.carrying = None
  165. return True
  166. return False
  167. def can_overlap(self):
  168. """The agent can only walk over this cell when the door is open"""
  169. return self.is_open
  170. def see_behind(self):
  171. return self.is_open
  172. def render(self, r):
  173. c = COLORS[self.color]
  174. r.setLineColor(c[0], c[1], c[2])
  175. r.setColor(c[0], c[1], c[2], 50)
  176. if self.is_open:
  177. r.drawPolygon([
  178. (CELL_PIXELS-2, CELL_PIXELS),
  179. (CELL_PIXELS , CELL_PIXELS),
  180. (CELL_PIXELS , 0),
  181. (CELL_PIXELS-2, 0)
  182. ])
  183. return
  184. r.drawPolygon([
  185. (0 , CELL_PIXELS),
  186. (CELL_PIXELS, CELL_PIXELS),
  187. (CELL_PIXELS, 0),
  188. (0 , 0)
  189. ])
  190. r.drawPolygon([
  191. (2 , CELL_PIXELS-2),
  192. (CELL_PIXELS-2, CELL_PIXELS-2),
  193. (CELL_PIXELS-2, 2),
  194. (2 , 2)
  195. ])
  196. r.drawLine(
  197. CELL_PIXELS * 0.55,
  198. CELL_PIXELS * 0.5,
  199. CELL_PIXELS * 0.75,
  200. CELL_PIXELS * 0.5
  201. )
  202. class Key(WorldObj):
  203. def __init__(self, color='blue'):
  204. super(Key, self).__init__('key', color)
  205. def can_pickup(self):
  206. return True
  207. def render(self, r):
  208. self._set_color(r)
  209. # Vertical quad
  210. r.drawPolygon([
  211. (16, 10),
  212. (20, 10),
  213. (20, 28),
  214. (16, 28)
  215. ])
  216. # Teeth
  217. r.drawPolygon([
  218. (12, 19),
  219. (16, 19),
  220. (16, 21),
  221. (12, 21)
  222. ])
  223. r.drawPolygon([
  224. (12, 26),
  225. (16, 26),
  226. (16, 28),
  227. (12, 28)
  228. ])
  229. r.drawCircle(18, 9, 6)
  230. r.setLineColor(0, 0, 0)
  231. r.setColor(0, 0, 0)
  232. r.drawCircle(18, 9, 2)
  233. class Ball(WorldObj):
  234. def __init__(self, color='blue'):
  235. super(Ball, self).__init__('ball', color)
  236. def can_pickup(self):
  237. return True
  238. def render(self, r):
  239. self._set_color(r)
  240. r.drawCircle(CELL_PIXELS * 0.5, CELL_PIXELS * 0.5, 10)
  241. class Box(WorldObj):
  242. def __init__(self, color, contains=None):
  243. super(Box, self).__init__('box', color)
  244. self.contains = contains
  245. def can_pickup(self):
  246. return True
  247. def render(self, r):
  248. c = COLORS[self.color]
  249. r.setLineColor(c[0], c[1], c[2])
  250. r.setColor(0, 0, 0)
  251. r.setLineWidth(2)
  252. r.drawPolygon([
  253. (4 , CELL_PIXELS-4),
  254. (CELL_PIXELS-4, CELL_PIXELS-4),
  255. (CELL_PIXELS-4, 4),
  256. (4 , 4)
  257. ])
  258. r.drawLine(
  259. 4,
  260. CELL_PIXELS / 2,
  261. CELL_PIXELS - 4,
  262. CELL_PIXELS / 2
  263. )
  264. r.setLineWidth(1)
  265. def toggle(self, env, pos):
  266. # Replace the box by its contents
  267. env.grid.set(*pos, self.contains)
  268. return True
  269. class Grid:
  270. """
  271. Represent a grid and operations on it
  272. """
  273. def __init__(self, width, height):
  274. assert width >= 4
  275. assert height >= 4
  276. self.width = width
  277. self.height = height
  278. self.grid = [None] * width * height
  279. def __contains__(self, key):
  280. if isinstance(key, WorldObj):
  281. for e in self.grid:
  282. if e is key:
  283. return True
  284. elif isinstance(key, tuple):
  285. for e in self.grid:
  286. if e is None:
  287. continue
  288. if (e.color, e.type) == key:
  289. return True
  290. return False
  291. def __eq__(self, other):
  292. grid1 = self.encode()
  293. grid2 = other.encode()
  294. return np.array_equal(grid2, grid1)
  295. def __ne__(self, other):
  296. return not self == other
  297. def copy(self):
  298. from copy import deepcopy
  299. return deepcopy(self)
  300. def set(self, i, j, v):
  301. assert i >= 0 and i < self.width
  302. assert j >= 0 and j < self.height
  303. self.grid[j * self.width + i] = v
  304. def get(self, i, j):
  305. assert i >= 0 and i < self.width
  306. assert j >= 0 and j < self.height
  307. return self.grid[j * self.width + i]
  308. def horz_wall(self, x, y, length=None):
  309. if length is None:
  310. length = self.width - x
  311. for i in range(0, length):
  312. self.set(x + i, y, Wall())
  313. def vert_wall(self, x, y, length=None):
  314. if length is None:
  315. length = self.height - y
  316. for j in range(0, length):
  317. self.set(x, y + j, Wall())
  318. def wall_rect(self, x, y, w, h):
  319. self.horz_wall(x, y, w)
  320. self.horz_wall(x, y+h-1, w)
  321. self.vert_wall(x, y, h)
  322. self.vert_wall(x+w-1, y, h)
  323. def rotate_left(self):
  324. """
  325. Rotate the grid to the left (counter-clockwise)
  326. """
  327. grid = Grid(self.width, self.height)
  328. for j in range(0, self.height):
  329. for i in range(0, self.width):
  330. v = self.get(self.width - 1 - j, i)
  331. grid.set(i, j, v)
  332. return grid
  333. def slice(self, topX, topY, width, height):
  334. """
  335. Get a subset of the grid
  336. """
  337. grid = Grid(width, height)
  338. for j in range(0, height):
  339. for i in range(0, width):
  340. x = topX + i
  341. y = topY + j
  342. if x >= 0 and x < self.width and \
  343. y >= 0 and y < self.height:
  344. v = self.get(x, y)
  345. else:
  346. v = Wall()
  347. grid.set(i, j, v)
  348. return grid
  349. def render(self, r, tile_size):
  350. """
  351. Render this grid at a given scale
  352. :param r: target renderer object
  353. :param tile_size: tile size in pixels
  354. """
  355. assert r.width == self.width * tile_size
  356. assert r.height == self.height * tile_size
  357. # Total grid size at native scale
  358. widthPx = self.width * CELL_PIXELS
  359. heightPx = self.height * CELL_PIXELS
  360. """
  361. # Draw background (out-of-world) tiles the same colors as walls
  362. # so the agent understands these areas are not reachable
  363. c = COLORS['grey']
  364. r.setLineColor(c[0], c[1], c[2])
  365. r.setColor(c[0], c[1], c[2])
  366. r.drawPolygon([
  367. (0 , heightPx),
  368. (widthPx, heightPx),
  369. (widthPx, 0),
  370. (0 , 0)
  371. ])
  372. """
  373. r.push()
  374. # Internally, we draw at the "large" full-grid resolution, but we
  375. # use the renderer to scale back to the desired size
  376. r.scale(tile_size / CELL_PIXELS, tile_size / CELL_PIXELS)
  377. # Draw the background of the in-world cells black
  378. r.fillRect(
  379. 0,
  380. 0,
  381. widthPx,
  382. heightPx,
  383. 0, 0, 0
  384. )
  385. # Draw grid lines
  386. r.setLineColor(100, 100, 100)
  387. for rowIdx in range(0, self.height):
  388. y = CELL_PIXELS * rowIdx
  389. r.drawLine(0, y, widthPx, y)
  390. for colIdx in range(0, self.width):
  391. x = CELL_PIXELS * colIdx
  392. r.drawLine(x, 0, x, heightPx)
  393. # Render the grid
  394. for j in range(0, self.height):
  395. for i in range(0, self.width):
  396. cell = self.get(i, j)
  397. if cell == None:
  398. continue
  399. r.push()
  400. r.translate(i * CELL_PIXELS, j * CELL_PIXELS)
  401. cell.render(r)
  402. r.pop()
  403. r.pop()
  404. def encode(self):
  405. """
  406. Produce a compact numpy encoding of the grid
  407. """
  408. codeSize = self.width * self.height * 3
  409. array = np.zeros(shape=(self.width, self.height, 3), dtype='uint8')
  410. for j in range(0, self.height):
  411. for i in range(0, self.width):
  412. v = self.get(i, j)
  413. if v == None:
  414. continue
  415. array[i, j, 0] = OBJECT_TO_IDX[v.type]
  416. array[i, j, 1] = COLOR_TO_IDX[v.color]
  417. if hasattr(v, 'is_open') and v.is_open:
  418. array[i, j, 2] = 1
  419. return array
  420. def decode(array):
  421. """
  422. Decode an array grid encoding back into a grid
  423. """
  424. width = array.shape[0]
  425. height = array.shape[1]
  426. assert array.shape[2] == 3
  427. grid = Grid(width, height)
  428. for j in range(0, height):
  429. for i in range(0, width):
  430. typeIdx = array[i, j, 0]
  431. colorIdx = array[i, j, 1]
  432. openIdx = array[i, j, 2]
  433. if typeIdx == 0:
  434. continue
  435. objType = IDX_TO_OBJECT[typeIdx]
  436. color = IDX_TO_COLOR[colorIdx]
  437. is_open = True if openIdx == 1 else 0
  438. if objType == 'wall':
  439. v = Wall(color)
  440. elif objType == 'ball':
  441. v = Ball(color)
  442. elif objType == 'key':
  443. v = Key(color)
  444. elif objType == 'box':
  445. v = Box(color)
  446. elif objType == 'door':
  447. v = Door(color, is_open)
  448. elif objType == 'locked_door':
  449. v = LockedDoor(color, is_open)
  450. elif objType == 'goal':
  451. v = Goal()
  452. else:
  453. assert False, "unknown obj type in decode '%s'" % objType
  454. grid.set(i, j, v)
  455. return grid
  456. def process_vis(
  457. grid,
  458. agent_pos,
  459. n_rays = 32,
  460. n_steps = 24,
  461. a_min = math.pi,
  462. a_max = 2 * math.pi
  463. ):
  464. """
  465. Use ray casting to determine the visibility of each grid cell
  466. """
  467. mask = np.zeros(shape=(grid.width, grid.height), dtype=np.bool)
  468. ang_step = (a_max - a_min) / n_rays
  469. dst_step = math.sqrt(grid.width ** 2 + grid.height ** 2) / n_steps
  470. ax = agent_pos[0] + 0.5
  471. ay = agent_pos[1] + 0.5
  472. for ray_idx in range(0, n_rays):
  473. angle = a_min + ang_step * ray_idx
  474. dx = dst_step * math.cos(angle)
  475. dy = dst_step * math.sin(angle)
  476. for step_idx in range(0, n_steps):
  477. x = ax + (step_idx * dx)
  478. y = ay + (step_idx * dy)
  479. i = math.floor(x)
  480. j = math.floor(y)
  481. # If we're outside of the grid, stop
  482. if i < 0 or i >= grid.width or j < 0 or j >= grid.height:
  483. break
  484. # Mark this cell as visible
  485. mask[i, j] = True
  486. # If we hit the obstructor, stop
  487. cell = grid.get(i, j)
  488. if cell and not cell.see_behind():
  489. break
  490. for j in range(0, grid.height):
  491. for i in range(0, grid.width):
  492. if not mask[i, j]:
  493. grid.set(i, j, None)
  494. #grid.set(i, j, Wall('red'))
  495. return mask
  496. def process_vis_prop(
  497. grid,
  498. agent_pos
  499. ):
  500. mask = np.zeros(shape=(grid.width, grid.height), dtype=np.bool)
  501. mask[agent_pos[0], agent_pos[1]] = True
  502. for j in reversed(range(1, grid.height)):
  503. for i in range(0, grid.width-1):
  504. if not mask[i, j]:
  505. continue
  506. cell = grid.get(i, j)
  507. if cell and not cell.see_behind():
  508. continue
  509. mask[i+1, j] = True
  510. mask[i+1, j-1] = True
  511. mask[i, j-1] = True
  512. for i in reversed(range(1, grid.width)):
  513. if not mask[i, j]:
  514. continue
  515. cell = grid.get(i, j)
  516. if cell and not cell.see_behind():
  517. continue
  518. mask[i-1, j-1] = True
  519. mask[i-1, j] = True
  520. mask[i, j-1] = True
  521. for j in range(0, grid.height):
  522. for i in range(0, grid.width):
  523. if not mask[i, j]:
  524. grid.set(i, j, None)
  525. #grid.set(i, j, Wall('red'))
  526. class MiniGridEnv(gym.Env):
  527. """
  528. 2D grid world game environment
  529. """
  530. metadata = {
  531. 'render.modes': ['human', 'rgb_array', 'pixmap'],
  532. 'video.frames_per_second' : 10
  533. }
  534. # Enumeration of possible actions
  535. class Actions(IntEnum):
  536. # Turn left, turn right, move forward
  537. left = 0
  538. right = 1
  539. forward = 2
  540. # Pick up an object
  541. pickup = 3
  542. # Drop an object
  543. drop = 4
  544. # Toggle/activate an object
  545. toggle = 5
  546. # Wait/stay put/do nothing
  547. wait = 6
  548. def __init__(
  549. self,
  550. grid_size=16,
  551. max_steps=100,
  552. see_through_walls=False,
  553. seed=1337
  554. ):
  555. # Action enumeration for this environment
  556. self.actions = MiniGridEnv.Actions
  557. # Actions are discrete integer values
  558. self.action_space = spaces.Discrete(len(self.actions))
  559. # Observations are dictionaries containing an
  560. # encoding of the grid and a textual 'mission' string
  561. self.observation_space = spaces.Box(
  562. low=0,
  563. high=255,
  564. shape=OBS_ARRAY_SIZE,
  565. dtype='uint8'
  566. )
  567. self.observation_space = spaces.Dict({
  568. 'image': self.observation_space
  569. })
  570. # Range of possible rewards
  571. self.reward_range = (0, 1)
  572. # Renderer object used to render the whole grid (full-scale)
  573. self.grid_render = None
  574. # Renderer used to render observations (small-scale agent view)
  575. self.obs_render = None
  576. # Environment configuration
  577. self.grid_size = grid_size
  578. self.max_steps = max_steps
  579. self.see_through_walls = see_through_walls
  580. # Starting position and direction for the agent
  581. self.start_pos = None
  582. self.start_dir = None
  583. # Initialize the RNG
  584. self.seed(seed=seed)
  585. # Initialize the state
  586. self.reset()
  587. def reset(self):
  588. # Generate a new random grid at the start of each episode
  589. # To keep the same grid for each episode, call env.seed() with
  590. # the same seed before calling env.reset()
  591. self._gen_grid(self.grid_size, self.grid_size)
  592. # These fields should be defined by _gen_grid
  593. assert self.start_pos is not None
  594. assert self.start_dir is not None
  595. # Check that the agent doesn't overlap with an object
  596. assert self.grid.get(*self.start_pos) is None
  597. # Place the agent in the starting position and direction
  598. self.agent_pos = self.start_pos
  599. self.agent_dir = self.start_dir
  600. # Item picked up, being carried, initially nothing
  601. self.carrying = None
  602. # Step count since episode start
  603. self.step_count = 0
  604. # Return first observation
  605. obs = self.gen_obs()
  606. return obs
  607. def seed(self, seed=1337):
  608. # Seed the random number generator
  609. self.np_random, _ = seeding.np_random(seed)
  610. return [seed]
  611. @property
  612. def steps_remaining(self):
  613. return self.max_steps - self.step_count
  614. def __str__(self):
  615. """
  616. Produce a pretty string of the environment's grid along with the agent.
  617. The agent is represented by `⏩`. A grid pixel is represented by 2-character
  618. string, the first one for the object and the second one for the color.
  619. """
  620. from copy import deepcopy
  621. def rotate_left(array):
  622. new_array = deepcopy(array)
  623. for i in range(len(array)):
  624. for j in range(len(array[0])):
  625. new_array[j][len(array[0])-1-i] = array[i][j]
  626. return new_array
  627. def vertically_symmetrize(array):
  628. new_array = deepcopy(array)
  629. for i in range(len(array)):
  630. for j in range(len(array[0])):
  631. new_array[i][len(array[0])-1-j] = array[i][j]
  632. return new_array
  633. # Map of object id to short string
  634. OBJECT_IDX_TO_IDS = {
  635. 0: ' ',
  636. 1: 'W',
  637. 2: 'D',
  638. 3: 'L',
  639. 4: 'K',
  640. 5: 'B',
  641. 6: 'X',
  642. 7: 'G'
  643. }
  644. # Short string for opened door
  645. OPENDED_DOOR_IDS = '_'
  646. # Map of color id to short string
  647. COLOR_IDX_TO_IDS = {
  648. 0: 'R',
  649. 1: 'G',
  650. 2: 'B',
  651. 3: 'P',
  652. 4: 'Y',
  653. 5: 'E'
  654. }
  655. # Map agent's direction to short string
  656. AGENT_DIR_TO_IDS = {
  657. 0: '⏩',
  658. 1: '⏬',
  659. 2: '⏪',
  660. 3: '⏫'
  661. }
  662. array = self.grid.encode()
  663. array = rotate_left(array)
  664. array = vertically_symmetrize(array)
  665. new_array = []
  666. for line in array:
  667. new_line = []
  668. for pixel in line:
  669. # If the door is opened
  670. if pixel[0] in [2, 3] and pixel[2] == 1:
  671. object_ids = OPENDED_DOOR_IDS
  672. else:
  673. object_ids = OBJECT_IDX_TO_IDS[pixel[0]]
  674. # If no object
  675. if pixel[0] == 0:
  676. color_ids = ' '
  677. else:
  678. color_ids = COLOR_IDX_TO_IDS[pixel[1]]
  679. new_line.append(object_ids + color_ids)
  680. new_array.append(new_line)
  681. # Add the agent
  682. new_array[self.agent_pos[1]][self.agent_pos[0]] = AGENT_DIR_TO_IDS[self.agent_dir]
  683. return "\n".join([" ".join(line) for line in new_array])
  684. def _gen_grid(self, width, height):
  685. assert False, "_gen_grid needs to be implemented by each environment"
  686. def _rand_int(self, low, high):
  687. """
  688. Generate random integer in [low,high[
  689. """
  690. return self.np_random.randint(low, high)
  691. def _rand_bool(self):
  692. """
  693. Generate random boolean value
  694. """
  695. return (self.np_random.randint(0, 2) == 0)
  696. def _rand_elem(self, iterable):
  697. """
  698. Pick a random element in a list
  699. """
  700. lst = list(iterable)
  701. idx = self._rand_int(0, len(lst))
  702. return lst[idx]
  703. def _rand_color(self):
  704. """
  705. Generate a random color name (string)
  706. """
  707. return self._rand_elem(COLOR_NAMES)
  708. def _rand_pos(self, xLow, xHigh, yLow, yHigh):
  709. """
  710. Generate a random (x,y) position tuple
  711. """
  712. return (
  713. self.np_random.randint(xLow, xHigh),
  714. self.np_random.randint(yLow, yHigh)
  715. )
  716. def place_obj(self, obj, top=None, size=None, reject_fn=None):
  717. """
  718. Place an object at an empty position in the grid
  719. :param top: top-left position of the rectangle where to place
  720. :param size: size of the rectangle where to place
  721. :param reject_fn: function to filter out potential positions
  722. """
  723. if top is None:
  724. top = (0, 0)
  725. if size is None:
  726. size = (self.grid.width, self.grid.height)
  727. while True:
  728. pos = np.array((
  729. self._rand_int(top[0], top[0] + size[0]),
  730. self._rand_int(top[1], top[1] + size[1])
  731. ))
  732. # Don't place the object on top of another object
  733. if self.grid.get(*pos) != None:
  734. continue
  735. # Don't place the object where the agent is
  736. if np.array_equal(pos, self.start_pos):
  737. continue
  738. # Check if there is a filtering criterion
  739. if reject_fn and reject_fn(self, pos):
  740. continue
  741. break
  742. self.grid.set(*pos, obj)
  743. return pos
  744. def place_agent(self, top=None, size=None, rand_dir=True):
  745. """
  746. Set the agent's starting point at an empty position in the grid
  747. """
  748. pos = self.place_obj(None, top, size)
  749. self.start_pos = pos
  750. if rand_dir:
  751. self.start_dir = self._rand_int(0, 4)
  752. return pos
  753. def get_dir_vec(self):
  754. """
  755. Get the direction vector for the agent, pointing in the direction
  756. of forward movement.
  757. """
  758. assert self.agent_dir >= 0 and self.agent_dir < 4
  759. return DIR_TO_VEC[self.agent_dir]
  760. def get_right_vec(self):
  761. """
  762. Get the vector pointing to the right of the agent.
  763. """
  764. dx, dy = self.get_dir_vec()
  765. return np.array((-dy, dx))
  766. def get_view_coords(self, i, j):
  767. """
  768. Translate and rotate absolute grid coordinates (i, j) into the
  769. agent's partially observable view (sub-grid). Note that the resulting
  770. coordinates may be negative or outside of the agent's view size.
  771. """
  772. ax, ay = self.agent_pos
  773. dx, dy = self.get_dir_vec()
  774. rx, ry = self.get_right_vec()
  775. # Compute the absolute coordinates of the top-left view corner
  776. sz = AGENT_VIEW_SIZE
  777. hs = AGENT_VIEW_SIZE // 2
  778. tx = ax + (dx * (sz-1)) - (rx * hs)
  779. ty = ay + (dy * (sz-1)) - (ry * hs)
  780. lx = i - tx
  781. ly = j - ty
  782. # Project the coordinates of the object relative to the top-left
  783. # corner onto the agent's own coordinate system
  784. vx = (rx*lx + ry*ly)
  785. vy = -(dx*lx + dy*ly)
  786. return vx, vy
  787. def get_view_exts(self):
  788. """
  789. Get the extents of the square set of tiles visible to the agent
  790. Note: the bottom extent indices are not included in the set
  791. """
  792. # Facing right
  793. if self.agent_dir == 0:
  794. topX = self.agent_pos[0]
  795. topY = self.agent_pos[1] - AGENT_VIEW_SIZE // 2
  796. # Facing down
  797. elif self.agent_dir == 1:
  798. topX = self.agent_pos[0] - AGENT_VIEW_SIZE // 2
  799. topY = self.agent_pos[1]
  800. # Facing left
  801. elif self.agent_dir == 2:
  802. topX = self.agent_pos[0] - AGENT_VIEW_SIZE + 1
  803. topY = self.agent_pos[1] - AGENT_VIEW_SIZE // 2
  804. # Facing up
  805. elif self.agent_dir == 3:
  806. topX = self.agent_pos[0] - AGENT_VIEW_SIZE // 2
  807. topY = self.agent_pos[1] - AGENT_VIEW_SIZE + 1
  808. else:
  809. assert False, "invalid agent direction"
  810. botX = topX + AGENT_VIEW_SIZE
  811. botY = topY + AGENT_VIEW_SIZE
  812. return (topX, topY, botX, botY)
  813. def agent_sees(self, x, y):
  814. """
  815. Check if a grid position is visible to the agent
  816. """
  817. vx, vy = self.get_view_coords(x, y)
  818. if vx < 0 or vy < 0 or vx >= AGENT_VIEW_SIZE or vy >= AGENT_VIEW_SIZE:
  819. return False
  820. obs = self.gen_obs()
  821. obs_grid = Grid.decode(obs['image'])
  822. obs_cell = obs_grid.get(vx, vy)
  823. world_cell = self.grid.get(x, y)
  824. return obs_cell is not None and obs_cell.type == world_cell.type
  825. def step(self, action):
  826. self.step_count += 1
  827. reward = 0
  828. done = False
  829. # Get the position in front of the agent
  830. fwd_pos = self.agent_pos + self.get_dir_vec()
  831. # Get the contents of the cell in front of the agent
  832. fwd_cell = self.grid.get(*fwd_pos)
  833. # Rotate left
  834. if action == self.actions.left:
  835. self.agent_dir -= 1
  836. if self.agent_dir < 0:
  837. self.agent_dir += 4
  838. # Rotate right
  839. elif action == self.actions.right:
  840. self.agent_dir = (self.agent_dir + 1) % 4
  841. # Move forward
  842. elif action == self.actions.forward:
  843. if fwd_cell == None or fwd_cell.can_overlap():
  844. self.agent_pos = fwd_pos
  845. if fwd_cell != None and fwd_cell.type == 'goal':
  846. done = True
  847. reward = 1
  848. # Pick up an object
  849. elif action == self.actions.pickup:
  850. if fwd_cell and fwd_cell.can_pickup():
  851. if self.carrying is None:
  852. self.carrying = fwd_cell
  853. self.grid.set(*fwd_pos, None)
  854. # Drop an object
  855. elif action == self.actions.drop:
  856. if not fwd_cell and self.carrying:
  857. self.grid.set(*fwd_pos, self.carrying)
  858. self.carrying = None
  859. # Toggle/activate an object
  860. elif action == self.actions.toggle:
  861. if fwd_cell:
  862. fwd_cell.toggle(self, fwd_pos)
  863. # Wait/do nothing
  864. elif action == self.actions.wait:
  865. pass
  866. else:
  867. assert False, "unknown action"
  868. if self.step_count >= self.max_steps:
  869. done = True
  870. obs = self.gen_obs()
  871. return obs, reward, done, {}
  872. def gen_obs(self):
  873. """
  874. Generate the agent's view (partially observable, low-resolution encoding)
  875. """
  876. topX, topY, botX, botY = self.get_view_exts()
  877. grid = self.grid.slice(topX, topY, AGENT_VIEW_SIZE, AGENT_VIEW_SIZE)
  878. for i in range(self.agent_dir + 1):
  879. grid = grid.rotate_left()
  880. # Make it so the agent sees what it's carrying
  881. # We do this by placing the carried object at the agent's position
  882. # in the agent's partially observable view
  883. agent_pos = grid.width // 2, grid.height - 1
  884. if self.carrying:
  885. grid.set(*agent_pos, self.carrying)
  886. else:
  887. grid.set(*agent_pos, None)
  888. # Process occluders and visibility
  889. # Note that this incurs some performance cost
  890. if not self.see_through_walls:
  891. grid.process_vis_prop(agent_pos=(3, 6))
  892. # Encode the partially observable view into a numpy array
  893. image = grid.encode()
  894. assert hasattr(self, 'mission'), "environments must define a textual mission string"
  895. # Observations are dictionaries containing:
  896. # - an image (partially observable view of the environment)
  897. # - the agent's direction/orientation (acting as a compass)
  898. # - a textual mission string (instructions for the agent)
  899. obs = {
  900. 'image': image,
  901. 'direction': self.agent_dir,
  902. 'mission': self.mission
  903. }
  904. return obs
  905. def get_obs_render(self, obs):
  906. """
  907. Render an agent observation for visualization
  908. """
  909. if self.obs_render == None:
  910. self.obs_render = Renderer(
  911. AGENT_VIEW_SIZE * CELL_PIXELS // 2,
  912. AGENT_VIEW_SIZE * CELL_PIXELS // 2
  913. )
  914. r = self.obs_render
  915. r.beginFrame()
  916. grid = Grid.decode(obs)
  917. # Render the whole grid
  918. grid.render(r, CELL_PIXELS // 2)
  919. # Draw the agent
  920. r.push()
  921. r.scale(0.5, 0.5)
  922. r.translate(
  923. CELL_PIXELS * (0.5 + AGENT_VIEW_SIZE // 2),
  924. CELL_PIXELS * (AGENT_VIEW_SIZE - 0.5)
  925. )
  926. r.rotate(3 * 90)
  927. r.setLineColor(255, 0, 0)
  928. r.setColor(255, 0, 0)
  929. r.drawPolygon([
  930. (-12, 10),
  931. ( 12, 0),
  932. (-12, -10)
  933. ])
  934. r.pop()
  935. r.endFrame()
  936. return r.getPixmap()
  937. def render(self, mode='human', close=False):
  938. """
  939. Render the whole-grid human view
  940. """
  941. if close:
  942. if self.grid_render:
  943. self.grid_render.close()
  944. return
  945. if self.grid_render is None:
  946. self.grid_render = Renderer(
  947. self.grid_size * CELL_PIXELS,
  948. self.grid_size * CELL_PIXELS,
  949. True if mode == 'human' else False
  950. )
  951. r = self.grid_render
  952. r.beginFrame()
  953. # Render the whole grid
  954. self.grid.render(r, CELL_PIXELS)
  955. # Draw the agent
  956. r.push()
  957. r.translate(
  958. CELL_PIXELS * (self.agent_pos[0] + 0.5),
  959. CELL_PIXELS * (self.agent_pos[1] + 0.5)
  960. )
  961. r.rotate(self.agent_dir * 90)
  962. r.setLineColor(255, 0, 0)
  963. r.setColor(255, 0, 0)
  964. r.drawPolygon([
  965. (-12, 10),
  966. ( 12, 0),
  967. (-12, -10)
  968. ])
  969. r.pop()
  970. # Highlight what the agent can see
  971. topX, topY, botX, botY = self.get_view_exts()
  972. r.fillRect(
  973. topX * CELL_PIXELS,
  974. topY * CELL_PIXELS,
  975. AGENT_VIEW_SIZE * CELL_PIXELS,
  976. AGENT_VIEW_SIZE * CELL_PIXELS,
  977. 200, 200, 200, 75
  978. )
  979. r.endFrame()
  980. if mode == 'rgb_array':
  981. return r.getArray()
  982. elif mode == 'pixmap':
  983. return r.getPixmap()
  984. return r