minigrid.py 33 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232
  1. import math
  2. import gym
  3. from enum import IntEnum
  4. import numpy as np
  5. from gym import error, spaces, utils
  6. from gym.utils import seeding
  7. from gym_minigrid.rendering import *
  8. # Size in pixels of a cell in the full-scale human view
  9. CELL_PIXELS = 32
  10. # Number of cells (width and height) in the agent view
  11. AGENT_VIEW_SIZE = 7
  12. # Size of the array given as an observation to the agent
  13. OBS_ARRAY_SIZE = (AGENT_VIEW_SIZE, AGENT_VIEW_SIZE, 3)
  14. # Map of color names to RGB values
  15. COLORS = {
  16. 'red' : (255, 0, 0),
  17. 'green' : (0, 255, 0),
  18. 'blue' : (0, 0, 255),
  19. 'purple': (112, 39, 195),
  20. 'yellow': (255, 255, 0),
  21. 'grey' : (100, 100, 100)
  22. }
  23. COLOR_NAMES = sorted(list(COLORS.keys()))
  24. # Used to map colors to integers
  25. COLOR_TO_IDX = {
  26. 'red' : 0,
  27. 'green' : 1,
  28. 'blue' : 2,
  29. 'purple': 3,
  30. 'yellow': 4,
  31. 'grey' : 5
  32. }
  33. IDX_TO_COLOR = dict(zip(COLOR_TO_IDX.values(), COLOR_TO_IDX.keys()))
  34. # Map of object type to integers
  35. OBJECT_TO_IDX = {
  36. 'empty' : 0,
  37. 'wall' : 1,
  38. 'door' : 2,
  39. 'locked_door' : 3,
  40. 'key' : 4,
  41. 'ball' : 5,
  42. 'box' : 6,
  43. 'goal' : 7
  44. }
  45. IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys()))
  46. class WorldObj:
  47. """
  48. Base class for grid world objects
  49. """
  50. def __init__(self, type, color):
  51. assert type in OBJECT_TO_IDX, type
  52. assert color in COLOR_TO_IDX, color
  53. self.type = type
  54. self.color = color
  55. self.contains = None
  56. def can_overlap(self):
  57. """Can the agent overlap with this?"""
  58. return False
  59. def can_pickup(self):
  60. """Can the agent pick this up?"""
  61. return False
  62. def can_contain(self):
  63. """Can this contain another object?"""
  64. return False
  65. def see_behind(self):
  66. """Can the agent see behind this object?"""
  67. return True
  68. def toggle(self, env, pos):
  69. """Method to trigger/toggle an action this object performs"""
  70. return False
  71. def render(self, r):
  72. """Draw this object with the given renderer"""
  73. raise NotImplementedError
  74. def _set_color(self, r):
  75. """Set the color of this object as the active drawing color"""
  76. c = COLORS[self.color]
  77. r.setLineColor(c[0], c[1], c[2])
  78. r.setColor(c[0], c[1], c[2])
  79. class Goal(WorldObj):
  80. def __init__(self):
  81. super(Goal, self).__init__('goal', 'green')
  82. def can_overlap(self):
  83. return True
  84. def render(self, r):
  85. self._set_color(r)
  86. r.drawPolygon([
  87. (0 , CELL_PIXELS),
  88. (CELL_PIXELS, CELL_PIXELS),
  89. (CELL_PIXELS, 0),
  90. (0 , 0)
  91. ])
  92. class Wall(WorldObj):
  93. def __init__(self, color='grey'):
  94. super(Wall, self).__init__('wall', color)
  95. def see_behind(self):
  96. return False
  97. def render(self, r):
  98. self._set_color(r)
  99. r.drawPolygon([
  100. (0 , CELL_PIXELS),
  101. (CELL_PIXELS, CELL_PIXELS),
  102. (CELL_PIXELS, 0),
  103. (0 , 0)
  104. ])
  105. class Door(WorldObj):
  106. def __init__(self, color, is_open=False):
  107. super(Door, self).__init__('door', color)
  108. self.is_open = is_open
  109. def can_overlap(self):
  110. """The agent can only walk over this cell when the door is open"""
  111. return self.is_open
  112. def see_behind(self):
  113. return self.is_open
  114. def toggle(self, env, pos):
  115. if not self.is_open:
  116. self.is_open = True
  117. return True
  118. return False
  119. def render(self, r):
  120. c = COLORS[self.color]
  121. r.setLineColor(c[0], c[1], c[2])
  122. r.setColor(0, 0, 0)
  123. if self.is_open:
  124. r.drawPolygon([
  125. (CELL_PIXELS-2, CELL_PIXELS),
  126. (CELL_PIXELS , CELL_PIXELS),
  127. (CELL_PIXELS , 0),
  128. (CELL_PIXELS-2, 0)
  129. ])
  130. return
  131. r.drawPolygon([
  132. (0 , CELL_PIXELS),
  133. (CELL_PIXELS, CELL_PIXELS),
  134. (CELL_PIXELS, 0),
  135. (0 , 0)
  136. ])
  137. r.drawPolygon([
  138. (2 , CELL_PIXELS-2),
  139. (CELL_PIXELS-2, CELL_PIXELS-2),
  140. (CELL_PIXELS-2, 2),
  141. (2 , 2)
  142. ])
  143. r.drawCircle(CELL_PIXELS * 0.75, CELL_PIXELS * 0.5, 2)
  144. class LockedDoor(WorldObj):
  145. def __init__(self, color, is_open=False):
  146. super(LockedDoor, self).__init__('locked_door', color)
  147. self.is_open = is_open
  148. def toggle(self, env, pos):
  149. # If the player has the right key to open the door
  150. if isinstance(env.carrying, Key) and env.carrying.color == self.color:
  151. self.is_open = True
  152. # The key has been used, remove it from the agent
  153. env.carrying = None
  154. return True
  155. return False
  156. def can_overlap(self):
  157. """The agent can only walk over this cell when the door is open"""
  158. return self.is_open
  159. def see_behind(self):
  160. return self.is_open
  161. def render(self, r):
  162. c = COLORS[self.color]
  163. r.setLineColor(c[0], c[1], c[2])
  164. r.setColor(c[0], c[1], c[2], 50)
  165. if self.is_open:
  166. r.drawPolygon([
  167. (CELL_PIXELS-2, CELL_PIXELS),
  168. (CELL_PIXELS , CELL_PIXELS),
  169. (CELL_PIXELS , 0),
  170. (CELL_PIXELS-2, 0)
  171. ])
  172. return
  173. r.drawPolygon([
  174. (0 , CELL_PIXELS),
  175. (CELL_PIXELS, CELL_PIXELS),
  176. (CELL_PIXELS, 0),
  177. (0 , 0)
  178. ])
  179. r.drawPolygon([
  180. (2 , CELL_PIXELS-2),
  181. (CELL_PIXELS-2, CELL_PIXELS-2),
  182. (CELL_PIXELS-2, 2),
  183. (2 , 2)
  184. ])
  185. r.drawLine(
  186. CELL_PIXELS * 0.55,
  187. CELL_PIXELS * 0.5,
  188. CELL_PIXELS * 0.75,
  189. CELL_PIXELS * 0.5
  190. )
  191. class Key(WorldObj):
  192. def __init__(self, color='blue'):
  193. super(Key, self).__init__('key', color)
  194. def can_pickup(self):
  195. return True
  196. def render(self, r):
  197. self._set_color(r)
  198. # Vertical quad
  199. r.drawPolygon([
  200. (16, 10),
  201. (20, 10),
  202. (20, 28),
  203. (16, 28)
  204. ])
  205. # Teeth
  206. r.drawPolygon([
  207. (12, 19),
  208. (16, 19),
  209. (16, 21),
  210. (12, 21)
  211. ])
  212. r.drawPolygon([
  213. (12, 26),
  214. (16, 26),
  215. (16, 28),
  216. (12, 28)
  217. ])
  218. r.drawCircle(18, 9, 6)
  219. r.setLineColor(0, 0, 0)
  220. r.setColor(0, 0, 0)
  221. r.drawCircle(18, 9, 2)
  222. class Ball(WorldObj):
  223. def __init__(self, color='blue'):
  224. super(Ball, self).__init__('ball', color)
  225. def can_pickup(self):
  226. return True
  227. def render(self, r):
  228. self._set_color(r)
  229. r.drawCircle(CELL_PIXELS * 0.5, CELL_PIXELS * 0.5, 10)
  230. class Box(WorldObj):
  231. def __init__(self, color, contains=None):
  232. super(Box, self).__init__('box', color)
  233. self.contains = contains
  234. def can_pickup(self):
  235. return True
  236. def render(self, r):
  237. c = COLORS[self.color]
  238. r.setLineColor(c[0], c[1], c[2])
  239. r.setColor(0, 0, 0)
  240. r.setLineWidth(2)
  241. r.drawPolygon([
  242. (4 , CELL_PIXELS-4),
  243. (CELL_PIXELS-4, CELL_PIXELS-4),
  244. (CELL_PIXELS-4, 4),
  245. (4 , 4)
  246. ])
  247. r.drawLine(
  248. 4,
  249. CELL_PIXELS / 2,
  250. CELL_PIXELS - 4,
  251. CELL_PIXELS / 2
  252. )
  253. r.setLineWidth(1)
  254. def toggle(self, env, pos):
  255. # Replace the box by its contents
  256. env.grid.set(*pos, self.contains)
  257. return True
  258. class Grid:
  259. """
  260. Represent a grid and operations on it
  261. """
  262. def __init__(self, width, height):
  263. assert width >= 4
  264. assert height >= 4
  265. self.width = width
  266. self.height = height
  267. self.grid = [None] * width * height
  268. def __contains__(self, key):
  269. if isinstance(key, WorldObj):
  270. for e in self.grid:
  271. if e is key:
  272. return True
  273. elif isinstance(key, tuple):
  274. for e in self.grid:
  275. if e is None:
  276. continue
  277. if (e.color, e.type) == key:
  278. return True
  279. return False
  280. def __eq__(self, other):
  281. grid1 = self.encode()
  282. grid2 = other.encode()
  283. return np.array_equal(grid2, grid1)
  284. def __ne__(self, other):
  285. return not self == other
  286. def copy(self):
  287. from copy import deepcopy
  288. return deepcopy(self)
  289. def set(self, i, j, v):
  290. assert i >= 0 and i < self.width
  291. assert j >= 0 and j < self.height
  292. self.grid[j * self.width + i] = v
  293. def get(self, i, j):
  294. assert i >= 0 and i < self.width
  295. assert j >= 0 and j < self.height
  296. return self.grid[j * self.width + i]
  297. def horz_wall(self, x, y, length=None):
  298. if length is None:
  299. length = self.width - x
  300. for i in range(0, length):
  301. self.set(x + i, y, Wall())
  302. def vert_wall(self, x, y, length=None):
  303. if length is None:
  304. length = self.height - y
  305. for j in range(0, length):
  306. self.set(x, y + j, Wall())
  307. def wall_rect(self, x, y, w, h):
  308. self.horz_wall(x, y, w)
  309. self.horz_wall(x, y+h-1, w)
  310. self.vert_wall(x, y, h)
  311. self.vert_wall(x+w-1, y, h)
  312. def rotate_left(self):
  313. """
  314. Rotate the grid to the left (counter-clockwise)
  315. """
  316. grid = Grid(self.width, self.height)
  317. for j in range(0, self.height):
  318. for i in range(0, self.width):
  319. v = self.get(self.width - 1 - j, i)
  320. grid.set(i, j, v)
  321. return grid
  322. def slice(self, topX, topY, width, height):
  323. """
  324. Get a subset of the grid
  325. """
  326. grid = Grid(width, height)
  327. for j in range(0, height):
  328. for i in range(0, width):
  329. x = topX + i
  330. y = topY + j
  331. if x >= 0 and x < self.width and \
  332. y >= 0 and y < self.height:
  333. v = self.get(x, y)
  334. else:
  335. v = Wall()
  336. grid.set(i, j, v)
  337. return grid
  338. def render(self, r, tile_size):
  339. """
  340. Render this grid at a given scale
  341. :param r: target renderer object
  342. :param tile_size: tile size in pixels
  343. """
  344. assert r.width == self.width * tile_size
  345. assert r.height == self.height * tile_size
  346. # Total grid size at native scale
  347. widthPx = self.width * CELL_PIXELS
  348. heightPx = self.height * CELL_PIXELS
  349. """
  350. # Draw background (out-of-world) tiles the same colors as walls
  351. # so the agent understands these areas are not reachable
  352. c = COLORS['grey']
  353. r.setLineColor(c[0], c[1], c[2])
  354. r.setColor(c[0], c[1], c[2])
  355. r.drawPolygon([
  356. (0 , heightPx),
  357. (widthPx, heightPx),
  358. (widthPx, 0),
  359. (0 , 0)
  360. ])
  361. """
  362. r.push()
  363. # Internally, we draw at the "large" full-grid resolution, but we
  364. # use the renderer to scale back to the desired size
  365. r.scale(tile_size / CELL_PIXELS, tile_size / CELL_PIXELS)
  366. # Draw the background of the in-world cells black
  367. r.fillRect(
  368. 0,
  369. 0,
  370. widthPx,
  371. heightPx,
  372. 0, 0, 0
  373. )
  374. # Draw grid lines
  375. r.setLineColor(100, 100, 100)
  376. for rowIdx in range(0, self.height):
  377. y = CELL_PIXELS * rowIdx
  378. r.drawLine(0, y, widthPx, y)
  379. for colIdx in range(0, self.width):
  380. x = CELL_PIXELS * colIdx
  381. r.drawLine(x, 0, x, heightPx)
  382. # Render the grid
  383. for j in range(0, self.height):
  384. for i in range(0, self.width):
  385. cell = self.get(i, j)
  386. if cell == None:
  387. continue
  388. r.push()
  389. r.translate(i * CELL_PIXELS, j * CELL_PIXELS)
  390. cell.render(r)
  391. r.pop()
  392. r.pop()
  393. def encode(self):
  394. """
  395. Produce a compact numpy encoding of the grid
  396. """
  397. codeSize = self.width * self.height * 3
  398. array = np.zeros(shape=(self.width, self.height, 3), dtype='uint8')
  399. for j in range(0, self.height):
  400. for i in range(0, self.width):
  401. v = self.get(i, j)
  402. if v == None:
  403. continue
  404. array[i, j, 0] = OBJECT_TO_IDX[v.type]
  405. array[i, j, 1] = COLOR_TO_IDX[v.color]
  406. if hasattr(v, 'is_open') and v.is_open:
  407. array[i, j, 2] = 1
  408. return array
  409. def decode(array):
  410. """
  411. Decode an array grid encoding back into a grid
  412. """
  413. width = array.shape[0]
  414. height = array.shape[1]
  415. assert array.shape[2] == 3
  416. grid = Grid(width, height)
  417. for j in range(0, height):
  418. for i in range(0, width):
  419. typeIdx = array[i, j, 0]
  420. colorIdx = array[i, j, 1]
  421. openIdx = array[i, j, 2]
  422. if typeIdx == 0:
  423. continue
  424. objType = IDX_TO_OBJECT[typeIdx]
  425. color = IDX_TO_COLOR[colorIdx]
  426. is_open = True if openIdx == 1 else 0
  427. if objType == 'wall':
  428. v = Wall(color)
  429. elif objType == 'ball':
  430. v = Ball(color)
  431. elif objType == 'key':
  432. v = Key(color)
  433. elif objType == 'box':
  434. v = Box(color)
  435. elif objType == 'door':
  436. v = Door(color, is_open)
  437. elif objType == 'locked_door':
  438. v = LockedDoor(color, is_open)
  439. elif objType == 'goal':
  440. v = Goal()
  441. else:
  442. assert False, "unknown obj type in decode '%s'" % objType
  443. grid.set(i, j, v)
  444. return grid
  445. def process_vis(
  446. grid,
  447. agent_pos,
  448. n_rays = 32,
  449. n_steps = 24,
  450. a_min = math.pi,
  451. a_max = 2 * math.pi
  452. ):
  453. """
  454. Use ray casting to determine the visibility of each grid cell
  455. """
  456. mask = np.zeros(shape=(grid.width, grid.height), dtype=np.bool)
  457. ang_step = (a_max - a_min) / n_rays
  458. dst_step = math.sqrt(grid.width ** 2 + grid.height ** 2) / n_steps
  459. ax = agent_pos[0] + 0.5
  460. ay = agent_pos[1] + 0.5
  461. for ray_idx in range(0, n_rays):
  462. angle = a_min + ang_step * ray_idx
  463. dx = dst_step * math.cos(angle)
  464. dy = dst_step * math.sin(angle)
  465. for step_idx in range(0, n_steps):
  466. x = ax + (step_idx * dx)
  467. y = ay + (step_idx * dy)
  468. i = math.floor(x)
  469. j = math.floor(y)
  470. # If we're outside of the grid, stop
  471. if i < 0 or i >= grid.width or j < 0 or j >= grid.height:
  472. break
  473. # Mark this cell as visible
  474. mask[i, j] = True
  475. # If we hit the obstructor, stop
  476. cell = grid.get(i, j)
  477. if cell and not cell.see_behind():
  478. break
  479. for j in range(0, grid.height):
  480. for i in range(0, grid.width):
  481. if not mask[i, j]:
  482. grid.set(i, j, None)
  483. #grid.set(i, j, Wall('red'))
  484. return mask
  485. def process_vis_prop(
  486. grid,
  487. agent_pos
  488. ):
  489. mask = np.zeros(shape=(grid.width, grid.height), dtype=np.bool)
  490. mask[agent_pos[0], agent_pos[1]] = True
  491. for j in reversed(range(1, grid.height)):
  492. for i in range(0, grid.width-1):
  493. if not mask[i, j]:
  494. continue
  495. cell = grid.get(i, j)
  496. if cell and not cell.see_behind():
  497. continue
  498. mask[i+1, j] = True
  499. mask[i+1, j-1] = True
  500. mask[i, j-1] = True
  501. for i in reversed(range(1, grid.width)):
  502. if not mask[i, j]:
  503. continue
  504. cell = grid.get(i, j)
  505. if cell and not cell.see_behind():
  506. continue
  507. mask[i-1, j-1] = True
  508. mask[i-1, j] = True
  509. mask[i, j-1] = True
  510. for j in range(0, grid.height):
  511. for i in range(0, grid.width):
  512. if not mask[i, j]:
  513. grid.set(i, j, None)
  514. #grid.set(i, j, Wall('red'))
  515. class MiniGridEnv(gym.Env):
  516. """
  517. 2D grid world game environment
  518. """
  519. metadata = {
  520. 'render.modes': ['human', 'rgb_array', 'pixmap'],
  521. 'video.frames_per_second' : 10
  522. }
  523. # Enumeration of possible actions
  524. class Actions(IntEnum):
  525. # Turn left, turn right, move forward
  526. left = 0
  527. right = 1
  528. forward = 2
  529. # Pick up an object
  530. pickup = 3
  531. # Drop an object
  532. drop = 4
  533. # Toggle/activate an object
  534. toggle = 5
  535. # Wait/stay put/do nothing
  536. wait = 6
  537. def __init__(
  538. self,
  539. grid_size=16,
  540. max_steps=100,
  541. see_through_walls=False,
  542. seed=None
  543. ):
  544. # Action enumeration for this environment
  545. self.actions = MiniGridEnv.Actions
  546. # Actions are discrete integer values
  547. self.action_space = spaces.Discrete(len(self.actions))
  548. # Observations are dictionaries containing an
  549. # encoding of the grid and a textual 'mission' string
  550. self.observation_space = spaces.Box(
  551. low=0,
  552. high=255,
  553. shape=OBS_ARRAY_SIZE,
  554. dtype='uint8'
  555. )
  556. self.observation_space = spaces.Dict({
  557. 'image': self.observation_space
  558. })
  559. # Range of possible rewards
  560. self.reward_range = (-1, 1000)
  561. # Renderer object used to render the whole grid (full-scale)
  562. self.grid_render = None
  563. # Renderer used to render observations (small-scale agent view)
  564. self.obs_render = None
  565. # Environment configuration
  566. self.grid_size = grid_size
  567. self.max_steps = max_steps
  568. self.see_through_walls = see_through_walls
  569. # Starting position and direction for the agent
  570. self.start_pos = None
  571. self.start_dir = None
  572. # Initialize the RNG
  573. self.seed(seed=seed)
  574. # Initialize the state
  575. self.reset()
  576. def reset(self):
  577. # Generate a new random grid at the start of each episode
  578. # To keep the same grid for each episode, call env.seed() with
  579. # the same seed before calling env.reset()
  580. self._gen_grid(self.grid_size, self.grid_size)
  581. # These fields should be defined by _gen_grid
  582. assert self.start_pos != None
  583. assert self.start_dir != None
  584. # Check that the agent doesn't overlap with an object
  585. assert self.grid.get(*self.start_pos) is None
  586. # Place the agent in the starting position and direction
  587. self.agent_pos = self.start_pos
  588. self.agent_dir = self.start_dir
  589. # Item picked up, being carried, initially nothing
  590. self.carrying = None
  591. # Step count since episode start
  592. self.step_count = 0
  593. # Return first observation
  594. obs = self.gen_obs()
  595. return obs
  596. def seed(self, seed=1337):
  597. # Seed the random number generator
  598. self.np_random, _ = seeding.np_random(seed)
  599. return [seed]
  600. @property
  601. def steps_remaining(self):
  602. return self.max_steps - self.step_count
  603. def __str__(self):
  604. """
  605. Produce a pretty string of the environment's grid along with the agent.
  606. The agent is represented by `⏩`. A grid pixel is represented by 2-character
  607. string, the first one for the object and the second one for the color.
  608. """
  609. from copy import deepcopy
  610. def rotate_left(array):
  611. new_array = deepcopy(array)
  612. for i in range(len(array)):
  613. for j in range(len(array[0])):
  614. new_array[j][len(array[0])-1-i] = array[i][j]
  615. return new_array
  616. def vertically_symmetrize(array):
  617. new_array = deepcopy(array)
  618. for i in range(len(array)):
  619. for j in range(len(array[0])):
  620. new_array[i][len(array[0])-1-j] = array[i][j]
  621. return new_array
  622. # Map of object id to short string
  623. OBJECT_IDX_TO_IDS = {
  624. 0: ' ',
  625. 1: 'W',
  626. 2: 'D',
  627. 3: 'L',
  628. 4: 'K',
  629. 5: 'B',
  630. 6: 'X',
  631. 7: 'G'
  632. }
  633. # Short string for opened door
  634. OPENDED_DOOR_IDS = '_'
  635. # Map of color id to short string
  636. COLOR_IDX_TO_IDS = {
  637. 0: 'R',
  638. 1: 'G',
  639. 2: 'B',
  640. 3: 'P',
  641. 4: 'Y',
  642. 5: 'E'
  643. }
  644. # Map agent's direction to short string
  645. AGENT_DIR_TO_IDS = {
  646. 0: '⏩',
  647. 1: '⏬',
  648. 2: '⏪',
  649. 3: '⏫'
  650. }
  651. array = self.grid.encode()
  652. array = rotate_left(array)
  653. array = vertically_symmetrize(array)
  654. new_array = []
  655. for line in array:
  656. new_line = []
  657. for pixel in line:
  658. # If the door is opened
  659. if pixel[0] in [2, 3] and pixel[2] == 1:
  660. object_ids = OPENDED_DOOR_IDS
  661. else:
  662. object_ids = OBJECT_IDX_TO_IDS[pixel[0]]
  663. # If no object
  664. if pixel[0] == 0:
  665. color_ids = ' '
  666. else:
  667. color_ids = COLOR_IDX_TO_IDS[pixel[1]]
  668. new_line.append(object_ids + color_ids)
  669. new_array.append(new_line)
  670. # Add the agent
  671. new_array[self.agent_pos[1]][self.agent_pos[0]] = AGENT_DIR_TO_IDS[self.agent_dir]
  672. return "\n".join([" ".join(line) for line in new_array])
  673. def _gen_grid(self, width, height):
  674. assert False, "_gen_grid needs to be implemented by each environment"
  675. def _rand_int(self, low, high):
  676. """
  677. Generate random integer in [low,high[
  678. """
  679. return self.np_random.randint(low, high)
  680. def _rand_elem(self, iterable):
  681. """
  682. Pick a random element in a list
  683. """
  684. lst = list(iterable)
  685. idx = self._rand_int(0, len(lst))
  686. return lst[idx]
  687. def _rand_pos(self, xLow, xHigh, yLow, yHigh):
  688. """
  689. Generate a random (x,y) position tuple
  690. """
  691. return (
  692. self.np_random.randint(xLow, xHigh),
  693. self.np_random.randint(yLow, yHigh)
  694. )
  695. def place_obj(self, obj, top=None, size=None, reject_fn=None):
  696. """
  697. Place an object at an empty position in the grid
  698. :param top: top-left position of the rectangle where to place
  699. :param size: size of the rectangle where to place
  700. :param reject_fn: function to filter out potential positions
  701. """
  702. if top is None:
  703. top = (0, 0)
  704. if size is None:
  705. size = (self.grid.width, self.grid.height)
  706. while True:
  707. pos = (
  708. self._rand_int(top[0], top[0] + size[0]),
  709. self._rand_int(top[1], top[1] + size[1])
  710. )
  711. # Don't place the object on top of another object
  712. if self.grid.get(*pos) != None:
  713. continue
  714. # Don't place the object where the agent is
  715. if pos == self.start_pos:
  716. continue
  717. # Check if there is a filtering criterion
  718. if reject_fn and reject_fn(self, pos):
  719. continue
  720. break
  721. self.grid.set(*pos, obj)
  722. return pos
  723. def place_agent(self, top=None, size=None, rand_dir=True):
  724. """
  725. Set the agent's starting point at an empty position in the grid
  726. """
  727. pos = self.place_obj(None, top, size)
  728. self.start_pos = pos
  729. if rand_dir:
  730. self.start_dir = self._rand_int(0, 4)
  731. return pos
  732. def get_dir_vec(self):
  733. """
  734. Get the direction vector for the agent, pointing in the direction
  735. of forward movement.
  736. """
  737. # Pointing right
  738. if self.agent_dir == 0:
  739. return (1, 0)
  740. # Down (positive Y)
  741. elif self.agent_dir == 1:
  742. return (0, 1)
  743. # Pointing left
  744. elif self.agent_dir == 2:
  745. return (-1, 0)
  746. # Up (negative Y)
  747. elif self.agent_dir == 3:
  748. return (0, -1)
  749. else:
  750. assert False
  751. def get_right_vec(self):
  752. """
  753. Get the vector pointing to the right of the agent.
  754. """
  755. dx, dy = self.get_dir_vec()
  756. return -dy, dx
  757. def get_view_coords(self, i, j):
  758. """
  759. Translate and rotate absolute grid coordinates (i, j) into the
  760. agent's partially observable view (sub-grid). Note that the resulting
  761. coordinates may be negative or outside of the agent's view size.
  762. """
  763. ax, ay = self.agent_pos
  764. dx, dy = self.get_dir_vec()
  765. rx, ry = self.get_right_vec()
  766. # Compute the absolute coordinates of the top-left view corner
  767. sz = AGENT_VIEW_SIZE
  768. hs = AGENT_VIEW_SIZE // 2
  769. tx = ax + (dx * (sz-1)) - (rx * hs)
  770. ty = ay + (dy * (sz-1)) - (ry * hs)
  771. lx = i - tx
  772. ly = j - ty
  773. # Project the coordinates of the object relative to the top-left
  774. # corner onto the agent's own coordinate system
  775. vx = (rx*lx + ry*ly)
  776. vy = -(dx*lx + dy*ly)
  777. return vx, vy
  778. def get_view_exts(self):
  779. """
  780. Get the extents of the square set of tiles visible to the agent
  781. Note: the bottom extent indices are not included in the set
  782. """
  783. # Facing right
  784. if self.agent_dir == 0:
  785. topX = self.agent_pos[0]
  786. topY = self.agent_pos[1] - AGENT_VIEW_SIZE // 2
  787. # Facing down
  788. elif self.agent_dir == 1:
  789. topX = self.agent_pos[0] - AGENT_VIEW_SIZE // 2
  790. topY = self.agent_pos[1]
  791. # Facing left
  792. elif self.agent_dir == 2:
  793. topX = self.agent_pos[0] - AGENT_VIEW_SIZE + 1
  794. topY = self.agent_pos[1] - AGENT_VIEW_SIZE // 2
  795. # Facing up
  796. elif self.agent_dir == 3:
  797. topX = self.agent_pos[0] - AGENT_VIEW_SIZE // 2
  798. topY = self.agent_pos[1] - AGENT_VIEW_SIZE + 1
  799. else:
  800. assert False, "invalid agent direction"
  801. botX = topX + AGENT_VIEW_SIZE
  802. botY = topY + AGENT_VIEW_SIZE
  803. return (topX, topY, botX, botY)
  804. def agent_sees(self, x, y):
  805. """
  806. Check if a grid position is visible to the agent
  807. """
  808. vx, vy = self.get_view_coords(x, y)
  809. if vx < 0 or vy < 0 or vx >= AGENT_VIEW_SIZE or vy >= AGENT_VIEW_SIZE:
  810. return False
  811. obs = self.gen_obs()
  812. obs_grid = Grid.decode(obs['image'])
  813. obs_cell = obs_grid.get(vx, vy)
  814. world_cell = self.grid.get(x, y)
  815. return obs_cell is not None and obs_cell.type == world_cell.type
  816. def step(self, action):
  817. self.step_count += 1
  818. reward = 0
  819. done = False
  820. # Get the position in front of the agent
  821. u, v = self.get_dir_vec()
  822. fwd_pos = (self.agent_pos[0] + u, self.agent_pos[1] + v)
  823. # Get the contents of the cell in front of the agent
  824. fwd_cell = self.grid.get(*fwd_pos)
  825. # Rotate left
  826. if action == self.actions.left:
  827. self.agent_dir -= 1
  828. if self.agent_dir < 0:
  829. self.agent_dir += 4
  830. # Rotate right
  831. elif action == self.actions.right:
  832. self.agent_dir = (self.agent_dir + 1) % 4
  833. # Move forward
  834. elif action == self.actions.forward:
  835. if fwd_cell == None or fwd_cell.can_overlap():
  836. self.agent_pos = fwd_pos
  837. if fwd_cell != None and fwd_cell.type == 'goal':
  838. done = True
  839. reward = 1000 - self.step_count
  840. # Pick up an object
  841. elif action == self.actions.pickup:
  842. if fwd_cell and fwd_cell.can_pickup():
  843. if self.carrying is None:
  844. self.carrying = fwd_cell
  845. self.grid.set(*fwd_pos, None)
  846. # Drop an object
  847. elif action == self.actions.drop:
  848. if not fwd_cell and self.carrying:
  849. self.grid.set(*fwd_pos, self.carrying)
  850. self.carrying = None
  851. # Toggle/activate an object
  852. elif action == self.actions.toggle:
  853. if fwd_cell:
  854. fwd_cell.toggle(self, fwd_pos)
  855. # Wait/do nothing
  856. elif action == self.actions.wait:
  857. pass
  858. else:
  859. assert False, "unknown action"
  860. if self.step_count >= self.max_steps:
  861. done = True
  862. obs = self.gen_obs()
  863. return obs, reward, done, {}
  864. def gen_obs(self):
  865. """
  866. Generate the agent's view (partially observable, low-resolution encoding)
  867. """
  868. topX, topY, botX, botY = self.get_view_exts()
  869. grid = self.grid.slice(topX, topY, AGENT_VIEW_SIZE, AGENT_VIEW_SIZE)
  870. for i in range(self.agent_dir + 1):
  871. grid = grid.rotate_left()
  872. # Make it so the agent sees what it's carrying
  873. # We do this by placing the carried object at the agent's position
  874. # in the agent's partially observable view
  875. agent_pos = grid.width // 2, grid.height - 1
  876. if self.carrying:
  877. grid.set(*agent_pos, self.carrying)
  878. else:
  879. grid.set(*agent_pos, None)
  880. # Process occluders and visibility
  881. # Note that this incurs some performance cost
  882. if not self.see_through_walls:
  883. grid.process_vis_prop(agent_pos=(3, 6))
  884. # Encode the partially observable view into a numpy array
  885. image = grid.encode()
  886. assert hasattr(self, 'mission'), "environments must define a textual mission string"
  887. # Observations are dictionaries containing:
  888. # - an image (partially observable view of the environment)
  889. # - the agent's direction/orientation (acting as a compass)
  890. # - a textual mission string (instructions for the agent)
  891. obs = {
  892. 'image': image,
  893. 'direction': self.agent_dir,
  894. 'mission': self.mission
  895. }
  896. return obs
  897. def get_obs_render(self, obs):
  898. """
  899. Render an agent observation for visualization
  900. """
  901. if self.obs_render == None:
  902. self.obs_render = Renderer(
  903. AGENT_VIEW_SIZE * CELL_PIXELS // 2,
  904. AGENT_VIEW_SIZE * CELL_PIXELS // 2
  905. )
  906. r = self.obs_render
  907. r.beginFrame()
  908. grid = Grid.decode(obs)
  909. # Render the whole grid
  910. grid.render(r, CELL_PIXELS // 2)
  911. # Draw the agent
  912. r.push()
  913. r.scale(0.5, 0.5)
  914. r.translate(
  915. CELL_PIXELS * (0.5 + AGENT_VIEW_SIZE // 2),
  916. CELL_PIXELS * (AGENT_VIEW_SIZE - 0.5)
  917. )
  918. r.rotate(3 * 90)
  919. r.setLineColor(255, 0, 0)
  920. r.setColor(255, 0, 0)
  921. r.drawPolygon([
  922. (-12, 10),
  923. ( 12, 0),
  924. (-12, -10)
  925. ])
  926. r.pop()
  927. r.endFrame()
  928. return r.getPixmap()
  929. def render(self, mode='human', close=False):
  930. """
  931. Render the whole-grid human view
  932. """
  933. if close:
  934. if self.grid_render:
  935. self.grid_render.close()
  936. return
  937. if self.grid_render is None:
  938. self.grid_render = Renderer(
  939. self.grid_size * CELL_PIXELS,
  940. self.grid_size * CELL_PIXELS,
  941. True if mode == 'human' else False
  942. )
  943. r = self.grid_render
  944. r.beginFrame()
  945. # Render the whole grid
  946. self.grid.render(r, CELL_PIXELS)
  947. # Draw the agent
  948. r.push()
  949. r.translate(
  950. CELL_PIXELS * (self.agent_pos[0] + 0.5),
  951. CELL_PIXELS * (self.agent_pos[1] + 0.5)
  952. )
  953. r.rotate(self.agent_dir * 90)
  954. r.setLineColor(255, 0, 0)
  955. r.setColor(255, 0, 0)
  956. r.drawPolygon([
  957. (-12, 10),
  958. ( 12, 0),
  959. (-12, -10)
  960. ])
  961. r.pop()
  962. # Highlight what the agent can see
  963. topX, topY, botX, botY = self.get_view_exts()
  964. r.fillRect(
  965. topX * CELL_PIXELS,
  966. topY * CELL_PIXELS,
  967. AGENT_VIEW_SIZE * CELL_PIXELS,
  968. AGENT_VIEW_SIZE * CELL_PIXELS,
  969. 200, 200, 200, 75
  970. )
  971. r.endFrame()
  972. if mode == 'rgb_array':
  973. return r.getArray()
  974. elif mode == 'pixmap':
  975. return r.getPixmap()
  976. return r