minigrid.py 34 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261
  1. import math
  2. import gym
  3. from enum import IntEnum
  4. import numpy as np
  5. from gym import error, spaces, utils
  6. from gym.utils import seeding
  7. from gym_minigrid.rendering import *
  8. # Size in pixels of a cell in the full-scale human view
  9. CELL_PIXELS = 32
  10. # Number of cells (width and height) in the agent view
  11. AGENT_VIEW_SIZE = 7
  12. # Size of the array given as an observation to the agent
  13. OBS_ARRAY_SIZE = (AGENT_VIEW_SIZE, AGENT_VIEW_SIZE, 3)
  14. # Map of color names to RGB values
  15. COLORS = {
  16. 'red' : (255, 0, 0),
  17. 'green' : (0, 255, 0),
  18. 'blue' : (0, 0, 255),
  19. 'purple': (112, 39, 195),
  20. 'yellow': (255, 255, 0),
  21. 'grey' : (100, 100, 100)
  22. }
  23. COLOR_NAMES = sorted(list(COLORS.keys()))
  24. # Used to map colors to integers
  25. COLOR_TO_IDX = {
  26. 'red' : 0,
  27. 'green' : 1,
  28. 'blue' : 2,
  29. 'purple': 3,
  30. 'yellow': 4,
  31. 'grey' : 5
  32. }
  33. IDX_TO_COLOR = dict(zip(COLOR_TO_IDX.values(), COLOR_TO_IDX.keys()))
  34. # Map of object type to integers
  35. OBJECT_TO_IDX = {
  36. 'empty' : 0,
  37. 'wall' : 1,
  38. 'door' : 2,
  39. 'locked_door' : 3,
  40. 'key' : 4,
  41. 'ball' : 5,
  42. 'box' : 6,
  43. 'goal' : 7
  44. }
  45. IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys()))
  46. # Map of agent direction indices to vectors
  47. DIR_TO_VEC = [
  48. # Pointing right (positive X)
  49. np.array((1, 0)),
  50. # Down (positive Y)
  51. np.array((0, 1)),
  52. # Pointing left (negative X)
  53. np.array((-1, 0)),
  54. # Up (negative Y)
  55. np.array((0, -1)),
  56. ]
  57. class WorldObj:
  58. """
  59. Base class for grid world objects
  60. """
  61. def __init__(self, type, color):
  62. assert type in OBJECT_TO_IDX, type
  63. assert color in COLOR_TO_IDX, color
  64. self.type = type
  65. self.color = color
  66. self.contains = None
  67. # Initial position of the object
  68. self.init_pos = None
  69. # Current position of the object
  70. self.cur_pos = None
  71. def can_overlap(self):
  72. """Can the agent overlap with this?"""
  73. return False
  74. def can_pickup(self):
  75. """Can the agent pick this up?"""
  76. return False
  77. def can_contain(self):
  78. """Can this contain another object?"""
  79. return False
  80. def see_behind(self):
  81. """Can the agent see behind this object?"""
  82. return True
  83. def toggle(self, env, pos):
  84. """Method to trigger/toggle an action this object performs"""
  85. return False
  86. def render(self, r):
  87. """Draw this object with the given renderer"""
  88. raise NotImplementedError
  89. def _set_color(self, r):
  90. """Set the color of this object as the active drawing color"""
  91. c = COLORS[self.color]
  92. r.setLineColor(c[0], c[1], c[2])
  93. r.setColor(c[0], c[1], c[2])
  94. class Goal(WorldObj):
  95. def __init__(self):
  96. super(Goal, self).__init__('goal', 'green')
  97. def can_overlap(self):
  98. return True
  99. def render(self, r):
  100. self._set_color(r)
  101. r.drawPolygon([
  102. (0 , CELL_PIXELS),
  103. (CELL_PIXELS, CELL_PIXELS),
  104. (CELL_PIXELS, 0),
  105. (0 , 0)
  106. ])
  107. class Wall(WorldObj):
  108. def __init__(self, color='grey'):
  109. super(Wall, self).__init__('wall', color)
  110. def see_behind(self):
  111. return False
  112. def render(self, r):
  113. self._set_color(r)
  114. r.drawPolygon([
  115. (0 , CELL_PIXELS),
  116. (CELL_PIXELS, CELL_PIXELS),
  117. (CELL_PIXELS, 0),
  118. (0 , 0)
  119. ])
  120. class Door(WorldObj):
  121. def __init__(self, color, is_open=False):
  122. super(Door, self).__init__('door', color)
  123. self.is_open = is_open
  124. def can_overlap(self):
  125. """The agent can only walk over this cell when the door is open"""
  126. return self.is_open
  127. def see_behind(self):
  128. return self.is_open
  129. def toggle(self, env, pos):
  130. self.is_open = not self.is_open
  131. return True
  132. def render(self, r):
  133. c = COLORS[self.color]
  134. r.setLineColor(c[0], c[1], c[2])
  135. r.setColor(0, 0, 0)
  136. if self.is_open:
  137. r.drawPolygon([
  138. (CELL_PIXELS-2, CELL_PIXELS),
  139. (CELL_PIXELS , CELL_PIXELS),
  140. (CELL_PIXELS , 0),
  141. (CELL_PIXELS-2, 0)
  142. ])
  143. return
  144. r.drawPolygon([
  145. (0 , CELL_PIXELS),
  146. (CELL_PIXELS, CELL_PIXELS),
  147. (CELL_PIXELS, 0),
  148. (0 , 0)
  149. ])
  150. r.drawPolygon([
  151. (2 , CELL_PIXELS-2),
  152. (CELL_PIXELS-2, CELL_PIXELS-2),
  153. (CELL_PIXELS-2, 2),
  154. (2 , 2)
  155. ])
  156. r.drawCircle(CELL_PIXELS * 0.75, CELL_PIXELS * 0.5, 2)
  157. class LockedDoor(WorldObj):
  158. def __init__(self, color, is_open=False):
  159. super(LockedDoor, self).__init__('locked_door', color)
  160. self.is_open = is_open
  161. def toggle(self, env, pos):
  162. # If the player has the right key to open the door
  163. if isinstance(env.carrying, Key) and env.carrying.color == self.color:
  164. self.is_open = True
  165. # The key has been used, remove it from the agent
  166. env.carrying = None
  167. return True
  168. return False
  169. def can_overlap(self):
  170. """The agent can only walk over this cell when the door is open"""
  171. return self.is_open
  172. def see_behind(self):
  173. return self.is_open
  174. def render(self, r):
  175. c = COLORS[self.color]
  176. r.setLineColor(c[0], c[1], c[2])
  177. r.setColor(c[0], c[1], c[2], 50)
  178. if self.is_open:
  179. r.drawPolygon([
  180. (CELL_PIXELS-2, CELL_PIXELS),
  181. (CELL_PIXELS , CELL_PIXELS),
  182. (CELL_PIXELS , 0),
  183. (CELL_PIXELS-2, 0)
  184. ])
  185. return
  186. r.drawPolygon([
  187. (0 , CELL_PIXELS),
  188. (CELL_PIXELS, CELL_PIXELS),
  189. (CELL_PIXELS, 0),
  190. (0 , 0)
  191. ])
  192. r.drawPolygon([
  193. (2 , CELL_PIXELS-2),
  194. (CELL_PIXELS-2, CELL_PIXELS-2),
  195. (CELL_PIXELS-2, 2),
  196. (2 , 2)
  197. ])
  198. r.drawLine(
  199. CELL_PIXELS * 0.55,
  200. CELL_PIXELS * 0.5,
  201. CELL_PIXELS * 0.75,
  202. CELL_PIXELS * 0.5
  203. )
  204. class Key(WorldObj):
  205. def __init__(self, color='blue'):
  206. super(Key, self).__init__('key', color)
  207. def can_pickup(self):
  208. return True
  209. def render(self, r):
  210. self._set_color(r)
  211. # Vertical quad
  212. r.drawPolygon([
  213. (16, 10),
  214. (20, 10),
  215. (20, 28),
  216. (16, 28)
  217. ])
  218. # Teeth
  219. r.drawPolygon([
  220. (12, 19),
  221. (16, 19),
  222. (16, 21),
  223. (12, 21)
  224. ])
  225. r.drawPolygon([
  226. (12, 26),
  227. (16, 26),
  228. (16, 28),
  229. (12, 28)
  230. ])
  231. r.drawCircle(18, 9, 6)
  232. r.setLineColor(0, 0, 0)
  233. r.setColor(0, 0, 0)
  234. r.drawCircle(18, 9, 2)
  235. class Ball(WorldObj):
  236. def __init__(self, color='blue'):
  237. super(Ball, self).__init__('ball', color)
  238. def can_pickup(self):
  239. return True
  240. def render(self, r):
  241. self._set_color(r)
  242. r.drawCircle(CELL_PIXELS * 0.5, CELL_PIXELS * 0.5, 10)
  243. class Box(WorldObj):
  244. def __init__(self, color, contains=None):
  245. super(Box, self).__init__('box', color)
  246. self.contains = contains
  247. def can_pickup(self):
  248. return True
  249. def render(self, r):
  250. c = COLORS[self.color]
  251. r.setLineColor(c[0], c[1], c[2])
  252. r.setColor(0, 0, 0)
  253. r.setLineWidth(2)
  254. r.drawPolygon([
  255. (4 , CELL_PIXELS-4),
  256. (CELL_PIXELS-4, CELL_PIXELS-4),
  257. (CELL_PIXELS-4, 4),
  258. (4 , 4)
  259. ])
  260. r.drawLine(
  261. 4,
  262. CELL_PIXELS / 2,
  263. CELL_PIXELS - 4,
  264. CELL_PIXELS / 2
  265. )
  266. r.setLineWidth(1)
  267. def toggle(self, env, pos):
  268. # Replace the box by its contents
  269. env.grid.set(*pos, self.contains)
  270. return True
  271. class Grid:
  272. """
  273. Represent a grid and operations on it
  274. """
  275. def __init__(self, width, height):
  276. assert width >= 4
  277. assert height >= 4
  278. self.width = width
  279. self.height = height
  280. self.grid = [None] * width * height
  281. def __contains__(self, key):
  282. if isinstance(key, WorldObj):
  283. for e in self.grid:
  284. if e is key:
  285. return True
  286. elif isinstance(key, tuple):
  287. for e in self.grid:
  288. if e is None:
  289. continue
  290. if (e.color, e.type) == key:
  291. return True
  292. return False
  293. def __eq__(self, other):
  294. grid1 = self.encode()
  295. grid2 = other.encode()
  296. return np.array_equal(grid2, grid1)
  297. def __ne__(self, other):
  298. return not self == other
  299. def copy(self):
  300. from copy import deepcopy
  301. return deepcopy(self)
  302. def set(self, i, j, v):
  303. assert i >= 0 and i < self.width
  304. assert j >= 0 and j < self.height
  305. self.grid[j * self.width + i] = v
  306. def get(self, i, j):
  307. assert i >= 0 and i < self.width
  308. assert j >= 0 and j < self.height
  309. return self.grid[j * self.width + i]
  310. def horz_wall(self, x, y, length=None):
  311. if length is None:
  312. length = self.width - x
  313. for i in range(0, length):
  314. self.set(x + i, y, Wall())
  315. def vert_wall(self, x, y, length=None):
  316. if length is None:
  317. length = self.height - y
  318. for j in range(0, length):
  319. self.set(x, y + j, Wall())
  320. def wall_rect(self, x, y, w, h):
  321. self.horz_wall(x, y, w)
  322. self.horz_wall(x, y+h-1, w)
  323. self.vert_wall(x, y, h)
  324. self.vert_wall(x+w-1, y, h)
  325. def rotate_left(self):
  326. """
  327. Rotate the grid to the left (counter-clockwise)
  328. """
  329. grid = Grid(self.width, self.height)
  330. for j in range(0, self.height):
  331. for i in range(0, self.width):
  332. v = self.get(self.width - 1 - j, i)
  333. grid.set(i, j, v)
  334. return grid
  335. def slice(self, topX, topY, width, height):
  336. """
  337. Get a subset of the grid
  338. """
  339. grid = Grid(width, height)
  340. for j in range(0, height):
  341. for i in range(0, width):
  342. x = topX + i
  343. y = topY + j
  344. if x >= 0 and x < self.width and \
  345. y >= 0 and y < self.height:
  346. v = self.get(x, y)
  347. else:
  348. v = Wall()
  349. grid.set(i, j, v)
  350. return grid
  351. def render(self, r, tile_size):
  352. """
  353. Render this grid at a given scale
  354. :param r: target renderer object
  355. :param tile_size: tile size in pixels
  356. """
  357. assert r.width == self.width * tile_size
  358. assert r.height == self.height * tile_size
  359. # Total grid size at native scale
  360. widthPx = self.width * CELL_PIXELS
  361. heightPx = self.height * CELL_PIXELS
  362. r.push()
  363. # Internally, we draw at the "large" full-grid resolution, but we
  364. # use the renderer to scale back to the desired size
  365. r.scale(tile_size / CELL_PIXELS, tile_size / CELL_PIXELS)
  366. # Draw the background of the in-world cells black
  367. r.fillRect(
  368. 0,
  369. 0,
  370. widthPx,
  371. heightPx,
  372. 0, 0, 0
  373. )
  374. # Draw grid lines
  375. r.setLineColor(100, 100, 100)
  376. for rowIdx in range(0, self.height):
  377. y = CELL_PIXELS * rowIdx
  378. r.drawLine(0, y, widthPx, y)
  379. for colIdx in range(0, self.width):
  380. x = CELL_PIXELS * colIdx
  381. r.drawLine(x, 0, x, heightPx)
  382. # Render the grid
  383. for j in range(0, self.height):
  384. for i in range(0, self.width):
  385. cell = self.get(i, j)
  386. if cell == None:
  387. continue
  388. r.push()
  389. r.translate(i * CELL_PIXELS, j * CELL_PIXELS)
  390. cell.render(r)
  391. r.pop()
  392. r.pop()
  393. def encode(self):
  394. """
  395. Produce a compact numpy encoding of the grid
  396. """
  397. codeSize = self.width * self.height * 3
  398. array = np.zeros(shape=(self.width, self.height, 3), dtype='uint8')
  399. for j in range(0, self.height):
  400. for i in range(0, self.width):
  401. v = self.get(i, j)
  402. if v == None:
  403. continue
  404. array[i, j, 0] = OBJECT_TO_IDX[v.type]
  405. array[i, j, 1] = COLOR_TO_IDX[v.color]
  406. if hasattr(v, 'is_open') and v.is_open:
  407. array[i, j, 2] = 1
  408. return array
  409. def decode(array):
  410. """
  411. Decode an array grid encoding back into a grid
  412. """
  413. width = array.shape[0]
  414. height = array.shape[1]
  415. assert array.shape[2] == 3
  416. grid = Grid(width, height)
  417. for j in range(0, height):
  418. for i in range(0, width):
  419. typeIdx = array[i, j, 0]
  420. colorIdx = array[i, j, 1]
  421. openIdx = array[i, j, 2]
  422. if typeIdx == 0:
  423. continue
  424. objType = IDX_TO_OBJECT[typeIdx]
  425. color = IDX_TO_COLOR[colorIdx]
  426. is_open = True if openIdx == 1 else 0
  427. if objType == 'wall':
  428. v = Wall(color)
  429. elif objType == 'ball':
  430. v = Ball(color)
  431. elif objType == 'key':
  432. v = Key(color)
  433. elif objType == 'box':
  434. v = Box(color)
  435. elif objType == 'door':
  436. v = Door(color, is_open)
  437. elif objType == 'locked_door':
  438. v = LockedDoor(color, is_open)
  439. elif objType == 'goal':
  440. v = Goal()
  441. else:
  442. assert False, "unknown obj type in decode '%s'" % objType
  443. grid.set(i, j, v)
  444. return grid
  445. def process_vis(grid, agent_pos):
  446. mask = np.zeros(shape=(grid.width, grid.height), dtype=np.bool)
  447. mask[agent_pos[0], agent_pos[1]] = True
  448. for j in reversed(range(1, grid.height)):
  449. for i in range(0, grid.width-1):
  450. if not mask[i, j]:
  451. continue
  452. cell = grid.get(i, j)
  453. if cell and not cell.see_behind():
  454. continue
  455. mask[i+1, j] = True
  456. mask[i+1, j-1] = True
  457. mask[i, j-1] = True
  458. for i in reversed(range(1, grid.width)):
  459. if not mask[i, j]:
  460. continue
  461. cell = grid.get(i, j)
  462. if cell and not cell.see_behind():
  463. continue
  464. mask[i-1, j-1] = True
  465. mask[i-1, j] = True
  466. mask[i, j-1] = True
  467. for j in range(0, grid.height):
  468. for i in range(0, grid.width):
  469. if not mask[i, j]:
  470. grid.set(i, j, None)
  471. return mask
  472. class MiniGridEnv(gym.Env):
  473. """
  474. 2D grid world game environment
  475. """
  476. metadata = {
  477. 'render.modes': ['human', 'rgb_array', 'pixmap'],
  478. 'video.frames_per_second' : 10
  479. }
  480. # Enumeration of possible actions
  481. class Actions(IntEnum):
  482. # Turn left, turn right, move forward
  483. left = 0
  484. right = 1
  485. forward = 2
  486. # Pick up an object
  487. pickup = 3
  488. # Drop an object
  489. drop = 4
  490. # Toggle/activate an object
  491. toggle = 5
  492. # Done completing task
  493. done = 6
  494. def __init__(
  495. self,
  496. grid_size=16,
  497. max_steps=100,
  498. see_through_walls=False,
  499. seed=1337
  500. ):
  501. # Action enumeration for this environment
  502. self.actions = MiniGridEnv.Actions
  503. # Actions are discrete integer values
  504. self.action_space = spaces.Discrete(len(self.actions))
  505. # Observations are dictionaries containing an
  506. # encoding of the grid and a textual 'mission' string
  507. self.observation_space = spaces.Box(
  508. low=0,
  509. high=255,
  510. shape=OBS_ARRAY_SIZE,
  511. dtype='uint8'
  512. )
  513. self.observation_space = spaces.Dict({
  514. 'image': self.observation_space
  515. })
  516. # Range of possible rewards
  517. self.reward_range = (0, 1)
  518. # Renderer object used to render the whole grid (full-scale)
  519. self.grid_render = None
  520. # Renderer used to render observations (small-scale agent view)
  521. self.obs_render = None
  522. # Environment configuration
  523. self.grid_size = grid_size
  524. self.max_steps = max_steps
  525. self.see_through_walls = see_through_walls
  526. # Starting position and direction for the agent
  527. self.start_pos = None
  528. self.start_dir = None
  529. # Initialize the RNG
  530. self.seed(seed=seed)
  531. # Initialize the state
  532. self.reset()
  533. def reset(self):
  534. # Generate a new random grid at the start of each episode
  535. # To keep the same grid for each episode, call env.seed() with
  536. # the same seed before calling env.reset()
  537. self._gen_grid(self.grid_size, self.grid_size)
  538. # These fields should be defined by _gen_grid
  539. assert self.start_pos is not None
  540. assert self.start_dir is not None
  541. # Check that the agent doesn't overlap with an object
  542. assert self.grid.get(*self.start_pos) is None
  543. # Place the agent in the starting position and direction
  544. self.agent_pos = self.start_pos
  545. self.agent_dir = self.start_dir
  546. # Item picked up, being carried, initially nothing
  547. self.carrying = None
  548. # Step count since episode start
  549. self.step_count = 0
  550. # Return first observation
  551. obs = self.gen_obs()
  552. return obs
  553. def seed(self, seed=1337):
  554. # Seed the random number generator
  555. self.np_random, _ = seeding.np_random(seed)
  556. return [seed]
  557. @property
  558. def steps_remaining(self):
  559. return self.max_steps - self.step_count
  560. def __str__(self):
  561. """
  562. Produce a pretty string of the environment's grid along with the agent.
  563. The agent is represented by `⏩`. A grid pixel is represented by 2-character
  564. string, the first one for the object and the second one for the color.
  565. """
  566. from copy import deepcopy
  567. def rotate_left(array):
  568. new_array = deepcopy(array)
  569. for i in range(len(array)):
  570. for j in range(len(array[0])):
  571. new_array[j][len(array[0])-1-i] = array[i][j]
  572. return new_array
  573. def vertically_symmetrize(array):
  574. new_array = deepcopy(array)
  575. for i in range(len(array)):
  576. for j in range(len(array[0])):
  577. new_array[i][len(array[0])-1-j] = array[i][j]
  578. return new_array
  579. # Map of object id to short string
  580. OBJECT_IDX_TO_IDS = {
  581. 0: ' ',
  582. 1: 'W',
  583. 2: 'D',
  584. 3: 'L',
  585. 4: 'K',
  586. 5: 'B',
  587. 6: 'X',
  588. 7: 'G'
  589. }
  590. # Short string for opened door
  591. OPENDED_DOOR_IDS = '_'
  592. # Map of color id to short string
  593. COLOR_IDX_TO_IDS = {
  594. 0: 'R',
  595. 1: 'G',
  596. 2: 'B',
  597. 3: 'P',
  598. 4: 'Y',
  599. 5: 'E'
  600. }
  601. # Map agent's direction to short string
  602. AGENT_DIR_TO_IDS = {
  603. 0: '⏩',
  604. 1: '⏬',
  605. 2: '⏪',
  606. 3: '⏫'
  607. }
  608. array = self.grid.encode()
  609. array = rotate_left(array)
  610. array = vertically_symmetrize(array)
  611. new_array = []
  612. for line in array:
  613. new_line = []
  614. for pixel in line:
  615. # If the door is opened
  616. if pixel[0] in [2, 3] and pixel[2] == 1:
  617. object_ids = OPENDED_DOOR_IDS
  618. else:
  619. object_ids = OBJECT_IDX_TO_IDS[pixel[0]]
  620. # If no object
  621. if pixel[0] == 0:
  622. color_ids = ' '
  623. else:
  624. color_ids = COLOR_IDX_TO_IDS[pixel[1]]
  625. new_line.append(object_ids + color_ids)
  626. new_array.append(new_line)
  627. # Add the agent
  628. new_array[self.agent_pos[1]][self.agent_pos[0]] = AGENT_DIR_TO_IDS[self.agent_dir]
  629. return "\n".join([" ".join(line) for line in new_array])
  630. def _gen_grid(self, width, height):
  631. assert False, "_gen_grid needs to be implemented by each environment"
  632. def _reward(self):
  633. """
  634. Compute the reward to be given upon success
  635. """
  636. return 1 - 0.9 * (self.step_count / self.max_steps)
  637. def _rand_int(self, low, high):
  638. """
  639. Generate random integer in [low,high[
  640. """
  641. return self.np_random.randint(low, high)
  642. def _rand_float(self, low, high):
  643. """
  644. Generate random float in [low,high[
  645. """
  646. return self.np_random.uniform(low, high)
  647. def _rand_bool(self):
  648. """
  649. Generate random boolean value
  650. """
  651. return (self.np_random.randint(0, 2) == 0)
  652. def _rand_elem(self, iterable):
  653. """
  654. Pick a random element in a list
  655. """
  656. lst = list(iterable)
  657. idx = self._rand_int(0, len(lst))
  658. return lst[idx]
  659. def _rand_subset(self, iterable, num_elems):
  660. """
  661. Sample a random subset of distinct elements of a list
  662. """
  663. lst = list(iterable)
  664. assert num_elems <= len(lst)
  665. out = []
  666. while len(out) < num_elems:
  667. elem = self._rand_elem(lst)
  668. lst.remove(elem)
  669. out.append(elem)
  670. return out
  671. def _rand_color(self):
  672. """
  673. Generate a random color name (string)
  674. """
  675. return self._rand_elem(COLOR_NAMES)
  676. def _rand_pos(self, xLow, xHigh, yLow, yHigh):
  677. """
  678. Generate a random (x,y) position tuple
  679. """
  680. return (
  681. self.np_random.randint(xLow, xHigh),
  682. self.np_random.randint(yLow, yHigh)
  683. )
  684. def place_obj(self, obj, top=None, size=None, reject_fn=None):
  685. """
  686. Place an object at an empty position in the grid
  687. :param top: top-left position of the rectangle where to place
  688. :param size: size of the rectangle where to place
  689. :param reject_fn: function to filter out potential positions
  690. """
  691. if top is None:
  692. top = (0, 0)
  693. if size is None:
  694. size = (self.grid.width, self.grid.height)
  695. while True:
  696. pos = np.array((
  697. self._rand_int(top[0], top[0] + size[0]),
  698. self._rand_int(top[1], top[1] + size[1])
  699. ))
  700. # Don't place the object on top of another object
  701. if self.grid.get(*pos) != None:
  702. continue
  703. # Don't place the object where the agent is
  704. if np.array_equal(pos, self.start_pos):
  705. continue
  706. # Check if there is a filtering criterion
  707. if reject_fn and reject_fn(self, pos):
  708. continue
  709. break
  710. self.grid.set(*pos, obj)
  711. if obj is not None:
  712. obj.init_pos = pos
  713. obj.cur_pos = pos
  714. return pos
  715. def place_agent(self, top=None, size=None, rand_dir=True):
  716. """
  717. Set the agent's starting point at an empty position in the grid
  718. """
  719. self.start_pos = None
  720. pos = self.place_obj(None, top, size)
  721. self.start_pos = pos
  722. if rand_dir:
  723. self.start_dir = self._rand_int(0, 4)
  724. return pos
  725. @property
  726. def dir_vec(self):
  727. """
  728. Get the direction vector for the agent, pointing in the direction
  729. of forward movement.
  730. """
  731. assert self.agent_dir >= 0 and self.agent_dir < 4
  732. return DIR_TO_VEC[self.agent_dir]
  733. @property
  734. def right_vec(self):
  735. """
  736. Get the vector pointing to the right of the agent.
  737. """
  738. dx, dy = self.dir_vec
  739. return np.array((-dy, dx))
  740. @property
  741. def front_pos(self):
  742. """
  743. Get the position of the cell that is right in front of the agent
  744. """
  745. return self.agent_pos + self.dir_vec
  746. def get_view_coords(self, i, j):
  747. """
  748. Translate and rotate absolute grid coordinates (i, j) into the
  749. agent's partially observable view (sub-grid). Note that the resulting
  750. coordinates may be negative or outside of the agent's view size.
  751. """
  752. ax, ay = self.agent_pos
  753. dx, dy = self.dir_vec
  754. rx, ry = self.right_vec
  755. # Compute the absolute coordinates of the top-left view corner
  756. sz = AGENT_VIEW_SIZE
  757. hs = AGENT_VIEW_SIZE // 2
  758. tx = ax + (dx * (sz-1)) - (rx * hs)
  759. ty = ay + (dy * (sz-1)) - (ry * hs)
  760. lx = i - tx
  761. ly = j - ty
  762. # Project the coordinates of the object relative to the top-left
  763. # corner onto the agent's own coordinate system
  764. vx = (rx*lx + ry*ly)
  765. vy = -(dx*lx + dy*ly)
  766. return vx, vy
  767. def get_view_exts(self):
  768. """
  769. Get the extents of the square set of tiles visible to the agent
  770. Note: the bottom extent indices are not included in the set
  771. """
  772. # Facing right
  773. if self.agent_dir == 0:
  774. topX = self.agent_pos[0]
  775. topY = self.agent_pos[1] - AGENT_VIEW_SIZE // 2
  776. # Facing down
  777. elif self.agent_dir == 1:
  778. topX = self.agent_pos[0] - AGENT_VIEW_SIZE // 2
  779. topY = self.agent_pos[1]
  780. # Facing left
  781. elif self.agent_dir == 2:
  782. topX = self.agent_pos[0] - AGENT_VIEW_SIZE + 1
  783. topY = self.agent_pos[1] - AGENT_VIEW_SIZE // 2
  784. # Facing up
  785. elif self.agent_dir == 3:
  786. topX = self.agent_pos[0] - AGENT_VIEW_SIZE // 2
  787. topY = self.agent_pos[1] - AGENT_VIEW_SIZE + 1
  788. else:
  789. assert False, "invalid agent direction"
  790. botX = topX + AGENT_VIEW_SIZE
  791. botY = topY + AGENT_VIEW_SIZE
  792. return (topX, topY, botX, botY)
  793. def agent_sees(self, x, y):
  794. """
  795. Check if a grid position is visible to the agent
  796. """
  797. vx, vy = self.get_view_coords(x, y)
  798. if vx < 0 or vy < 0 or vx >= AGENT_VIEW_SIZE or vy >= AGENT_VIEW_SIZE:
  799. return False
  800. obs = self.gen_obs()
  801. obs_grid = Grid.decode(obs['image'])
  802. obs_cell = obs_grid.get(vx, vy)
  803. world_cell = self.grid.get(x, y)
  804. return obs_cell is not None and obs_cell.type == world_cell.type
  805. def step(self, action):
  806. self.step_count += 1
  807. reward = 0
  808. done = False
  809. # Get the position in front of the agent
  810. fwd_pos = self.front_pos
  811. # Get the contents of the cell in front of the agent
  812. fwd_cell = self.grid.get(*fwd_pos)
  813. # Rotate left
  814. if action == self.actions.left:
  815. self.agent_dir -= 1
  816. if self.agent_dir < 0:
  817. self.agent_dir += 4
  818. # Rotate right
  819. elif action == self.actions.right:
  820. self.agent_dir = (self.agent_dir + 1) % 4
  821. # Move forward
  822. elif action == self.actions.forward:
  823. if fwd_cell == None or fwd_cell.can_overlap():
  824. self.agent_pos = fwd_pos
  825. if fwd_cell != None and fwd_cell.type == 'goal':
  826. done = True
  827. reward = self._reward()
  828. # Pick up an object
  829. elif action == self.actions.pickup:
  830. if fwd_cell and fwd_cell.can_pickup():
  831. if self.carrying is None:
  832. self.carrying = fwd_cell
  833. self.carrying.cur_pos = np.array([-1, -1])
  834. self.grid.set(*fwd_pos, None)
  835. # Drop an object
  836. elif action == self.actions.drop:
  837. if not fwd_cell and self.carrying:
  838. self.grid.set(*fwd_pos, self.carrying)
  839. self.carrying.cur_pos = fwd_pos
  840. self.carrying = None
  841. # Toggle/activate an object
  842. elif action == self.actions.toggle:
  843. if fwd_cell:
  844. fwd_cell.toggle(self, fwd_pos)
  845. # Done action (not used by default)
  846. elif action == self.actions.done:
  847. pass
  848. else:
  849. assert False, "unknown action"
  850. if self.step_count >= self.max_steps:
  851. done = True
  852. obs = self.gen_obs()
  853. return obs, reward, done, {}
  854. def gen_obs_grid(self):
  855. """
  856. Generate the sub-grid observed by the agent.
  857. This method also outputs a visibility mask telling us which grid
  858. cells the agent can actually see.
  859. """
  860. topX, topY, botX, botY = self.get_view_exts()
  861. grid = self.grid.slice(topX, topY, AGENT_VIEW_SIZE, AGENT_VIEW_SIZE)
  862. for i in range(self.agent_dir + 1):
  863. grid = grid.rotate_left()
  864. # Process occluders and visibility
  865. # Note that this incurs some performance cost
  866. if not self.see_through_walls:
  867. vis_mask = grid.process_vis(agent_pos=(AGENT_VIEW_SIZE // 2 , AGENT_VIEW_SIZE - 1))
  868. else:
  869. vis_mask = np.ones(shape=(grid.width, grid.height), dtype=np.bool)
  870. # Make it so the agent sees what it's carrying
  871. # We do this by placing the carried object at the agent's position
  872. # in the agent's partially observable view
  873. agent_pos = grid.width // 2, grid.height - 1
  874. if self.carrying:
  875. grid.set(*agent_pos, self.carrying)
  876. else:
  877. grid.set(*agent_pos, None)
  878. return grid, vis_mask
  879. def gen_obs(self):
  880. """
  881. Generate the agent's view (partially observable, low-resolution encoding)
  882. """
  883. grid, vis_mask = self.gen_obs_grid()
  884. # Encode the partially observable view into a numpy array
  885. image = grid.encode()
  886. assert hasattr(self, 'mission'), "environments must define a textual mission string"
  887. # Observations are dictionaries containing:
  888. # - an image (partially observable view of the environment)
  889. # - the agent's direction/orientation (acting as a compass)
  890. # - a textual mission string (instructions for the agent)
  891. obs = {
  892. 'image': image,
  893. 'direction': self.agent_dir,
  894. 'mission': self.mission
  895. }
  896. return obs
  897. def get_obs_render(self, obs, tile_pixels=CELL_PIXELS//2):
  898. """
  899. Render an agent observation for visualization
  900. """
  901. if self.obs_render == None:
  902. self.obs_render = Renderer(
  903. AGENT_VIEW_SIZE * tile_pixels,
  904. AGENT_VIEW_SIZE * tile_pixels
  905. )
  906. r = self.obs_render
  907. r.beginFrame()
  908. grid = Grid.decode(obs)
  909. # Render the whole grid
  910. grid.render(r, tile_pixels)
  911. # Draw the agent
  912. ratio = tile_pixels / CELL_PIXELS
  913. r.push()
  914. r.scale(ratio, ratio)
  915. r.translate(
  916. CELL_PIXELS * (0.5 + AGENT_VIEW_SIZE // 2),
  917. CELL_PIXELS * (AGENT_VIEW_SIZE - 0.5)
  918. )
  919. r.rotate(3 * 90)
  920. r.setLineColor(255, 0, 0)
  921. r.setColor(255, 0, 0)
  922. r.drawPolygon([
  923. (-12, 10),
  924. ( 12, 0),
  925. (-12, -10)
  926. ])
  927. r.pop()
  928. r.endFrame()
  929. return r.getPixmap()
  930. def render(self, mode='human', close=False):
  931. """
  932. Render the whole-grid human view
  933. """
  934. if close:
  935. if self.grid_render:
  936. self.grid_render.close()
  937. return
  938. if self.grid_render is None:
  939. self.grid_render = Renderer(
  940. self.grid_size * CELL_PIXELS,
  941. self.grid_size * CELL_PIXELS,
  942. True if mode == 'human' else False
  943. )
  944. r = self.grid_render
  945. r.beginFrame()
  946. # Render the whole grid
  947. self.grid.render(r, CELL_PIXELS)
  948. # Draw the agent
  949. r.push()
  950. r.translate(
  951. CELL_PIXELS * (self.agent_pos[0] + 0.5),
  952. CELL_PIXELS * (self.agent_pos[1] + 0.5)
  953. )
  954. r.rotate(self.agent_dir * 90)
  955. r.setLineColor(255, 0, 0)
  956. r.setColor(255, 0, 0)
  957. r.drawPolygon([
  958. (-12, 10),
  959. ( 12, 0),
  960. (-12, -10)
  961. ])
  962. r.pop()
  963. # Compute which cells are visible to the agent
  964. _, vis_mask = self.gen_obs_grid()
  965. # Compute the absolute coordinates of the bottom-left corner
  966. # of the agent's view area
  967. f_vec = self.dir_vec
  968. r_vec = self.right_vec
  969. top_left = self.agent_pos + f_vec * (AGENT_VIEW_SIZE-1) - r_vec * (AGENT_VIEW_SIZE // 2)
  970. # For each cell in the visibility mask
  971. for vis_j in range(0, AGENT_VIEW_SIZE):
  972. for vis_i in range(0, AGENT_VIEW_SIZE):
  973. # If this cell is not visible, don't highlight it
  974. if not vis_mask[vis_i, vis_j]:
  975. continue
  976. # Compute the world coordinates of this cell
  977. abs_i, abs_j = top_left - (f_vec * vis_j) + (r_vec * vis_i)
  978. # Highlight the cell
  979. r.fillRect(
  980. abs_i * CELL_PIXELS,
  981. abs_j * CELL_PIXELS,
  982. CELL_PIXELS,
  983. CELL_PIXELS,
  984. 255, 255, 255, 75
  985. )
  986. r.endFrame()
  987. if mode == 'rgb_array':
  988. return r.getArray()
  989. elif mode == 'pixmap':
  990. return r.getPixmap()
  991. return r