minigrid.py 28 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066
  1. import math
  2. import gym
  3. from enum import IntEnum
  4. import numpy as np
  5. from gym import error, spaces, utils
  6. from gym.utils import seeding
  7. from gym_minigrid.rendering import *
  8. # Size in pixels of a cell in the full-scale human view
  9. CELL_PIXELS = 32
  10. # Number of cells (width and height) in the agent view
  11. AGENT_VIEW_SIZE = 7
  12. # Size of the array given as an observation to the agent
  13. OBS_ARRAY_SIZE = (AGENT_VIEW_SIZE, AGENT_VIEW_SIZE, 3)
  14. # Map of color names to RGB values
  15. COLORS = {
  16. 'red' : (255, 0, 0),
  17. 'green' : (0, 255, 0),
  18. 'blue' : (0, 0, 255),
  19. 'purple': (112, 39, 195),
  20. 'yellow': (255, 255, 0),
  21. 'grey' : (100, 100, 100)
  22. }
  23. COLOR_NAMES = sorted(list(COLORS.keys()))
  24. # Used to map colors to integers
  25. COLOR_TO_IDX = {
  26. 'red' : 0,
  27. 'green' : 1,
  28. 'blue' : 2,
  29. 'purple': 3,
  30. 'yellow': 4,
  31. 'grey' : 5
  32. }
  33. IDX_TO_COLOR = dict(zip(COLOR_TO_IDX.values(), COLOR_TO_IDX.keys()))
  34. # Map of object type to integers
  35. OBJECT_TO_IDX = {
  36. 'empty' : 0,
  37. 'wall' : 1,
  38. 'door' : 2,
  39. 'locked_door' : 3,
  40. 'key' : 4,
  41. 'ball' : 5,
  42. 'box' : 6,
  43. 'goal' : 7
  44. }
  45. IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys()))
  46. class WorldObj:
  47. """
  48. Base class for grid world objects
  49. """
  50. def __init__(self, type, color):
  51. assert type in OBJECT_TO_IDX, type
  52. assert color in COLOR_TO_IDX, color
  53. self.type = type
  54. self.color = color
  55. self.contains = None
  56. def canOverlap(self):
  57. """Can the agent overlap with this?"""
  58. return False
  59. def canPickup(self):
  60. """Can the agent pick this up?"""
  61. return False
  62. def canContain(self):
  63. """Can this contain another object?"""
  64. return False
  65. def toggle(self, env, pos):
  66. """Method to trigger/toggle an action this object performs"""
  67. return False
  68. def render(self, r):
  69. assert False
  70. def _setColor(self, r):
  71. c = COLORS[self.color]
  72. r.setLineColor(c[0], c[1], c[2])
  73. r.setColor(c[0], c[1], c[2])
  74. class Goal(WorldObj):
  75. def __init__(self):
  76. super(Goal, self).__init__('goal', 'green')
  77. def canOverlap(self):
  78. return True
  79. def render(self, r):
  80. self._setColor(r)
  81. r.drawPolygon([
  82. (0 , CELL_PIXELS),
  83. (CELL_PIXELS, CELL_PIXELS),
  84. (CELL_PIXELS, 0),
  85. (0 , 0)
  86. ])
  87. class Wall(WorldObj):
  88. def __init__(self, color='grey'):
  89. super(Wall, self).__init__('wall', color)
  90. def render(self, r):
  91. self._setColor(r)
  92. r.drawPolygon([
  93. (0 , CELL_PIXELS),
  94. (CELL_PIXELS, CELL_PIXELS),
  95. (CELL_PIXELS, 0),
  96. (0 , 0)
  97. ])
  98. class Door(WorldObj):
  99. def __init__(self, color, isOpen=False):
  100. super(Door, self).__init__('door', color)
  101. self.isOpen = isOpen
  102. def render(self, r):
  103. c = COLORS[self.color]
  104. r.setLineColor(c[0], c[1], c[2])
  105. r.setColor(0, 0, 0)
  106. if self.isOpen:
  107. r.drawPolygon([
  108. (CELL_PIXELS-2, CELL_PIXELS),
  109. (CELL_PIXELS , CELL_PIXELS),
  110. (CELL_PIXELS , 0),
  111. (CELL_PIXELS-2, 0)
  112. ])
  113. return
  114. r.drawPolygon([
  115. (0 , CELL_PIXELS),
  116. (CELL_PIXELS, CELL_PIXELS),
  117. (CELL_PIXELS, 0),
  118. (0 , 0)
  119. ])
  120. r.drawPolygon([
  121. (2 , CELL_PIXELS-2),
  122. (CELL_PIXELS-2, CELL_PIXELS-2),
  123. (CELL_PIXELS-2, 2),
  124. (2 , 2)
  125. ])
  126. r.drawCircle(CELL_PIXELS * 0.75, CELL_PIXELS * 0.5, 2)
  127. def toggle(self, env, pos):
  128. if not self.isOpen:
  129. self.isOpen = True
  130. return True
  131. return False
  132. def canOverlap(self):
  133. """The agent can only walk over this cell when the door is open"""
  134. return self.isOpen
  135. class LockedDoor(WorldObj):
  136. def __init__(self, color, isOpen=False):
  137. super(LockedDoor, self).__init__('locked_door', color)
  138. self.isOpen = isOpen
  139. def render(self, r):
  140. c = COLORS[self.color]
  141. r.setLineColor(c[0], c[1], c[2])
  142. r.setColor(c[0], c[1], c[2], 50)
  143. if self.isOpen:
  144. r.drawPolygon([
  145. (CELL_PIXELS-2, CELL_PIXELS),
  146. (CELL_PIXELS , CELL_PIXELS),
  147. (CELL_PIXELS , 0),
  148. (CELL_PIXELS-2, 0)
  149. ])
  150. return
  151. r.drawPolygon([
  152. (0 , CELL_PIXELS),
  153. (CELL_PIXELS, CELL_PIXELS),
  154. (CELL_PIXELS, 0),
  155. (0 , 0)
  156. ])
  157. r.drawPolygon([
  158. (2 , CELL_PIXELS-2),
  159. (CELL_PIXELS-2, CELL_PIXELS-2),
  160. (CELL_PIXELS-2, 2),
  161. (2 , 2)
  162. ])
  163. r.drawLine(
  164. CELL_PIXELS * 0.55,
  165. CELL_PIXELS * 0.5,
  166. CELL_PIXELS * 0.75,
  167. CELL_PIXELS * 0.5
  168. )
  169. def toggle(self, env, pos):
  170. # If the player has the right key to open the door
  171. if isinstance(env.carrying, Key) and env.carrying.color == self.color:
  172. self.isOpen = True
  173. # The key has been used, remove it from the agent
  174. env.carrying = None
  175. return True
  176. return False
  177. def canOverlap(self):
  178. """The agent can only walk over this cell when the door is open"""
  179. return self.isOpen
  180. class Key(WorldObj):
  181. def __init__(self, color='blue'):
  182. super(Key, self).__init__('key', color)
  183. def canPickup(self):
  184. return True
  185. def render(self, r):
  186. self._setColor(r)
  187. # Vertical quad
  188. r.drawPolygon([
  189. (16, 10),
  190. (20, 10),
  191. (20, 28),
  192. (16, 28)
  193. ])
  194. # Teeth
  195. r.drawPolygon([
  196. (12, 19),
  197. (16, 19),
  198. (16, 21),
  199. (12, 21)
  200. ])
  201. r.drawPolygon([
  202. (12, 26),
  203. (16, 26),
  204. (16, 28),
  205. (12, 28)
  206. ])
  207. r.drawCircle(18, 9, 6)
  208. r.setLineColor(0, 0, 0)
  209. r.setColor(0, 0, 0)
  210. r.drawCircle(18, 9, 2)
  211. class Ball(WorldObj):
  212. def __init__(self, color='blue'):
  213. super(Ball, self).__init__('ball', color)
  214. def canPickup(self):
  215. return True
  216. def render(self, r):
  217. self._setColor(r)
  218. r.drawCircle(CELL_PIXELS * 0.5, CELL_PIXELS * 0.5, 10)
  219. class Box(WorldObj):
  220. def __init__(self, color, contains=None):
  221. super(Box, self).__init__('box', color)
  222. self.contains = contains
  223. def canPickup(self):
  224. return True
  225. def render(self, r):
  226. c = COLORS[self.color]
  227. r.setLineColor(c[0], c[1], c[2])
  228. r.setColor(0, 0, 0)
  229. r.setLineWidth(2)
  230. r.drawPolygon([
  231. (4 , CELL_PIXELS-4),
  232. (CELL_PIXELS-4, CELL_PIXELS-4),
  233. (CELL_PIXELS-4, 4),
  234. (4 , 4)
  235. ])
  236. r.drawLine(
  237. 4,
  238. CELL_PIXELS / 2,
  239. CELL_PIXELS - 4,
  240. CELL_PIXELS / 2
  241. )
  242. r.setLineWidth(1)
  243. def toggle(self, env, pos):
  244. # Replace the box by its contents
  245. env.grid.set(*pos, self.contains)
  246. return True
  247. class Grid:
  248. """
  249. Represent a grid and operations on it
  250. """
  251. def __init__(self, width, height):
  252. assert width >= 4
  253. assert height >= 4
  254. self.width = width
  255. self.height = height
  256. self.grid = [None] * width * height
  257. def __contains__(self, key):
  258. if isinstance(key, WorldObj):
  259. for e in self.grid:
  260. if e is key:
  261. return True
  262. elif isinstance(key, tuple):
  263. for e in self.grid:
  264. if e is None:
  265. continue
  266. if (e.color, e.type) == key:
  267. return True
  268. return False
  269. def __eq__(self, other):
  270. grid1 = self.encode()
  271. grid2 = other.encode()
  272. return np.array_equal(grid2, grid1)
  273. def __ne__(self, other):
  274. return not self == other
  275. def copy(self):
  276. from copy import deepcopy
  277. return deepcopy(self)
  278. def set(self, i, j, v):
  279. assert i >= 0 and i < self.width
  280. assert j >= 0 and j < self.height
  281. self.grid[j * self.width + i] = v
  282. def get(self, i, j):
  283. assert i >= 0 and i < self.width
  284. assert j >= 0 and j < self.height
  285. return self.grid[j * self.width + i]
  286. def horzWall(self, x, y, length=None):
  287. if length is None:
  288. length = self.width - x
  289. for i in range(0, length):
  290. self.set(x + i, y, Wall())
  291. def vertWall(self, x, y, length=None):
  292. if length is None:
  293. length = self.height - y
  294. for j in range(0, length):
  295. self.set(x, y + j, Wall())
  296. def wallRect(self, x, y, w, h):
  297. self.horzWall(x, y, w)
  298. self.horzWall(x, y+h-1, w)
  299. self.vertWall(x, y, h)
  300. self.vertWall(x+w-1, y, h)
  301. def rotateLeft(self):
  302. """
  303. Rotate the grid to the left (counter-clockwise)
  304. """
  305. grid = Grid(self.width, self.height)
  306. for j in range(0, self.height):
  307. for i in range(0, self.width):
  308. v = self.get(self.width - 1 - j, i)
  309. grid.set(i, j, v)
  310. return grid
  311. def slice(self, topX, topY, width, height):
  312. """
  313. Get a subset of the grid
  314. """
  315. grid = Grid(width, height)
  316. for j in range(0, height):
  317. for i in range(0, width):
  318. x = topX + i
  319. y = topY + j
  320. if x >= 0 and x < self.width and \
  321. y >= 0 and y < self.height:
  322. v = self.get(x, y)
  323. else:
  324. v = Wall()
  325. grid.set(i, j, v)
  326. return grid
  327. def render(self, r, tileSize):
  328. """
  329. Render this grid at a given scale
  330. :param r: target renderer object
  331. :param tileSize: tile size in pixels
  332. """
  333. assert r.width == self.width * tileSize
  334. assert r.height == self.height * tileSize
  335. # Total grid size at native scale
  336. widthPx = self.width * CELL_PIXELS
  337. heightPx = self.height * CELL_PIXELS
  338. # Draw background (out-of-world) tiles the same colors as walls
  339. # so the agent understands these areas are not reachable
  340. c = COLORS['grey']
  341. r.setLineColor(c[0], c[1], c[2])
  342. r.setColor(c[0], c[1], c[2])
  343. r.drawPolygon([
  344. (0 , heightPx),
  345. (widthPx, heightPx),
  346. (widthPx, 0),
  347. (0 , 0)
  348. ])
  349. r.push()
  350. # Internally, we draw at the "large" full-grid resolution, but we
  351. # use the renderer to scale back to the desired size
  352. r.scale(tileSize / CELL_PIXELS, tileSize / CELL_PIXELS)
  353. # Draw the background of the in-world cells black
  354. r.fillRect(
  355. 0,
  356. 0,
  357. widthPx,
  358. heightPx,
  359. 0, 0, 0
  360. )
  361. # Draw grid lines
  362. r.setLineColor(100, 100, 100)
  363. for rowIdx in range(0, self.height):
  364. y = CELL_PIXELS * rowIdx
  365. r.drawLine(0, y, widthPx, y)
  366. for colIdx in range(0, self.width):
  367. x = CELL_PIXELS * colIdx
  368. r.drawLine(x, 0, x, heightPx)
  369. # Render the grid
  370. for j in range(0, self.height):
  371. for i in range(0, self.width):
  372. cell = self.get(i, j)
  373. if cell == None:
  374. continue
  375. r.push()
  376. r.translate(i * CELL_PIXELS, j * CELL_PIXELS)
  377. cell.render(r)
  378. r.pop()
  379. r.pop()
  380. def encode(self):
  381. """
  382. Produce a compact numpy encoding of the grid
  383. """
  384. codeSize = self.width * self.height * 3
  385. array = np.zeros(shape=(self.width, self.height, 3), dtype='uint8')
  386. for j in range(0, self.height):
  387. for i in range(0, self.width):
  388. v = self.get(i, j)
  389. if v == None:
  390. continue
  391. array[i, j, 0] = OBJECT_TO_IDX[v.type]
  392. array[i, j, 1] = COLOR_TO_IDX[v.color]
  393. if hasattr(v, 'isOpen') and v.isOpen:
  394. array[i, j, 2] = 1
  395. return array
  396. def decode(array):
  397. """
  398. Decode an array grid encoding back into a grid
  399. """
  400. width = array.shape[0]
  401. height = array.shape[1]
  402. assert array.shape[2] == 3
  403. grid = Grid(width, height)
  404. for j in range(0, height):
  405. for i in range(0, width):
  406. typeIdx = array[i, j, 0]
  407. colorIdx = array[i, j, 1]
  408. openIdx = array[i, j, 2]
  409. if typeIdx == 0:
  410. continue
  411. objType = IDX_TO_OBJECT[typeIdx]
  412. color = IDX_TO_COLOR[colorIdx]
  413. isOpen = True if openIdx == 1 else 0
  414. if objType == 'wall':
  415. v = Wall(color)
  416. elif objType == 'ball':
  417. v = Ball(color)
  418. elif objType == 'key':
  419. v = Key(color)
  420. elif objType == 'box':
  421. v = Box(color)
  422. elif objType == 'door':
  423. v = Door(color, isOpen)
  424. elif objType == 'locked_door':
  425. v = LockedDoor(color, isOpen)
  426. elif objType == 'goal':
  427. v = Goal()
  428. else:
  429. assert False, "unknown obj type in decode '%s'" % objType
  430. grid.set(i, j, v)
  431. return grid
  432. class MiniGridEnv(gym.Env):
  433. """
  434. 2D grid world game environment
  435. """
  436. metadata = {
  437. 'render.modes': ['human', 'rgb_array', 'pixmap'],
  438. 'video.frames_per_second' : 10
  439. }
  440. # Enumeration of possible actions
  441. class Actions(IntEnum):
  442. # Turn left, turn right, move forward
  443. left = 0
  444. right = 1
  445. forward = 2
  446. # Pick up an object
  447. pickup = 3
  448. # Drop an object
  449. drop = 4
  450. # Toggle/activate an object
  451. toggle = 5
  452. # Wait/stay put/do nothing
  453. wait = 6
  454. def __init__(self, grid_size=16, max_steps=100):
  455. # Action enumeration for this environment
  456. self.actions = MiniGridEnv.Actions
  457. # Actions are discrete integer values
  458. self.action_space = spaces.Discrete(len(self.actions))
  459. # Observations are dictionaries containing an
  460. # encoding of the grid and a textual 'mission' string
  461. self.observation_space = spaces.Box(
  462. low=0,
  463. high=255,
  464. shape=OBS_ARRAY_SIZE,
  465. dtype='uint8'
  466. )
  467. self.observation_space = spaces.Dict({
  468. 'image': self.observation_space
  469. })
  470. # Range of possible rewards
  471. self.reward_range = (-1, 1000)
  472. # Renderer object used to render the whole grid (full-scale)
  473. self.grid_render = None
  474. # Renderer used to render observations (small-scale agent view)
  475. self.obs_render = None
  476. # Environment configuration
  477. self.grid_size = grid_size
  478. self.max_steps = max_steps
  479. # Starting position and direction for the agent
  480. self.start_pos = None
  481. self.start_dir = None
  482. # Initialize the state
  483. self.seed()
  484. self.reset()
  485. def reset(self):
  486. # Generate a new random grid at the start of each episode
  487. # To keep the same grid for each episode, call env.seed() with
  488. # the same seed before calling env.reset()
  489. self._genGrid(self.grid_size, self.grid_size)
  490. # These fields should be defined by _genGrid
  491. assert self.start_pos != None
  492. assert self.start_dir != None
  493. # Check that the agent doesn't overlap with an object
  494. assert self.grid.get(*self.start_pos) is None
  495. # Place the agent in the starting position and direction
  496. self.agent_pos = self.start_pos
  497. self.agent_dir = self.start_dir
  498. # Item picked up, being carried, initially nothing
  499. self.carrying = None
  500. # Step count since episode start
  501. self.step_count = 0
  502. # Return first observation
  503. obs = self._genObs()
  504. return obs
  505. def seed(self, seed=1337):
  506. # Seed the random number generator
  507. self.np_random, _ = seeding.np_random(seed)
  508. return [seed]
  509. def __str__(self):
  510. """
  511. Produce a pretty string of the environment's grid along with the agent.
  512. The agent is represented by `⏩`. A grid pixel is represented by 2-character
  513. string, the first one for the object and the second one for the color.
  514. """
  515. from copy import deepcopy
  516. def rotate_left(array):
  517. new_array = deepcopy(array)
  518. for i in range(len(array)):
  519. for j in range(len(array[0])):
  520. new_array[j][len(array[0])-1-i] = array[i][j]
  521. return new_array
  522. def vertically_symmetrize(array):
  523. new_array = deepcopy(array)
  524. for i in range(len(array)):
  525. for j in range(len(array[0])):
  526. new_array[i][len(array[0])-1-j] = array[i][j]
  527. return new_array
  528. # Map of object id to short string
  529. OBJECT_IDX_TO_IDS = {
  530. 0: ' ',
  531. 1: 'W',
  532. 2: 'D',
  533. 3: 'L',
  534. 4: 'K',
  535. 5: 'B',
  536. 6: 'X',
  537. 7: 'G'
  538. }
  539. # Short string for opened door
  540. OPENDED_DOOR_IDS = '_'
  541. # Map of color id to short string
  542. COLOR_IDX_TO_IDS = {
  543. 0: 'R',
  544. 1: 'G',
  545. 2: 'B',
  546. 3: 'P',
  547. 4: 'Y',
  548. 5: 'E'
  549. }
  550. # Map agent's direction to short string
  551. AGENT_DIR_TO_IDS = {
  552. 0: '⏩',
  553. 1: '⏬',
  554. 2: '⏪',
  555. 3: '⏫'
  556. }
  557. array = self.grid.encode()
  558. array = rotate_left(array)
  559. array = vertically_symmetrize(array)
  560. new_array = []
  561. for line in array:
  562. new_line = []
  563. for pixel in line:
  564. # If the door is opened
  565. if pixel[0] in [2, 3] and pixel[2] == 1:
  566. object_ids = OPENDED_DOOR_IDS
  567. else:
  568. object_ids = OBJECT_IDX_TO_IDS[pixel[0]]
  569. # If no object
  570. if pixel[0] == 0:
  571. color_ids = ' '
  572. else:
  573. color_ids = COLOR_IDX_TO_IDS[pixel[1]]
  574. new_line.append(object_ids + color_ids)
  575. new_array.append(new_line)
  576. # Add the agent
  577. new_array[self.agent_pos[1]][self.agent_pos[0]] = AGENT_DIR_TO_IDS[self.agent_dir]
  578. return "\n".join([" ".join(line) for line in new_array])
  579. def _genGrid(self, width, height):
  580. assert False, "_genGrid needs to be implemented by each environment"
  581. def _randInt(self, low, high):
  582. """
  583. Generate random integer in [low,high[
  584. """
  585. return self.np_random.randint(low, high)
  586. def _randElem(self, iterable):
  587. """
  588. Pick a random element in a list
  589. """
  590. lst = list(iterable)
  591. idx = self._randInt(0, len(lst))
  592. return lst[idx]
  593. def _randPos(self, xLow, xHigh, yLow, yHigh):
  594. """
  595. Generate a random (x,y) position tuple
  596. """
  597. return (
  598. self.np_random.randint(xLow, xHigh),
  599. self.np_random.randint(yLow, yHigh)
  600. )
  601. def placeObj(self, obj, top=None, size=None, reject_fn=None):
  602. """
  603. Place an object at an empty position in the grid
  604. :param top: top-left position of the rectangle where to place
  605. :param size: size of the rectangle where to place
  606. :param reject_fn: function to filter out potential positions
  607. """
  608. if top is None:
  609. top = (0, 0)
  610. if size is None:
  611. size = (self.grid.width, self.grid.height)
  612. while True:
  613. pos = (
  614. self._randInt(top[0], top[0] + size[0]),
  615. self._randInt(top[1], top[1] + size[1])
  616. )
  617. # Don't place the object on top of another object
  618. if self.grid.get(*pos) != None:
  619. continue
  620. # Don't place the object where the agent is
  621. if pos == self.start_pos:
  622. continue
  623. # Check if there is a filtering criterion
  624. if reject_fn and reject_fn(self, pos):
  625. continue
  626. break
  627. self.grid.set(*pos, obj)
  628. return pos
  629. def placeAgent(self, top=None, size=None, randDir=True):
  630. """
  631. Set the agent's starting point at an empty position in the grid
  632. """
  633. pos = self.placeObj(None, top, size)
  634. self.start_pos = pos
  635. if randDir:
  636. self.start_dir = self._randInt(0, 4)
  637. return pos
  638. def getStepsRemaining(self):
  639. return self.max_steps - self.step_count
  640. def getDirVec(self):
  641. """
  642. Get the direction vector for the agent, pointing in the direction
  643. of forward movement.
  644. """
  645. # Pointing right
  646. if self.agent_dir == 0:
  647. return (1, 0)
  648. # Down (positive Y)
  649. elif self.agent_dir == 1:
  650. return (0, 1)
  651. # Pointing left
  652. elif self.agent_dir == 2:
  653. return (-1, 0)
  654. # Up (negative Y)
  655. elif self.agent_dir == 3:
  656. return (0, -1)
  657. else:
  658. assert False
  659. def getViewExts(self):
  660. """
  661. Get the extents of the square set of tiles visible to the agent
  662. Note: the bottom extent indices are not included in the set
  663. """
  664. # Facing right
  665. if self.agent_dir == 0:
  666. topX = self.agent_pos[0]
  667. topY = self.agent_pos[1] - AGENT_VIEW_SIZE // 2
  668. # Facing down
  669. elif self.agent_dir == 1:
  670. topX = self.agent_pos[0] - AGENT_VIEW_SIZE // 2
  671. topY = self.agent_pos[1]
  672. # Facing left
  673. elif self.agent_dir == 2:
  674. topX = self.agent_pos[0] - AGENT_VIEW_SIZE + 1
  675. topY = self.agent_pos[1] - AGENT_VIEW_SIZE // 2
  676. # Facing up
  677. elif self.agent_dir == 3:
  678. topX = self.agent_pos[0] - AGENT_VIEW_SIZE // 2
  679. topY = self.agent_pos[1] - AGENT_VIEW_SIZE + 1
  680. else:
  681. assert False, "invalid agent direction"
  682. botX = topX + AGENT_VIEW_SIZE
  683. botY = topY + AGENT_VIEW_SIZE
  684. return (topX, topY, botX, botY)
  685. def agentSees(self, x, y):
  686. """
  687. Check if a grid position is visible to the agent
  688. """
  689. topX, topY, botX, botY = self.getViewExts()
  690. return (x >= topX and x < botX and y >= topY and y < botY)
  691. def step(self, action):
  692. self.step_count += 1
  693. reward = 0
  694. done = False
  695. # Get the position in front of the agent
  696. u, v = self.getDirVec()
  697. fwdPos = (self.agent_pos[0] + u, self.agent_pos[1] + v)
  698. # Get the contents of the cell in front of the agent
  699. fwdCell = self.grid.get(*fwdPos)
  700. # Rotate left
  701. if action == self.actions.left:
  702. self.agent_dir -= 1
  703. if self.agent_dir < 0:
  704. self.agent_dir += 4
  705. # Rotate right
  706. elif action == self.actions.right:
  707. self.agent_dir = (self.agent_dir + 1) % 4
  708. # Move forward
  709. elif action == self.actions.forward:
  710. if fwdCell == None or fwdCell.canOverlap():
  711. self.agent_pos = fwdPos
  712. if fwdCell != None and fwdCell.type == 'goal':
  713. done = True
  714. reward = 1000 - self.step_count
  715. # Pick up an object
  716. elif action == self.actions.pickup:
  717. if fwdCell and fwdCell.canPickup():
  718. if self.carrying is None:
  719. self.carrying = fwdCell
  720. self.grid.set(*fwdPos, None)
  721. # Drop an object
  722. elif action == self.actions.drop:
  723. if not fwdCell and self.carrying:
  724. self.grid.set(*fwdPos, self.carrying)
  725. self.carrying = None
  726. # Toggle/activate an object
  727. elif action == self.actions.toggle:
  728. if fwdCell:
  729. fwdCell.toggle(self, fwdPos)
  730. # Wait/do nothing
  731. elif action == self.actions.wait:
  732. pass
  733. else:
  734. assert False, "unknown action"
  735. if self.step_count >= self.max_steps:
  736. done = True
  737. obs = self._genObs()
  738. return obs, reward, done, {}
  739. def _genObs(self):
  740. """
  741. Generate the agent's view (partially observable, low-resolution encoding)
  742. """
  743. topX, topY, botX, botY = self.getViewExts()
  744. grid = self.grid.slice(topX, topY, AGENT_VIEW_SIZE, AGENT_VIEW_SIZE)
  745. for i in range(self.agent_dir + 1):
  746. grid = grid.rotateLeft()
  747. # Make it so the agent sees what it's carrying
  748. # We do this by placing the carried object at the agent's position
  749. # in the agent's partially observable view
  750. agent_pos = grid.width // 2, grid.height - 1
  751. if self.carrying:
  752. grid.set(*agent_pos, self.carrying)
  753. else:
  754. grid.set(*agent_pos, None)
  755. # Encode the partially observable view into a numpy array
  756. image = grid.encode()
  757. assert hasattr(self, 'mission'), "environments must define a textual mission string"
  758. # Observations are dictionaries containing:
  759. # - an image (partially observable view of the environment)
  760. # - the agent's direction/orientation (acting as a compass)
  761. # - a textual mission string (instructions for the agent)
  762. obs = {
  763. 'image': image,
  764. 'direction': self.agent_dir,
  765. 'mission': self.mission
  766. }
  767. return obs
  768. def getObsRender(self, obs):
  769. """
  770. Render an agent observation for visualization
  771. """
  772. if self.obs_render == None:
  773. self.obs_render = Renderer(
  774. AGENT_VIEW_SIZE * CELL_PIXELS // 2,
  775. AGENT_VIEW_SIZE * CELL_PIXELS // 2
  776. )
  777. r = self.obs_render
  778. r.beginFrame()
  779. grid = Grid.decode(obs)
  780. # Render the whole grid
  781. grid.render(r, CELL_PIXELS // 2)
  782. # Draw the agent
  783. r.push()
  784. r.scale(0.5, 0.5)
  785. r.translate(
  786. CELL_PIXELS * (0.5 + AGENT_VIEW_SIZE // 2),
  787. CELL_PIXELS * (AGENT_VIEW_SIZE - 0.5)
  788. )
  789. r.rotate(3 * 90)
  790. r.setLineColor(255, 0, 0)
  791. r.setColor(255, 0, 0)
  792. r.drawPolygon([
  793. (-12, 10),
  794. ( 12, 0),
  795. (-12, -10)
  796. ])
  797. r.pop()
  798. r.endFrame()
  799. return r.getPixmap()
  800. def render(self, mode='human', close=False):
  801. """
  802. Render the whole-grid human view
  803. """
  804. if close:
  805. if self.grid_render:
  806. self.grid_render.close()
  807. return
  808. if self.grid_render is None:
  809. self.grid_render = Renderer(
  810. self.grid_size * CELL_PIXELS,
  811. self.grid_size * CELL_PIXELS,
  812. True if mode == 'human' else False
  813. )
  814. r = self.grid_render
  815. r.beginFrame()
  816. # Render the whole grid
  817. self.grid.render(r, CELL_PIXELS)
  818. # Draw the agent
  819. r.push()
  820. r.translate(
  821. CELL_PIXELS * (self.agent_pos[0] + 0.5),
  822. CELL_PIXELS * (self.agent_pos[1] + 0.5)
  823. )
  824. r.rotate(self.agent_dir * 90)
  825. r.setLineColor(255, 0, 0)
  826. r.setColor(255, 0, 0)
  827. r.drawPolygon([
  828. (-12, 10),
  829. ( 12, 0),
  830. (-12, -10)
  831. ])
  832. r.pop()
  833. # Highlight what the agent can see
  834. topX, topY, botX, botY = self.getViewExts()
  835. r.fillRect(
  836. topX * CELL_PIXELS,
  837. topY * CELL_PIXELS,
  838. AGENT_VIEW_SIZE * CELL_PIXELS,
  839. AGENT_VIEW_SIZE * CELL_PIXELS,
  840. 200, 200, 200, 75
  841. )
  842. r.endFrame()
  843. if mode == 'rgb_array':
  844. return r.getArray()
  845. elif mode == 'pixmap':
  846. return r.getPixmap()
  847. return r