minigrid.py 33 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225
  1. import math
  2. import gym
  3. from enum import IntEnum
  4. import numpy as np
  5. from gym import error, spaces, utils
  6. from gym.utils import seeding
  7. from gym_minigrid.rendering import *
  8. # Size in pixels of a cell in the full-scale human view
  9. CELL_PIXELS = 32
  10. # Number of cells (width and height) in the agent view
  11. AGENT_VIEW_SIZE = 7
  12. # Size of the array given as an observation to the agent
  13. OBS_ARRAY_SIZE = (AGENT_VIEW_SIZE, AGENT_VIEW_SIZE, 3)
  14. # Map of color names to RGB values
  15. COLORS = {
  16. 'red' : (255, 0, 0),
  17. 'green' : (0, 255, 0),
  18. 'blue' : (0, 0, 255),
  19. 'purple': (112, 39, 195),
  20. 'yellow': (255, 255, 0),
  21. 'grey' : (100, 100, 100)
  22. }
  23. COLOR_NAMES = sorted(list(COLORS.keys()))
  24. # Used to map colors to integers
  25. COLOR_TO_IDX = {
  26. 'red' : 0,
  27. 'green' : 1,
  28. 'blue' : 2,
  29. 'purple': 3,
  30. 'yellow': 4,
  31. 'grey' : 5
  32. }
  33. IDX_TO_COLOR = dict(zip(COLOR_TO_IDX.values(), COLOR_TO_IDX.keys()))
  34. # Map of object type to integers
  35. OBJECT_TO_IDX = {
  36. 'empty' : 0,
  37. 'wall' : 1,
  38. 'door' : 2,
  39. 'locked_door' : 3,
  40. 'key' : 4,
  41. 'ball' : 5,
  42. 'box' : 6,
  43. 'goal' : 7
  44. }
  45. IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys()))
  46. # Map of agent direction indices to vectors
  47. DIR_TO_VEC = [
  48. # Pointing right (positive X)
  49. np.array((1, 0)),
  50. # Down (positive Y)
  51. np.array((0, 1)),
  52. # Pointing left (negative X)
  53. np.array((-1, 0)),
  54. # Up (negative Y)
  55. np.array((0, -1)),
  56. ]
  57. class WorldObj:
  58. """
  59. Base class for grid world objects
  60. """
  61. def __init__(self, type, color):
  62. assert type in OBJECT_TO_IDX, type
  63. assert color in COLOR_TO_IDX, color
  64. self.type = type
  65. self.color = color
  66. self.contains = None
  67. def can_overlap(self):
  68. """Can the agent overlap with this?"""
  69. return False
  70. def can_pickup(self):
  71. """Can the agent pick this up?"""
  72. return False
  73. def can_contain(self):
  74. """Can this contain another object?"""
  75. return False
  76. def see_behind(self):
  77. """Can the agent see behind this object?"""
  78. return True
  79. def toggle(self, env, pos):
  80. """Method to trigger/toggle an action this object performs"""
  81. return False
  82. def render(self, r):
  83. """Draw this object with the given renderer"""
  84. raise NotImplementedError
  85. def _set_color(self, r):
  86. """Set the color of this object as the active drawing color"""
  87. c = COLORS[self.color]
  88. r.setLineColor(c[0], c[1], c[2])
  89. r.setColor(c[0], c[1], c[2])
  90. class Goal(WorldObj):
  91. def __init__(self):
  92. super(Goal, self).__init__('goal', 'green')
  93. def can_overlap(self):
  94. return True
  95. def render(self, r):
  96. self._set_color(r)
  97. r.drawPolygon([
  98. (0 , CELL_PIXELS),
  99. (CELL_PIXELS, CELL_PIXELS),
  100. (CELL_PIXELS, 0),
  101. (0 , 0)
  102. ])
  103. class Wall(WorldObj):
  104. def __init__(self, color='grey'):
  105. super(Wall, self).__init__('wall', color)
  106. def see_behind(self):
  107. return False
  108. def render(self, r):
  109. self._set_color(r)
  110. r.drawPolygon([
  111. (0 , CELL_PIXELS),
  112. (CELL_PIXELS, CELL_PIXELS),
  113. (CELL_PIXELS, 0),
  114. (0 , 0)
  115. ])
  116. class Door(WorldObj):
  117. def __init__(self, color, is_open=False):
  118. super(Door, self).__init__('door', color)
  119. self.is_open = is_open
  120. def can_overlap(self):
  121. """The agent can only walk over this cell when the door is open"""
  122. return self.is_open
  123. def see_behind(self):
  124. return self.is_open
  125. def toggle(self, env, pos):
  126. self.is_open = not self.is_open
  127. return True
  128. def render(self, r):
  129. c = COLORS[self.color]
  130. r.setLineColor(c[0], c[1], c[2])
  131. r.setColor(0, 0, 0)
  132. if self.is_open:
  133. r.drawPolygon([
  134. (CELL_PIXELS-2, CELL_PIXELS),
  135. (CELL_PIXELS , CELL_PIXELS),
  136. (CELL_PIXELS , 0),
  137. (CELL_PIXELS-2, 0)
  138. ])
  139. return
  140. r.drawPolygon([
  141. (0 , CELL_PIXELS),
  142. (CELL_PIXELS, CELL_PIXELS),
  143. (CELL_PIXELS, 0),
  144. (0 , 0)
  145. ])
  146. r.drawPolygon([
  147. (2 , CELL_PIXELS-2),
  148. (CELL_PIXELS-2, CELL_PIXELS-2),
  149. (CELL_PIXELS-2, 2),
  150. (2 , 2)
  151. ])
  152. r.drawCircle(CELL_PIXELS * 0.75, CELL_PIXELS * 0.5, 2)
  153. class LockedDoor(WorldObj):
  154. def __init__(self, color, is_open=False):
  155. super(LockedDoor, self).__init__('locked_door', color)
  156. self.is_open = is_open
  157. def toggle(self, env, pos):
  158. # If the player has the right key to open the door
  159. if isinstance(env.carrying, Key) and env.carrying.color == self.color:
  160. self.is_open = True
  161. # The key has been used, remove it from the agent
  162. env.carrying = None
  163. return True
  164. return False
  165. def can_overlap(self):
  166. """The agent can only walk over this cell when the door is open"""
  167. return self.is_open
  168. def see_behind(self):
  169. return self.is_open
  170. def render(self, r):
  171. c = COLORS[self.color]
  172. r.setLineColor(c[0], c[1], c[2])
  173. r.setColor(c[0], c[1], c[2], 50)
  174. if self.is_open:
  175. r.drawPolygon([
  176. (CELL_PIXELS-2, CELL_PIXELS),
  177. (CELL_PIXELS , CELL_PIXELS),
  178. (CELL_PIXELS , 0),
  179. (CELL_PIXELS-2, 0)
  180. ])
  181. return
  182. r.drawPolygon([
  183. (0 , CELL_PIXELS),
  184. (CELL_PIXELS, CELL_PIXELS),
  185. (CELL_PIXELS, 0),
  186. (0 , 0)
  187. ])
  188. r.drawPolygon([
  189. (2 , CELL_PIXELS-2),
  190. (CELL_PIXELS-2, CELL_PIXELS-2),
  191. (CELL_PIXELS-2, 2),
  192. (2 , 2)
  193. ])
  194. r.drawLine(
  195. CELL_PIXELS * 0.55,
  196. CELL_PIXELS * 0.5,
  197. CELL_PIXELS * 0.75,
  198. CELL_PIXELS * 0.5
  199. )
  200. class Key(WorldObj):
  201. def __init__(self, color='blue'):
  202. super(Key, self).__init__('key', color)
  203. def can_pickup(self):
  204. return True
  205. def render(self, r):
  206. self._set_color(r)
  207. # Vertical quad
  208. r.drawPolygon([
  209. (16, 10),
  210. (20, 10),
  211. (20, 28),
  212. (16, 28)
  213. ])
  214. # Teeth
  215. r.drawPolygon([
  216. (12, 19),
  217. (16, 19),
  218. (16, 21),
  219. (12, 21)
  220. ])
  221. r.drawPolygon([
  222. (12, 26),
  223. (16, 26),
  224. (16, 28),
  225. (12, 28)
  226. ])
  227. r.drawCircle(18, 9, 6)
  228. r.setLineColor(0, 0, 0)
  229. r.setColor(0, 0, 0)
  230. r.drawCircle(18, 9, 2)
  231. class Ball(WorldObj):
  232. def __init__(self, color='blue'):
  233. super(Ball, self).__init__('ball', color)
  234. def can_pickup(self):
  235. return True
  236. def render(self, r):
  237. self._set_color(r)
  238. r.drawCircle(CELL_PIXELS * 0.5, CELL_PIXELS * 0.5, 10)
  239. class Box(WorldObj):
  240. def __init__(self, color, contains=None):
  241. super(Box, self).__init__('box', color)
  242. self.contains = contains
  243. def can_pickup(self):
  244. return True
  245. def render(self, r):
  246. c = COLORS[self.color]
  247. r.setLineColor(c[0], c[1], c[2])
  248. r.setColor(0, 0, 0)
  249. r.setLineWidth(2)
  250. r.drawPolygon([
  251. (4 , CELL_PIXELS-4),
  252. (CELL_PIXELS-4, CELL_PIXELS-4),
  253. (CELL_PIXELS-4, 4),
  254. (4 , 4)
  255. ])
  256. r.drawLine(
  257. 4,
  258. CELL_PIXELS / 2,
  259. CELL_PIXELS - 4,
  260. CELL_PIXELS / 2
  261. )
  262. r.setLineWidth(1)
  263. def toggle(self, env, pos):
  264. # Replace the box by its contents
  265. env.grid.set(*pos, self.contains)
  266. return True
  267. class Grid:
  268. """
  269. Represent a grid and operations on it
  270. """
  271. def __init__(self, width, height):
  272. assert width >= 4
  273. assert height >= 4
  274. self.width = width
  275. self.height = height
  276. self.grid = [None] * width * height
  277. def __contains__(self, key):
  278. if isinstance(key, WorldObj):
  279. for e in self.grid:
  280. if e is key:
  281. return True
  282. elif isinstance(key, tuple):
  283. for e in self.grid:
  284. if e is None:
  285. continue
  286. if (e.color, e.type) == key:
  287. return True
  288. return False
  289. def __eq__(self, other):
  290. grid1 = self.encode()
  291. grid2 = other.encode()
  292. return np.array_equal(grid2, grid1)
  293. def __ne__(self, other):
  294. return not self == other
  295. def copy(self):
  296. from copy import deepcopy
  297. return deepcopy(self)
  298. def set(self, i, j, v):
  299. assert i >= 0 and i < self.width
  300. assert j >= 0 and j < self.height
  301. self.grid[j * self.width + i] = v
  302. def get(self, i, j):
  303. assert i >= 0 and i < self.width
  304. assert j >= 0 and j < self.height
  305. return self.grid[j * self.width + i]
  306. def horz_wall(self, x, y, length=None):
  307. if length is None:
  308. length = self.width - x
  309. for i in range(0, length):
  310. self.set(x + i, y, Wall())
  311. def vert_wall(self, x, y, length=None):
  312. if length is None:
  313. length = self.height - y
  314. for j in range(0, length):
  315. self.set(x, y + j, Wall())
  316. def wall_rect(self, x, y, w, h):
  317. self.horz_wall(x, y, w)
  318. self.horz_wall(x, y+h-1, w)
  319. self.vert_wall(x, y, h)
  320. self.vert_wall(x+w-1, y, h)
  321. def rotate_left(self):
  322. """
  323. Rotate the grid to the left (counter-clockwise)
  324. """
  325. grid = Grid(self.width, self.height)
  326. for j in range(0, self.height):
  327. for i in range(0, self.width):
  328. v = self.get(self.width - 1 - j, i)
  329. grid.set(i, j, v)
  330. return grid
  331. def slice(self, topX, topY, width, height):
  332. """
  333. Get a subset of the grid
  334. """
  335. grid = Grid(width, height)
  336. for j in range(0, height):
  337. for i in range(0, width):
  338. x = topX + i
  339. y = topY + j
  340. if x >= 0 and x < self.width and \
  341. y >= 0 and y < self.height:
  342. v = self.get(x, y)
  343. else:
  344. v = Wall()
  345. grid.set(i, j, v)
  346. return grid
  347. def render(self, r, tile_size):
  348. """
  349. Render this grid at a given scale
  350. :param r: target renderer object
  351. :param tile_size: tile size in pixels
  352. """
  353. assert r.width == self.width * tile_size
  354. assert r.height == self.height * tile_size
  355. # Total grid size at native scale
  356. widthPx = self.width * CELL_PIXELS
  357. heightPx = self.height * CELL_PIXELS
  358. r.push()
  359. # Internally, we draw at the "large" full-grid resolution, but we
  360. # use the renderer to scale back to the desired size
  361. r.scale(tile_size / CELL_PIXELS, tile_size / CELL_PIXELS)
  362. # Draw the background of the in-world cells black
  363. r.fillRect(
  364. 0,
  365. 0,
  366. widthPx,
  367. heightPx,
  368. 0, 0, 0
  369. )
  370. # Draw grid lines
  371. r.setLineColor(100, 100, 100)
  372. for rowIdx in range(0, self.height):
  373. y = CELL_PIXELS * rowIdx
  374. r.drawLine(0, y, widthPx, y)
  375. for colIdx in range(0, self.width):
  376. x = CELL_PIXELS * colIdx
  377. r.drawLine(x, 0, x, heightPx)
  378. # Render the grid
  379. for j in range(0, self.height):
  380. for i in range(0, self.width):
  381. cell = self.get(i, j)
  382. if cell == None:
  383. continue
  384. r.push()
  385. r.translate(i * CELL_PIXELS, j * CELL_PIXELS)
  386. cell.render(r)
  387. r.pop()
  388. r.pop()
  389. def encode(self):
  390. """
  391. Produce a compact numpy encoding of the grid
  392. """
  393. codeSize = self.width * self.height * 3
  394. array = np.zeros(shape=(self.width, self.height, 3), dtype='uint8')
  395. for j in range(0, self.height):
  396. for i in range(0, self.width):
  397. v = self.get(i, j)
  398. if v == None:
  399. continue
  400. array[i, j, 0] = OBJECT_TO_IDX[v.type]
  401. array[i, j, 1] = COLOR_TO_IDX[v.color]
  402. if hasattr(v, 'is_open') and v.is_open:
  403. array[i, j, 2] = 1
  404. return array
  405. def decode(array):
  406. """
  407. Decode an array grid encoding back into a grid
  408. """
  409. width = array.shape[0]
  410. height = array.shape[1]
  411. assert array.shape[2] == 3
  412. grid = Grid(width, height)
  413. for j in range(0, height):
  414. for i in range(0, width):
  415. typeIdx = array[i, j, 0]
  416. colorIdx = array[i, j, 1]
  417. openIdx = array[i, j, 2]
  418. if typeIdx == 0:
  419. continue
  420. objType = IDX_TO_OBJECT[typeIdx]
  421. color = IDX_TO_COLOR[colorIdx]
  422. is_open = True if openIdx == 1 else 0
  423. if objType == 'wall':
  424. v = Wall(color)
  425. elif objType == 'ball':
  426. v = Ball(color)
  427. elif objType == 'key':
  428. v = Key(color)
  429. elif objType == 'box':
  430. v = Box(color)
  431. elif objType == 'door':
  432. v = Door(color, is_open)
  433. elif objType == 'locked_door':
  434. v = LockedDoor(color, is_open)
  435. elif objType == 'goal':
  436. v = Goal()
  437. else:
  438. assert False, "unknown obj type in decode '%s'" % objType
  439. grid.set(i, j, v)
  440. return grid
  441. def process_vis(grid, agent_pos):
  442. mask = np.zeros(shape=(grid.width, grid.height), dtype=np.bool)
  443. mask[agent_pos[0], agent_pos[1]] = True
  444. for j in reversed(range(1, grid.height)):
  445. for i in range(0, grid.width-1):
  446. if not mask[i, j]:
  447. continue
  448. cell = grid.get(i, j)
  449. if cell and not cell.see_behind():
  450. continue
  451. mask[i+1, j] = True
  452. mask[i+1, j-1] = True
  453. mask[i, j-1] = True
  454. for i in reversed(range(1, grid.width)):
  455. if not mask[i, j]:
  456. continue
  457. cell = grid.get(i, j)
  458. if cell and not cell.see_behind():
  459. continue
  460. mask[i-1, j-1] = True
  461. mask[i-1, j] = True
  462. mask[i, j-1] = True
  463. for j in range(0, grid.height):
  464. for i in range(0, grid.width):
  465. if not mask[i, j]:
  466. grid.set(i, j, None)
  467. return mask
  468. class MiniGridEnv(gym.Env):
  469. """
  470. 2D grid world game environment
  471. """
  472. metadata = {
  473. 'render.modes': ['human', 'rgb_array', 'pixmap'],
  474. 'video.frames_per_second' : 10
  475. }
  476. # Enumeration of possible actions
  477. class Actions(IntEnum):
  478. # Turn left, turn right, move forward
  479. left = 0
  480. right = 1
  481. forward = 2
  482. # Pick up an object
  483. pickup = 3
  484. # Drop an object
  485. drop = 4
  486. # Toggle/activate an object
  487. toggle = 5
  488. # Done completing task
  489. done = 6
  490. def __init__(
  491. self,
  492. grid_size=16,
  493. max_steps=100,
  494. see_through_walls=False,
  495. seed=1337
  496. ):
  497. # Action enumeration for this environment
  498. self.actions = MiniGridEnv.Actions
  499. # Actions are discrete integer values
  500. self.action_space = spaces.Discrete(len(self.actions))
  501. # Observations are dictionaries containing an
  502. # encoding of the grid and a textual 'mission' string
  503. self.observation_space = spaces.Box(
  504. low=0,
  505. high=255,
  506. shape=OBS_ARRAY_SIZE,
  507. dtype='uint8'
  508. )
  509. self.observation_space = spaces.Dict({
  510. 'image': self.observation_space
  511. })
  512. # Range of possible rewards
  513. self.reward_range = (0, 1)
  514. # Renderer object used to render the whole grid (full-scale)
  515. self.grid_render = None
  516. # Renderer used to render observations (small-scale agent view)
  517. self.obs_render = None
  518. # Environment configuration
  519. self.grid_size = grid_size
  520. self.max_steps = max_steps
  521. self.see_through_walls = see_through_walls
  522. # Starting position and direction for the agent
  523. self.start_pos = None
  524. self.start_dir = None
  525. # Initialize the RNG
  526. self.seed(seed=seed)
  527. # Initialize the state
  528. self.reset()
  529. def reset(self):
  530. # Generate a new random grid at the start of each episode
  531. # To keep the same grid for each episode, call env.seed() with
  532. # the same seed before calling env.reset()
  533. self._gen_grid(self.grid_size, self.grid_size)
  534. # These fields should be defined by _gen_grid
  535. assert self.start_pos is not None
  536. assert self.start_dir is not None
  537. # Check that the agent doesn't overlap with an object
  538. assert self.grid.get(*self.start_pos) is None
  539. # Place the agent in the starting position and direction
  540. self.agent_pos = self.start_pos
  541. self.agent_dir = self.start_dir
  542. # Item picked up, being carried, initially nothing
  543. self.carrying = None
  544. # Step count since episode start
  545. self.step_count = 0
  546. # Return first observation
  547. obs = self.gen_obs()
  548. return obs
  549. def seed(self, seed=1337):
  550. # Seed the random number generator
  551. self.np_random, _ = seeding.np_random(seed)
  552. return [seed]
  553. @property
  554. def steps_remaining(self):
  555. return self.max_steps - self.step_count
  556. def __str__(self):
  557. """
  558. Produce a pretty string of the environment's grid along with the agent.
  559. The agent is represented by `⏩`. A grid pixel is represented by 2-character
  560. string, the first one for the object and the second one for the color.
  561. """
  562. from copy import deepcopy
  563. def rotate_left(array):
  564. new_array = deepcopy(array)
  565. for i in range(len(array)):
  566. for j in range(len(array[0])):
  567. new_array[j][len(array[0])-1-i] = array[i][j]
  568. return new_array
  569. def vertically_symmetrize(array):
  570. new_array = deepcopy(array)
  571. for i in range(len(array)):
  572. for j in range(len(array[0])):
  573. new_array[i][len(array[0])-1-j] = array[i][j]
  574. return new_array
  575. # Map of object id to short string
  576. OBJECT_IDX_TO_IDS = {
  577. 0: ' ',
  578. 1: 'W',
  579. 2: 'D',
  580. 3: 'L',
  581. 4: 'K',
  582. 5: 'B',
  583. 6: 'X',
  584. 7: 'G'
  585. }
  586. # Short string for opened door
  587. OPENDED_DOOR_IDS = '_'
  588. # Map of color id to short string
  589. COLOR_IDX_TO_IDS = {
  590. 0: 'R',
  591. 1: 'G',
  592. 2: 'B',
  593. 3: 'P',
  594. 4: 'Y',
  595. 5: 'E'
  596. }
  597. # Map agent's direction to short string
  598. AGENT_DIR_TO_IDS = {
  599. 0: '⏩',
  600. 1: '⏬',
  601. 2: '⏪',
  602. 3: '⏫'
  603. }
  604. array = self.grid.encode()
  605. array = rotate_left(array)
  606. array = vertically_symmetrize(array)
  607. new_array = []
  608. for line in array:
  609. new_line = []
  610. for pixel in line:
  611. # If the door is opened
  612. if pixel[0] in [2, 3] and pixel[2] == 1:
  613. object_ids = OPENDED_DOOR_IDS
  614. else:
  615. object_ids = OBJECT_IDX_TO_IDS[pixel[0]]
  616. # If no object
  617. if pixel[0] == 0:
  618. color_ids = ' '
  619. else:
  620. color_ids = COLOR_IDX_TO_IDS[pixel[1]]
  621. new_line.append(object_ids + color_ids)
  622. new_array.append(new_line)
  623. # Add the agent
  624. new_array[self.agent_pos[1]][self.agent_pos[0]] = AGENT_DIR_TO_IDS[self.agent_dir]
  625. return "\n".join([" ".join(line) for line in new_array])
  626. def _gen_grid(self, width, height):
  627. assert False, "_gen_grid needs to be implemented by each environment"
  628. def _rand_int(self, low, high):
  629. """
  630. Generate random integer in [low,high[
  631. """
  632. return self.np_random.randint(low, high)
  633. def _rand_bool(self):
  634. """
  635. Generate random boolean value
  636. """
  637. return (self.np_random.randint(0, 2) == 0)
  638. def _rand_elem(self, iterable):
  639. """
  640. Pick a random element in a list
  641. """
  642. lst = list(iterable)
  643. idx = self._rand_int(0, len(lst))
  644. return lst[idx]
  645. def _rand_subset(self, iterable, num_elems):
  646. """
  647. Sample a random subset of distinct elements of a list
  648. """
  649. lst = list(iterable)
  650. assert num_elems <= len(lst)
  651. out = []
  652. while len(out) < num_elems:
  653. elem = self._rand_elem(lst)
  654. lst.remove(elem)
  655. out.append(elem)
  656. return out
  657. def _rand_color(self):
  658. """
  659. Generate a random color name (string)
  660. """
  661. return self._rand_elem(COLOR_NAMES)
  662. def _rand_pos(self, xLow, xHigh, yLow, yHigh):
  663. """
  664. Generate a random (x,y) position tuple
  665. """
  666. return (
  667. self.np_random.randint(xLow, xHigh),
  668. self.np_random.randint(yLow, yHigh)
  669. )
  670. def place_obj(self, obj, top=None, size=None, reject_fn=None):
  671. """
  672. Place an object at an empty position in the grid
  673. :param top: top-left position of the rectangle where to place
  674. :param size: size of the rectangle where to place
  675. :param reject_fn: function to filter out potential positions
  676. """
  677. if top is None:
  678. top = (0, 0)
  679. if size is None:
  680. size = (self.grid.width, self.grid.height)
  681. while True:
  682. pos = np.array((
  683. self._rand_int(top[0], top[0] + size[0]),
  684. self._rand_int(top[1], top[1] + size[1])
  685. ))
  686. # Don't place the object on top of another object
  687. if self.grid.get(*pos) != None:
  688. continue
  689. # Don't place the object where the agent is
  690. if np.array_equal(pos, self.start_pos):
  691. continue
  692. # Check if there is a filtering criterion
  693. if reject_fn and reject_fn(self, pos):
  694. continue
  695. break
  696. self.grid.set(*pos, obj)
  697. return pos
  698. def place_agent(self, top=None, size=None, rand_dir=True):
  699. """
  700. Set the agent's starting point at an empty position in the grid
  701. """
  702. self.start_pos = None
  703. pos = self.place_obj(None, top, size)
  704. self.start_pos = pos
  705. if rand_dir:
  706. self.start_dir = self._rand_int(0, 4)
  707. return pos
  708. def get_dir_vec(self):
  709. """
  710. Get the direction vector for the agent, pointing in the direction
  711. of forward movement.
  712. """
  713. assert self.agent_dir >= 0 and self.agent_dir < 4
  714. return DIR_TO_VEC[self.agent_dir]
  715. def get_right_vec(self):
  716. """
  717. Get the vector pointing to the right of the agent.
  718. """
  719. dx, dy = self.get_dir_vec()
  720. return np.array((-dy, dx))
  721. def get_view_coords(self, i, j):
  722. """
  723. Translate and rotate absolute grid coordinates (i, j) into the
  724. agent's partially observable view (sub-grid). Note that the resulting
  725. coordinates may be negative or outside of the agent's view size.
  726. """
  727. ax, ay = self.agent_pos
  728. dx, dy = self.get_dir_vec()
  729. rx, ry = self.get_right_vec()
  730. # Compute the absolute coordinates of the top-left view corner
  731. sz = AGENT_VIEW_SIZE
  732. hs = AGENT_VIEW_SIZE // 2
  733. tx = ax + (dx * (sz-1)) - (rx * hs)
  734. ty = ay + (dy * (sz-1)) - (ry * hs)
  735. lx = i - tx
  736. ly = j - ty
  737. # Project the coordinates of the object relative to the top-left
  738. # corner onto the agent's own coordinate system
  739. vx = (rx*lx + ry*ly)
  740. vy = -(dx*lx + dy*ly)
  741. return vx, vy
  742. def get_view_exts(self):
  743. """
  744. Get the extents of the square set of tiles visible to the agent
  745. Note: the bottom extent indices are not included in the set
  746. """
  747. # Facing right
  748. if self.agent_dir == 0:
  749. topX = self.agent_pos[0]
  750. topY = self.agent_pos[1] - AGENT_VIEW_SIZE // 2
  751. # Facing down
  752. elif self.agent_dir == 1:
  753. topX = self.agent_pos[0] - AGENT_VIEW_SIZE // 2
  754. topY = self.agent_pos[1]
  755. # Facing left
  756. elif self.agent_dir == 2:
  757. topX = self.agent_pos[0] - AGENT_VIEW_SIZE + 1
  758. topY = self.agent_pos[1] - AGENT_VIEW_SIZE // 2
  759. # Facing up
  760. elif self.agent_dir == 3:
  761. topX = self.agent_pos[0] - AGENT_VIEW_SIZE // 2
  762. topY = self.agent_pos[1] - AGENT_VIEW_SIZE + 1
  763. else:
  764. assert False, "invalid agent direction"
  765. botX = topX + AGENT_VIEW_SIZE
  766. botY = topY + AGENT_VIEW_SIZE
  767. return (topX, topY, botX, botY)
  768. def agent_sees(self, x, y):
  769. """
  770. Check if a grid position is visible to the agent
  771. """
  772. vx, vy = self.get_view_coords(x, y)
  773. if vx < 0 or vy < 0 or vx >= AGENT_VIEW_SIZE or vy >= AGENT_VIEW_SIZE:
  774. return False
  775. obs = self.gen_obs()
  776. obs_grid = Grid.decode(obs['image'])
  777. obs_cell = obs_grid.get(vx, vy)
  778. world_cell = self.grid.get(x, y)
  779. return obs_cell is not None and obs_cell.type == world_cell.type
  780. def step(self, action):
  781. self.step_count += 1
  782. reward = 0
  783. done = False
  784. # Get the position in front of the agent
  785. fwd_pos = self.agent_pos + self.get_dir_vec()
  786. # Get the contents of the cell in front of the agent
  787. fwd_cell = self.grid.get(*fwd_pos)
  788. # Rotate left
  789. if action == self.actions.left:
  790. self.agent_dir -= 1
  791. if self.agent_dir < 0:
  792. self.agent_dir += 4
  793. # Rotate right
  794. elif action == self.actions.right:
  795. self.agent_dir = (self.agent_dir + 1) % 4
  796. # Move forward
  797. elif action == self.actions.forward:
  798. if fwd_cell == None or fwd_cell.can_overlap():
  799. self.agent_pos = fwd_pos
  800. if fwd_cell != None and fwd_cell.type == 'goal':
  801. done = True
  802. reward = 1
  803. # Pick up an object
  804. elif action == self.actions.pickup:
  805. if fwd_cell and fwd_cell.can_pickup():
  806. if self.carrying is None:
  807. self.carrying = fwd_cell
  808. self.grid.set(*fwd_pos, None)
  809. # Drop an object
  810. elif action == self.actions.drop:
  811. if not fwd_cell and self.carrying:
  812. self.grid.set(*fwd_pos, self.carrying)
  813. self.carrying = None
  814. # Toggle/activate an object
  815. elif action == self.actions.toggle:
  816. if fwd_cell:
  817. fwd_cell.toggle(self, fwd_pos)
  818. # Done action (not used by default)
  819. elif action == self.actions.done:
  820. pass
  821. else:
  822. assert False, "unknown action"
  823. if self.step_count >= self.max_steps:
  824. done = True
  825. obs = self.gen_obs()
  826. return obs, reward, done, {}
  827. def gen_obs_grid(self):
  828. """
  829. Generate the sub-grid observed by the agent.
  830. This method also outputs a visibility mask telling us which grid
  831. cells the agent can actually see.
  832. """
  833. topX, topY, botX, botY = self.get_view_exts()
  834. grid = self.grid.slice(topX, topY, AGENT_VIEW_SIZE, AGENT_VIEW_SIZE)
  835. for i in range(self.agent_dir + 1):
  836. grid = grid.rotate_left()
  837. # Process occluders and visibility
  838. # Note that this incurs some performance cost
  839. if not self.see_through_walls:
  840. vis_mask = grid.process_vis(agent_pos=(AGENT_VIEW_SIZE // 2 , AGENT_VIEW_SIZE - 1))
  841. else:
  842. vis_mask = np.ones(shape=(grid.width, grid.height), dtype=np.bool)
  843. # Make it so the agent sees what it's carrying
  844. # We do this by placing the carried object at the agent's position
  845. # in the agent's partially observable view
  846. agent_pos = grid.width // 2, grid.height - 1
  847. if self.carrying:
  848. grid.set(*agent_pos, self.carrying)
  849. else:
  850. grid.set(*agent_pos, None)
  851. return grid, vis_mask
  852. def gen_obs(self):
  853. """
  854. Generate the agent's view (partially observable, low-resolution encoding)
  855. """
  856. grid, vis_mask = self.gen_obs_grid()
  857. # Encode the partially observable view into a numpy array
  858. image = grid.encode()
  859. assert hasattr(self, 'mission'), "environments must define a textual mission string"
  860. # Observations are dictionaries containing:
  861. # - an image (partially observable view of the environment)
  862. # - the agent's direction/orientation (acting as a compass)
  863. # - a textual mission string (instructions for the agent)
  864. obs = {
  865. 'image': image,
  866. 'direction': self.agent_dir,
  867. 'mission': self.mission
  868. }
  869. return obs
  870. def get_obs_render(self, obs, tile_pixels=CELL_PIXELS//2):
  871. """
  872. Render an agent observation for visualization
  873. """
  874. if self.obs_render == None:
  875. self.obs_render = Renderer(
  876. AGENT_VIEW_SIZE * tile_pixels,
  877. AGENT_VIEW_SIZE * tile_pixels
  878. )
  879. r = self.obs_render
  880. r.beginFrame()
  881. grid = Grid.decode(obs)
  882. # Render the whole grid
  883. grid.render(r, tile_pixels)
  884. # Draw the agent
  885. ratio = tile_pixels / CELL_PIXELS
  886. r.push()
  887. r.scale(ratio, ratio)
  888. r.translate(
  889. CELL_PIXELS * (0.5 + AGENT_VIEW_SIZE // 2),
  890. CELL_PIXELS * (AGENT_VIEW_SIZE - 0.5)
  891. )
  892. r.rotate(3 * 90)
  893. r.setLineColor(255, 0, 0)
  894. r.setColor(255, 0, 0)
  895. r.drawPolygon([
  896. (-12, 10),
  897. ( 12, 0),
  898. (-12, -10)
  899. ])
  900. r.pop()
  901. r.endFrame()
  902. return r.getPixmap()
  903. def render(self, mode='human', close=False):
  904. """
  905. Render the whole-grid human view
  906. """
  907. if close:
  908. if self.grid_render:
  909. self.grid_render.close()
  910. return
  911. if self.grid_render is None:
  912. self.grid_render = Renderer(
  913. self.grid_size * CELL_PIXELS,
  914. self.grid_size * CELL_PIXELS,
  915. True if mode == 'human' else False
  916. )
  917. r = self.grid_render
  918. r.beginFrame()
  919. # Render the whole grid
  920. self.grid.render(r, CELL_PIXELS)
  921. # Draw the agent
  922. r.push()
  923. r.translate(
  924. CELL_PIXELS * (self.agent_pos[0] + 0.5),
  925. CELL_PIXELS * (self.agent_pos[1] + 0.5)
  926. )
  927. r.rotate(self.agent_dir * 90)
  928. r.setLineColor(255, 0, 0)
  929. r.setColor(255, 0, 0)
  930. r.drawPolygon([
  931. (-12, 10),
  932. ( 12, 0),
  933. (-12, -10)
  934. ])
  935. r.pop()
  936. # Compute which cells are visible to the agent
  937. _, vis_mask = self.gen_obs_grid()
  938. # Compute the absolute coordinates of the bottom-left corner
  939. # of the agent's view area
  940. f_vec = self.get_dir_vec()
  941. r_vec = self.get_right_vec()
  942. top_left = self.agent_pos + f_vec * (AGENT_VIEW_SIZE-1) - r_vec * (AGENT_VIEW_SIZE // 2)
  943. # For each cell in the visibility mask
  944. for vis_j in range(0, AGENT_VIEW_SIZE):
  945. for vis_i in range(0, AGENT_VIEW_SIZE):
  946. # If this cell is not visible, don't highlight it
  947. if not vis_mask[vis_i, vis_j]:
  948. continue
  949. # Compute the world coordinates of this cell
  950. abs_i, abs_j = top_left - (f_vec * vis_j) + (r_vec * vis_i)
  951. # Highlight the cell
  952. r.fillRect(
  953. abs_i * CELL_PIXELS,
  954. abs_j * CELL_PIXELS,
  955. CELL_PIXELS,
  956. CELL_PIXELS,
  957. 255, 255, 255, 75
  958. )
  959. r.endFrame()
  960. if mode == 'rgb_array':
  961. return r.getArray()
  962. elif mode == 'pixmap':
  963. return r.getPixmap()
  964. return r