minigrid.py 31 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178
  1. import math
  2. import gym
  3. from enum import IntEnum
  4. import numpy as np
  5. from gym import error, spaces, utils
  6. from gym.utils import seeding
  7. from gym_minigrid.rendering import *
  8. # Size in pixels of a cell in the full-scale human view
  9. CELL_PIXELS = 32
  10. # Number of cells (width and height) in the agent view
  11. AGENT_VIEW_SIZE = 7
  12. # Size of the array given as an observation to the agent
  13. OBS_ARRAY_SIZE = (AGENT_VIEW_SIZE, AGENT_VIEW_SIZE, 3)
  14. # Map of color names to RGB values
  15. COLORS = {
  16. 'red' : (255, 0, 0),
  17. 'green' : (0, 255, 0),
  18. 'blue' : (0, 0, 255),
  19. 'purple': (112, 39, 195),
  20. 'yellow': (255, 255, 0),
  21. 'grey' : (100, 100, 100)
  22. }
  23. COLOR_NAMES = sorted(list(COLORS.keys()))
  24. # Used to map colors to integers
  25. COLOR_TO_IDX = {
  26. 'red' : 0,
  27. 'green' : 1,
  28. 'blue' : 2,
  29. 'purple': 3,
  30. 'yellow': 4,
  31. 'grey' : 5
  32. }
  33. IDX_TO_COLOR = dict(zip(COLOR_TO_IDX.values(), COLOR_TO_IDX.keys()))
  34. # Map of object type to integers
  35. OBJECT_TO_IDX = {
  36. 'empty' : 0,
  37. 'wall' : 1,
  38. 'door' : 2,
  39. 'locked_door' : 3,
  40. 'key' : 4,
  41. 'ball' : 5,
  42. 'box' : 6,
  43. 'goal' : 7
  44. }
  45. IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys()))
  46. class WorldObj:
  47. """
  48. Base class for grid world objects
  49. """
  50. def __init__(self, type, color):
  51. assert type in OBJECT_TO_IDX, type
  52. assert color in COLOR_TO_IDX, color
  53. self.type = type
  54. self.color = color
  55. self.contains = None
  56. def can_overlap(self):
  57. """Can the agent overlap with this?"""
  58. return False
  59. def canPickup(self):
  60. """Can the agent pick this up?"""
  61. return False
  62. def canContain(self):
  63. """Can this contain another object?"""
  64. return False
  65. def see_behind(self):
  66. """Can the agent see behind this object?"""
  67. return True
  68. def toggle(self, env, pos):
  69. """Method to trigger/toggle an action this object performs"""
  70. return False
  71. def render(self, r):
  72. assert False
  73. def _set_color(self, r):
  74. c = COLORS[self.color]
  75. r.setLineColor(c[0], c[1], c[2])
  76. r.setColor(c[0], c[1], c[2])
  77. class Goal(WorldObj):
  78. def __init__(self):
  79. super(Goal, self).__init__('goal', 'green')
  80. def can_overlap(self):
  81. return True
  82. def render(self, r):
  83. self._set_color(r)
  84. r.drawPolygon([
  85. (0 , CELL_PIXELS),
  86. (CELL_PIXELS, CELL_PIXELS),
  87. (CELL_PIXELS, 0),
  88. (0 , 0)
  89. ])
  90. class Wall(WorldObj):
  91. def __init__(self, color='grey'):
  92. super(Wall, self).__init__('wall', color)
  93. def see_behind(self):
  94. return False
  95. def render(self, r):
  96. self._set_color(r)
  97. r.drawPolygon([
  98. (0 , CELL_PIXELS),
  99. (CELL_PIXELS, CELL_PIXELS),
  100. (CELL_PIXELS, 0),
  101. (0 , 0)
  102. ])
  103. class Door(WorldObj):
  104. def __init__(self, color, is_open=False):
  105. super(Door, self).__init__('door', color)
  106. self.is_open = is_open
  107. def can_overlap(self):
  108. """The agent can only walk over this cell when the door is open"""
  109. return self.is_open
  110. def see_behind(self):
  111. return self.is_open
  112. def toggle(self, env, pos):
  113. if not self.is_open:
  114. self.is_open = True
  115. return True
  116. return False
  117. def render(self, r):
  118. c = COLORS[self.color]
  119. r.setLineColor(c[0], c[1], c[2])
  120. r.setColor(0, 0, 0)
  121. if self.is_open:
  122. r.drawPolygon([
  123. (CELL_PIXELS-2, CELL_PIXELS),
  124. (CELL_PIXELS , CELL_PIXELS),
  125. (CELL_PIXELS , 0),
  126. (CELL_PIXELS-2, 0)
  127. ])
  128. return
  129. r.drawPolygon([
  130. (0 , CELL_PIXELS),
  131. (CELL_PIXELS, CELL_PIXELS),
  132. (CELL_PIXELS, 0),
  133. (0 , 0)
  134. ])
  135. r.drawPolygon([
  136. (2 , CELL_PIXELS-2),
  137. (CELL_PIXELS-2, CELL_PIXELS-2),
  138. (CELL_PIXELS-2, 2),
  139. (2 , 2)
  140. ])
  141. r.drawCircle(CELL_PIXELS * 0.75, CELL_PIXELS * 0.5, 2)
  142. class LockedDoor(WorldObj):
  143. def __init__(self, color, is_open=False):
  144. super(LockedDoor, self).__init__('locked_door', color)
  145. self.is_open = is_open
  146. def toggle(self, env, pos):
  147. # If the player has the right key to open the door
  148. if isinstance(env.carrying, Key) and env.carrying.color == self.color:
  149. self.is_open = True
  150. # The key has been used, remove it from the agent
  151. env.carrying = None
  152. return True
  153. return False
  154. def can_overlap(self):
  155. """The agent can only walk over this cell when the door is open"""
  156. return self.is_open
  157. def render(self, r):
  158. c = COLORS[self.color]
  159. r.setLineColor(c[0], c[1], c[2])
  160. r.setColor(c[0], c[1], c[2], 50)
  161. if self.is_open:
  162. r.drawPolygon([
  163. (CELL_PIXELS-2, CELL_PIXELS),
  164. (CELL_PIXELS , CELL_PIXELS),
  165. (CELL_PIXELS , 0),
  166. (CELL_PIXELS-2, 0)
  167. ])
  168. return
  169. r.drawPolygon([
  170. (0 , CELL_PIXELS),
  171. (CELL_PIXELS, CELL_PIXELS),
  172. (CELL_PIXELS, 0),
  173. (0 , 0)
  174. ])
  175. r.drawPolygon([
  176. (2 , CELL_PIXELS-2),
  177. (CELL_PIXELS-2, CELL_PIXELS-2),
  178. (CELL_PIXELS-2, 2),
  179. (2 , 2)
  180. ])
  181. r.drawLine(
  182. CELL_PIXELS * 0.55,
  183. CELL_PIXELS * 0.5,
  184. CELL_PIXELS * 0.75,
  185. CELL_PIXELS * 0.5
  186. )
  187. class Key(WorldObj):
  188. def __init__(self, color='blue'):
  189. super(Key, self).__init__('key', color)
  190. def canPickup(self):
  191. return True
  192. def render(self, r):
  193. self._set_color(r)
  194. # Vertical quad
  195. r.drawPolygon([
  196. (16, 10),
  197. (20, 10),
  198. (20, 28),
  199. (16, 28)
  200. ])
  201. # Teeth
  202. r.drawPolygon([
  203. (12, 19),
  204. (16, 19),
  205. (16, 21),
  206. (12, 21)
  207. ])
  208. r.drawPolygon([
  209. (12, 26),
  210. (16, 26),
  211. (16, 28),
  212. (12, 28)
  213. ])
  214. r.drawCircle(18, 9, 6)
  215. r.setLineColor(0, 0, 0)
  216. r.setColor(0, 0, 0)
  217. r.drawCircle(18, 9, 2)
  218. class Ball(WorldObj):
  219. def __init__(self, color='blue'):
  220. super(Ball, self).__init__('ball', color)
  221. def canPickup(self):
  222. return True
  223. def render(self, r):
  224. self._set_color(r)
  225. r.drawCircle(CELL_PIXELS * 0.5, CELL_PIXELS * 0.5, 10)
  226. class Box(WorldObj):
  227. def __init__(self, color, contains=None):
  228. super(Box, self).__init__('box', color)
  229. self.contains = contains
  230. def canPickup(self):
  231. return True
  232. def render(self, r):
  233. c = COLORS[self.color]
  234. r.setLineColor(c[0], c[1], c[2])
  235. r.setColor(0, 0, 0)
  236. r.setLineWidth(2)
  237. r.drawPolygon([
  238. (4 , CELL_PIXELS-4),
  239. (CELL_PIXELS-4, CELL_PIXELS-4),
  240. (CELL_PIXELS-4, 4),
  241. (4 , 4)
  242. ])
  243. r.drawLine(
  244. 4,
  245. CELL_PIXELS / 2,
  246. CELL_PIXELS - 4,
  247. CELL_PIXELS / 2
  248. )
  249. r.setLineWidth(1)
  250. def toggle(self, env, pos):
  251. # Replace the box by its contents
  252. env.grid.set(*pos, self.contains)
  253. return True
  254. class Grid:
  255. """
  256. Represent a grid and operations on it
  257. """
  258. def __init__(self, width, height):
  259. assert width >= 4
  260. assert height >= 4
  261. self.width = width
  262. self.height = height
  263. self.grid = [None] * width * height
  264. def __contains__(self, key):
  265. if isinstance(key, WorldObj):
  266. for e in self.grid:
  267. if e is key:
  268. return True
  269. elif isinstance(key, tuple):
  270. for e in self.grid:
  271. if e is None:
  272. continue
  273. if (e.color, e.type) == key:
  274. return True
  275. return False
  276. def __eq__(self, other):
  277. grid1 = self.encode()
  278. grid2 = other.encode()
  279. return np.array_equal(grid2, grid1)
  280. def __ne__(self, other):
  281. return not self == other
  282. def copy(self):
  283. from copy import deepcopy
  284. return deepcopy(self)
  285. def set(self, i, j, v):
  286. assert i >= 0 and i < self.width
  287. assert j >= 0 and j < self.height
  288. self.grid[j * self.width + i] = v
  289. def get(self, i, j):
  290. assert i >= 0 and i < self.width
  291. assert j >= 0 and j < self.height
  292. return self.grid[j * self.width + i]
  293. def horzWall(self, x, y, length=None):
  294. if length is None:
  295. length = self.width - x
  296. for i in range(0, length):
  297. self.set(x + i, y, Wall())
  298. def vertWall(self, x, y, length=None):
  299. if length is None:
  300. length = self.height - y
  301. for j in range(0, length):
  302. self.set(x, y + j, Wall())
  303. def wallRect(self, x, y, w, h):
  304. self.horzWall(x, y, w)
  305. self.horzWall(x, y+h-1, w)
  306. self.vertWall(x, y, h)
  307. self.vertWall(x+w-1, y, h)
  308. def rotateLeft(self):
  309. """
  310. Rotate the grid to the left (counter-clockwise)
  311. """
  312. grid = Grid(self.width, self.height)
  313. for j in range(0, self.height):
  314. for i in range(0, self.width):
  315. v = self.get(self.width - 1 - j, i)
  316. grid.set(i, j, v)
  317. return grid
  318. def slice(self, topX, topY, width, height):
  319. """
  320. Get a subset of the grid
  321. """
  322. grid = Grid(width, height)
  323. for j in range(0, height):
  324. for i in range(0, width):
  325. x = topX + i
  326. y = topY + j
  327. if x >= 0 and x < self.width and \
  328. y >= 0 and y < self.height:
  329. v = self.get(x, y)
  330. else:
  331. v = Wall()
  332. grid.set(i, j, v)
  333. return grid
  334. def render(self, r, tileSize):
  335. """
  336. Render this grid at a given scale
  337. :param r: target renderer object
  338. :param tileSize: tile size in pixels
  339. """
  340. assert r.width == self.width * tileSize
  341. assert r.height == self.height * tileSize
  342. # Total grid size at native scale
  343. widthPx = self.width * CELL_PIXELS
  344. heightPx = self.height * CELL_PIXELS
  345. """
  346. # Draw background (out-of-world) tiles the same colors as walls
  347. # so the agent understands these areas are not reachable
  348. c = COLORS['grey']
  349. r.setLineColor(c[0], c[1], c[2])
  350. r.setColor(c[0], c[1], c[2])
  351. r.drawPolygon([
  352. (0 , heightPx),
  353. (widthPx, heightPx),
  354. (widthPx, 0),
  355. (0 , 0)
  356. ])
  357. """
  358. r.push()
  359. # Internally, we draw at the "large" full-grid resolution, but we
  360. # use the renderer to scale back to the desired size
  361. r.scale(tileSize / CELL_PIXELS, tileSize / CELL_PIXELS)
  362. # Draw the background of the in-world cells black
  363. r.fillRect(
  364. 0,
  365. 0,
  366. widthPx,
  367. heightPx,
  368. 0, 0, 0
  369. )
  370. # Draw grid lines
  371. r.setLineColor(100, 100, 100)
  372. for rowIdx in range(0, self.height):
  373. y = CELL_PIXELS * rowIdx
  374. r.drawLine(0, y, widthPx, y)
  375. for colIdx in range(0, self.width):
  376. x = CELL_PIXELS * colIdx
  377. r.drawLine(x, 0, x, heightPx)
  378. # Render the grid
  379. for j in range(0, self.height):
  380. for i in range(0, self.width):
  381. cell = self.get(i, j)
  382. if cell == None:
  383. continue
  384. r.push()
  385. r.translate(i * CELL_PIXELS, j * CELL_PIXELS)
  386. cell.render(r)
  387. r.pop()
  388. r.pop()
  389. def encode(self):
  390. """
  391. Produce a compact numpy encoding of the grid
  392. """
  393. codeSize = self.width * self.height * 3
  394. array = np.zeros(shape=(self.width, self.height, 3), dtype='uint8')
  395. for j in range(0, self.height):
  396. for i in range(0, self.width):
  397. v = self.get(i, j)
  398. if v == None:
  399. continue
  400. array[i, j, 0] = OBJECT_TO_IDX[v.type]
  401. array[i, j, 1] = COLOR_TO_IDX[v.color]
  402. if hasattr(v, 'is_open') and v.is_open:
  403. array[i, j, 2] = 1
  404. return array
  405. def decode(array):
  406. """
  407. Decode an array grid encoding back into a grid
  408. """
  409. width = array.shape[0]
  410. height = array.shape[1]
  411. assert array.shape[2] == 3
  412. grid = Grid(width, height)
  413. for j in range(0, height):
  414. for i in range(0, width):
  415. typeIdx = array[i, j, 0]
  416. colorIdx = array[i, j, 1]
  417. openIdx = array[i, j, 2]
  418. if typeIdx == 0:
  419. continue
  420. objType = IDX_TO_OBJECT[typeIdx]
  421. color = IDX_TO_COLOR[colorIdx]
  422. is_open = True if openIdx == 1 else 0
  423. if objType == 'wall':
  424. v = Wall(color)
  425. elif objType == 'ball':
  426. v = Ball(color)
  427. elif objType == 'key':
  428. v = Key(color)
  429. elif objType == 'box':
  430. v = Box(color)
  431. elif objType == 'door':
  432. v = Door(color, is_open)
  433. elif objType == 'locked_door':
  434. v = LockedDoor(color, is_open)
  435. elif objType == 'goal':
  436. v = Goal()
  437. else:
  438. assert False, "unknown obj type in decode '%s'" % objType
  439. grid.set(i, j, v)
  440. return grid
  441. def process_vis(
  442. grid,
  443. agent_pos,
  444. n_rays = 32,
  445. n_steps = 24,
  446. a_min = math.pi,
  447. a_max = 2 * math.pi
  448. ):
  449. """
  450. Use ray casting to determine the visibility of each grid cell
  451. """
  452. mask = np.zeros(shape=(grid.width, grid.height), dtype=np.bool)
  453. ang_step = (a_max - a_min) / n_rays
  454. dst_step = math.sqrt(grid.width ** 2 + grid.height ** 2) / n_steps
  455. ax = agent_pos[0] + 0.5
  456. ay = agent_pos[1] + 0.5
  457. for ray_idx in range(0, n_rays):
  458. angle = a_min + ang_step * ray_idx
  459. dx = dst_step * math.cos(angle)
  460. dy = dst_step * math.sin(angle)
  461. for step_idx in range(0, n_steps):
  462. x = ax + (step_idx * dx)
  463. y = ay + (step_idx * dy)
  464. i = math.floor(x)
  465. j = math.floor(y)
  466. # If we're outside of the grid, stop
  467. if i < 0 or i >= grid.width or j < 0 or j >= grid.height:
  468. break
  469. # Mark this cell as visible
  470. mask[i, j] = True
  471. # If we hit the obstructor, stop
  472. cell = grid.get(i, j)
  473. if cell and not cell.see_behind():
  474. break
  475. for j in range(0, grid.height):
  476. for i in range(0, grid.width):
  477. if not mask[i, j]:
  478. grid.set(i, j, None)
  479. #grid.set(i, j, Wall('red'))
  480. return mask
  481. class MiniGridEnv(gym.Env):
  482. """
  483. 2D grid world game environment
  484. """
  485. metadata = {
  486. 'render.modes': ['human', 'rgb_array', 'pixmap'],
  487. 'video.frames_per_second' : 10
  488. }
  489. # Enumeration of possible actions
  490. class Actions(IntEnum):
  491. # Turn left, turn right, move forward
  492. left = 0
  493. right = 1
  494. forward = 2
  495. # Pick up an object
  496. pickup = 3
  497. # Drop an object
  498. drop = 4
  499. # Toggle/activate an object
  500. toggle = 5
  501. # Wait/stay put/do nothing
  502. wait = 6
  503. def __init__(self, grid_size=16, max_steps=100):
  504. # Action enumeration for this environment
  505. self.actions = MiniGridEnv.Actions
  506. # Actions are discrete integer values
  507. self.action_space = spaces.Discrete(len(self.actions))
  508. # Observations are dictionaries containing an
  509. # encoding of the grid and a textual 'mission' string
  510. self.observation_space = spaces.Box(
  511. low=0,
  512. high=255,
  513. shape=OBS_ARRAY_SIZE,
  514. dtype='uint8'
  515. )
  516. self.observation_space = spaces.Dict({
  517. 'image': self.observation_space
  518. })
  519. # Range of possible rewards
  520. self.reward_range = (-1, 1000)
  521. # Renderer object used to render the whole grid (full-scale)
  522. self.grid_render = None
  523. # Renderer used to render observations (small-scale agent view)
  524. self.obs_render = None
  525. # Environment configuration
  526. self.grid_size = grid_size
  527. self.max_steps = max_steps
  528. # Starting position and direction for the agent
  529. self.start_pos = None
  530. self.start_dir = None
  531. # Initialize the state
  532. self.seed()
  533. self.reset()
  534. def reset(self):
  535. # Generate a new random grid at the start of each episode
  536. # To keep the same grid for each episode, call env.seed() with
  537. # the same seed before calling env.reset()
  538. self._genGrid(self.grid_size, self.grid_size)
  539. # These fields should be defined by _genGrid
  540. assert self.start_pos != None
  541. assert self.start_dir != None
  542. # Check that the agent doesn't overlap with an object
  543. assert self.grid.get(*self.start_pos) is None
  544. # Place the agent in the starting position and direction
  545. self.agent_pos = self.start_pos
  546. self.agent_dir = self.start_dir
  547. # Item picked up, being carried, initially nothing
  548. self.carrying = None
  549. # Step count since episode start
  550. self.step_count = 0
  551. # Return first observation
  552. obs = self.gen_obs()
  553. return obs
  554. def seed(self, seed=1337):
  555. # Seed the random number generator
  556. self.np_random, _ = seeding.np_random(seed)
  557. return [seed]
  558. @property
  559. def steps_remaining(self):
  560. return self.max_steps - self.step_count
  561. def __str__(self):
  562. """
  563. Produce a pretty string of the environment's grid along with the agent.
  564. The agent is represented by `⏩`. A grid pixel is represented by 2-character
  565. string, the first one for the object and the second one for the color.
  566. """
  567. from copy import deepcopy
  568. def rotate_left(array):
  569. new_array = deepcopy(array)
  570. for i in range(len(array)):
  571. for j in range(len(array[0])):
  572. new_array[j][len(array[0])-1-i] = array[i][j]
  573. return new_array
  574. def vertically_symmetrize(array):
  575. new_array = deepcopy(array)
  576. for i in range(len(array)):
  577. for j in range(len(array[0])):
  578. new_array[i][len(array[0])-1-j] = array[i][j]
  579. return new_array
  580. # Map of object id to short string
  581. OBJECT_IDX_TO_IDS = {
  582. 0: ' ',
  583. 1: 'W',
  584. 2: 'D',
  585. 3: 'L',
  586. 4: 'K',
  587. 5: 'B',
  588. 6: 'X',
  589. 7: 'G'
  590. }
  591. # Short string for opened door
  592. OPENDED_DOOR_IDS = '_'
  593. # Map of color id to short string
  594. COLOR_IDX_TO_IDS = {
  595. 0: 'R',
  596. 1: 'G',
  597. 2: 'B',
  598. 3: 'P',
  599. 4: 'Y',
  600. 5: 'E'
  601. }
  602. # Map agent's direction to short string
  603. AGENT_DIR_TO_IDS = {
  604. 0: '⏩',
  605. 1: '⏬',
  606. 2: '⏪',
  607. 3: '⏫'
  608. }
  609. array = self.grid.encode()
  610. array = rotate_left(array)
  611. array = vertically_symmetrize(array)
  612. new_array = []
  613. for line in array:
  614. new_line = []
  615. for pixel in line:
  616. # If the door is opened
  617. if pixel[0] in [2, 3] and pixel[2] == 1:
  618. object_ids = OPENDED_DOOR_IDS
  619. else:
  620. object_ids = OBJECT_IDX_TO_IDS[pixel[0]]
  621. # If no object
  622. if pixel[0] == 0:
  623. color_ids = ' '
  624. else:
  625. color_ids = COLOR_IDX_TO_IDS[pixel[1]]
  626. new_line.append(object_ids + color_ids)
  627. new_array.append(new_line)
  628. # Add the agent
  629. new_array[self.agent_pos[1]][self.agent_pos[0]] = AGENT_DIR_TO_IDS[self.agent_dir]
  630. return "\n".join([" ".join(line) for line in new_array])
  631. def _genGrid(self, width, height):
  632. assert False, "_genGrid needs to be implemented by each environment"
  633. def _randInt(self, low, high):
  634. """
  635. Generate random integer in [low,high[
  636. """
  637. return self.np_random.randint(low, high)
  638. def _randElem(self, iterable):
  639. """
  640. Pick a random element in a list
  641. """
  642. lst = list(iterable)
  643. idx = self._randInt(0, len(lst))
  644. return lst[idx]
  645. def _randPos(self, xLow, xHigh, yLow, yHigh):
  646. """
  647. Generate a random (x,y) position tuple
  648. """
  649. return (
  650. self.np_random.randint(xLow, xHigh),
  651. self.np_random.randint(yLow, yHigh)
  652. )
  653. def placeObj(self, obj, top=None, size=None, reject_fn=None):
  654. """
  655. Place an object at an empty position in the grid
  656. :param top: top-left position of the rectangle where to place
  657. :param size: size of the rectangle where to place
  658. :param reject_fn: function to filter out potential positions
  659. """
  660. if top is None:
  661. top = (0, 0)
  662. if size is None:
  663. size = (self.grid.width, self.grid.height)
  664. while True:
  665. pos = (
  666. self._randInt(top[0], top[0] + size[0]),
  667. self._randInt(top[1], top[1] + size[1])
  668. )
  669. # Don't place the object on top of another object
  670. if self.grid.get(*pos) != None:
  671. continue
  672. # Don't place the object where the agent is
  673. if pos == self.start_pos:
  674. continue
  675. # Check if there is a filtering criterion
  676. if reject_fn and reject_fn(self, pos):
  677. continue
  678. break
  679. self.grid.set(*pos, obj)
  680. return pos
  681. def placeAgent(self, top=None, size=None, randDir=True):
  682. """
  683. Set the agent's starting point at an empty position in the grid
  684. """
  685. pos = self.placeObj(None, top, size)
  686. self.start_pos = pos
  687. if randDir:
  688. self.start_dir = self._randInt(0, 4)
  689. return pos
  690. def get_dir_vec(self):
  691. """
  692. Get the direction vector for the agent, pointing in the direction
  693. of forward movement.
  694. """
  695. # Pointing right
  696. if self.agent_dir == 0:
  697. return (1, 0)
  698. # Down (positive Y)
  699. elif self.agent_dir == 1:
  700. return (0, 1)
  701. # Pointing left
  702. elif self.agent_dir == 2:
  703. return (-1, 0)
  704. # Up (negative Y)
  705. elif self.agent_dir == 3:
  706. return (0, -1)
  707. else:
  708. assert False
  709. def get_right_vec(self):
  710. """
  711. Get the vector pointing to the right of the agent.
  712. """
  713. dx, dy = self.get_dir_vec()
  714. return -dy, dx
  715. def get_view_coords(self, i, j):
  716. """
  717. Translate and rotate absolute grid coordinates (i, j) into the
  718. agent's partially observable view (sub-grid). Note that the resulting
  719. coordinates may be negative or outside of the agent's view size.
  720. """
  721. ax, ay = self.agent_pos
  722. dx, dy = self.get_dir_vec()
  723. rx, ry = self.get_right_vec()
  724. # Compute the absolute coordinates of the top-left view corner
  725. sz = AGENT_VIEW_SIZE
  726. hs = AGENT_VIEW_SIZE // 2
  727. tx = ax + (dx * (sz-1)) - (rx * hs)
  728. ty = ay + (dy * (sz-1)) - (ry * hs)
  729. lx = i - tx
  730. ly = j - ty
  731. # Project the coordinates of the object relative to the top-left
  732. # corner onto the agent's own coordinate system
  733. vx = (rx*lx + ry*ly)
  734. vy = -(dx*lx + dy*ly)
  735. return vx, vy
  736. def get_view_exts(self):
  737. """
  738. Get the extents of the square set of tiles visible to the agent
  739. Note: the bottom extent indices are not included in the set
  740. """
  741. # Facing right
  742. if self.agent_dir == 0:
  743. topX = self.agent_pos[0]
  744. topY = self.agent_pos[1] - AGENT_VIEW_SIZE // 2
  745. # Facing down
  746. elif self.agent_dir == 1:
  747. topX = self.agent_pos[0] - AGENT_VIEW_SIZE // 2
  748. topY = self.agent_pos[1]
  749. # Facing left
  750. elif self.agent_dir == 2:
  751. topX = self.agent_pos[0] - AGENT_VIEW_SIZE + 1
  752. topY = self.agent_pos[1] - AGENT_VIEW_SIZE // 2
  753. # Facing up
  754. elif self.agent_dir == 3:
  755. topX = self.agent_pos[0] - AGENT_VIEW_SIZE // 2
  756. topY = self.agent_pos[1] - AGENT_VIEW_SIZE + 1
  757. else:
  758. assert False, "invalid agent direction"
  759. botX = topX + AGENT_VIEW_SIZE
  760. botY = topY + AGENT_VIEW_SIZE
  761. return (topX, topY, botX, botY)
  762. def agent_sees(self, x, y):
  763. """
  764. Check if a grid position is visible to the agent
  765. """
  766. vx, vy = self.get_view_coords(x, y)
  767. if vx < 0 or vy < 0 or vx >= AGENT_VIEW_SIZE or vy >= AGENT_VIEW_SIZE:
  768. return False
  769. obs = self.gen_obs()
  770. obs_grid = Grid.decode(obs['image'])
  771. obs_cell = obs_grid.get(vx, vy)
  772. world_cell = self.grid.get(x, y)
  773. return obs_cell is not None and obs_cell.type == world_cell.type
  774. def step(self, action):
  775. self.step_count += 1
  776. reward = 0
  777. done = False
  778. # Get the position in front of the agent
  779. u, v = self.get_dir_vec()
  780. fwdPos = (self.agent_pos[0] + u, self.agent_pos[1] + v)
  781. # Get the contents of the cell in front of the agent
  782. fwdCell = self.grid.get(*fwdPos)
  783. # Rotate left
  784. if action == self.actions.left:
  785. self.agent_dir -= 1
  786. if self.agent_dir < 0:
  787. self.agent_dir += 4
  788. # Rotate right
  789. elif action == self.actions.right:
  790. self.agent_dir = (self.agent_dir + 1) % 4
  791. # Move forward
  792. elif action == self.actions.forward:
  793. if fwdCell == None or fwdCell.can_overlap():
  794. self.agent_pos = fwdPos
  795. if fwdCell != None and fwdCell.type == 'goal':
  796. done = True
  797. reward = 1000 - self.step_count
  798. # Pick up an object
  799. elif action == self.actions.pickup:
  800. if fwdCell and fwdCell.canPickup():
  801. if self.carrying is None:
  802. self.carrying = fwdCell
  803. self.grid.set(*fwdPos, None)
  804. # Drop an object
  805. elif action == self.actions.drop:
  806. if not fwdCell and self.carrying:
  807. self.grid.set(*fwdPos, self.carrying)
  808. self.carrying = None
  809. # Toggle/activate an object
  810. elif action == self.actions.toggle:
  811. if fwdCell:
  812. fwdCell.toggle(self, fwdPos)
  813. # Wait/do nothing
  814. elif action == self.actions.wait:
  815. pass
  816. else:
  817. assert False, "unknown action"
  818. if self.step_count >= self.max_steps:
  819. done = True
  820. obs = self.gen_obs()
  821. return obs, reward, done, {}
  822. def gen_obs(self):
  823. """
  824. Generate the agent's view (partially observable, low-resolution encoding)
  825. """
  826. topX, topY, botX, botY = self.get_view_exts()
  827. grid = self.grid.slice(topX, topY, AGENT_VIEW_SIZE, AGENT_VIEW_SIZE)
  828. for i in range(self.agent_dir + 1):
  829. grid = grid.rotateLeft()
  830. # Make it so the agent sees what it's carrying
  831. # We do this by placing the carried object at the agent's position
  832. # in the agent's partially observable view
  833. agent_pos = grid.width // 2, grid.height - 1
  834. if self.carrying:
  835. grid.set(*agent_pos, self.carrying)
  836. else:
  837. grid.set(*agent_pos, None)
  838. # Process occluders and visibility
  839. grid.process_vis(agent_pos=(3, 6))
  840. # Encode the partially observable view into a numpy array
  841. image = grid.encode()
  842. assert hasattr(self, 'mission'), "environments must define a textual mission string"
  843. # Observations are dictionaries containing:
  844. # - an image (partially observable view of the environment)
  845. # - the agent's direction/orientation (acting as a compass)
  846. # - a textual mission string (instructions for the agent)
  847. obs = {
  848. 'image': image,
  849. 'direction': self.agent_dir,
  850. 'mission': self.mission
  851. }
  852. return obs
  853. def get_obs_render(self, obs):
  854. """
  855. Render an agent observation for visualization
  856. """
  857. if self.obs_render == None:
  858. self.obs_render = Renderer(
  859. AGENT_VIEW_SIZE * CELL_PIXELS // 2,
  860. AGENT_VIEW_SIZE * CELL_PIXELS // 2
  861. )
  862. r = self.obs_render
  863. r.beginFrame()
  864. grid = Grid.decode(obs)
  865. # Render the whole grid
  866. grid.render(r, CELL_PIXELS // 2)
  867. # Draw the agent
  868. r.push()
  869. r.scale(0.5, 0.5)
  870. r.translate(
  871. CELL_PIXELS * (0.5 + AGENT_VIEW_SIZE // 2),
  872. CELL_PIXELS * (AGENT_VIEW_SIZE - 0.5)
  873. )
  874. r.rotate(3 * 90)
  875. r.setLineColor(255, 0, 0)
  876. r.setColor(255, 0, 0)
  877. r.drawPolygon([
  878. (-12, 10),
  879. ( 12, 0),
  880. (-12, -10)
  881. ])
  882. r.pop()
  883. r.endFrame()
  884. return r.getPixmap()
  885. def render(self, mode='human', close=False):
  886. """
  887. Render the whole-grid human view
  888. """
  889. if close:
  890. if self.grid_render:
  891. self.grid_render.close()
  892. return
  893. if self.grid_render is None:
  894. self.grid_render = Renderer(
  895. self.grid_size * CELL_PIXELS,
  896. self.grid_size * CELL_PIXELS,
  897. True if mode == 'human' else False
  898. )
  899. r = self.grid_render
  900. r.beginFrame()
  901. # Render the whole grid
  902. self.grid.render(r, CELL_PIXELS)
  903. # Draw the agent
  904. r.push()
  905. r.translate(
  906. CELL_PIXELS * (self.agent_pos[0] + 0.5),
  907. CELL_PIXELS * (self.agent_pos[1] + 0.5)
  908. )
  909. r.rotate(self.agent_dir * 90)
  910. r.setLineColor(255, 0, 0)
  911. r.setColor(255, 0, 0)
  912. r.drawPolygon([
  913. (-12, 10),
  914. ( 12, 0),
  915. (-12, -10)
  916. ])
  917. r.pop()
  918. # Highlight what the agent can see
  919. topX, topY, botX, botY = self.get_view_exts()
  920. r.fillRect(
  921. topX * CELL_PIXELS,
  922. topY * CELL_PIXELS,
  923. AGENT_VIEW_SIZE * CELL_PIXELS,
  924. AGENT_VIEW_SIZE * CELL_PIXELS,
  925. 200, 200, 200, 75
  926. )
  927. r.endFrame()
  928. if mode == 'rgb_array':
  929. return r.getArray()
  930. elif mode == 'pixmap':
  931. return r.getPixmap()
  932. return r