minigrid.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945
  1. import math
  2. import gym
  3. from enum import IntEnum
  4. import numpy as np
  5. from gym import error, spaces, utils
  6. from gym.utils import seeding
  7. from gym_minigrid.rendering import *
  8. # Size in pixels of a cell in the full-scale human view
  9. CELL_PIXELS = 32
  10. # Number of cells (width and height) in the agent view
  11. AGENT_VIEW_SIZE = 7
  12. # Size of the array given as an observation to the agent
  13. OBS_ARRAY_SIZE = (AGENT_VIEW_SIZE, AGENT_VIEW_SIZE, 3)
  14. # Map of color names to RGB values
  15. COLORS = {
  16. 'red' : (255, 0, 0),
  17. 'green' : (0, 255, 0),
  18. 'blue' : (0, 0, 255),
  19. 'purple': (112, 39, 195),
  20. 'yellow': (255, 255, 0),
  21. 'grey' : (100, 100, 100)
  22. }
  23. COLOR_NAMES = sorted(list(COLORS.keys()))
  24. # Used to map colors to integers
  25. COLOR_TO_IDX = {
  26. 'red' : 0,
  27. 'green' : 1,
  28. 'blue' : 2,
  29. 'purple': 3,
  30. 'yellow': 4,
  31. 'grey' : 5
  32. }
  33. IDX_TO_COLOR = dict(zip(COLOR_TO_IDX.values(), COLOR_TO_IDX.keys()))
  34. # Map of object type to integers
  35. OBJECT_TO_IDX = {
  36. 'empty' : 0,
  37. 'wall' : 1,
  38. 'door' : 2,
  39. 'locked_door' : 3,
  40. 'key' : 4,
  41. 'ball' : 5,
  42. 'box' : 6,
  43. 'goal' : 7
  44. }
  45. IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys()))
  46. class WorldObj:
  47. """
  48. Base class for grid world objects
  49. """
  50. def __init__(self, type, color):
  51. assert type in OBJECT_TO_IDX, type
  52. assert color in COLOR_TO_IDX, color
  53. self.type = type
  54. self.color = color
  55. self.contains = None
  56. def canOverlap(self):
  57. """Can the agent overlap with this?"""
  58. return False
  59. def canPickup(self):
  60. """Can the agent pick this up?"""
  61. return False
  62. def canContain(self):
  63. """Can this contain another object?"""
  64. return False
  65. def toggle(self, env, pos):
  66. """Method to trigger/toggle an action this object performs"""
  67. return False
  68. def render(self, r):
  69. assert False
  70. def _setColor(self, r):
  71. c = COLORS[self.color]
  72. r.setLineColor(c[0], c[1], c[2])
  73. r.setColor(c[0], c[1], c[2])
  74. class Goal(WorldObj):
  75. def __init__(self):
  76. super(Goal, self).__init__('goal', 'green')
  77. def canOverlap(self):
  78. return True
  79. def render(self, r):
  80. self._setColor(r)
  81. r.drawPolygon([
  82. (0 , CELL_PIXELS),
  83. (CELL_PIXELS, CELL_PIXELS),
  84. (CELL_PIXELS, 0),
  85. (0 , 0)
  86. ])
  87. class Wall(WorldObj):
  88. def __init__(self, color='grey'):
  89. super(Wall, self).__init__('wall', color)
  90. def render(self, r):
  91. self._setColor(r)
  92. r.drawPolygon([
  93. (0 , CELL_PIXELS),
  94. (CELL_PIXELS, CELL_PIXELS),
  95. (CELL_PIXELS, 0),
  96. (0 , 0)
  97. ])
  98. class Door(WorldObj):
  99. def __init__(self, color, isOpen=False):
  100. super(Door, self).__init__('door', color)
  101. self.isOpen = isOpen
  102. def render(self, r):
  103. c = COLORS[self.color]
  104. r.setLineColor(c[0], c[1], c[2])
  105. r.setColor(0, 0, 0)
  106. if self.isOpen:
  107. r.drawPolygon([
  108. (CELL_PIXELS-2, CELL_PIXELS),
  109. (CELL_PIXELS , CELL_PIXELS),
  110. (CELL_PIXELS , 0),
  111. (CELL_PIXELS-2, 0)
  112. ])
  113. return
  114. r.drawPolygon([
  115. (0 , CELL_PIXELS),
  116. (CELL_PIXELS, CELL_PIXELS),
  117. (CELL_PIXELS, 0),
  118. (0 , 0)
  119. ])
  120. r.drawPolygon([
  121. (2 , CELL_PIXELS-2),
  122. (CELL_PIXELS-2, CELL_PIXELS-2),
  123. (CELL_PIXELS-2, 2),
  124. (2 , 2)
  125. ])
  126. r.drawCircle(CELL_PIXELS * 0.75, CELL_PIXELS * 0.5, 2)
  127. def toggle(self, env, pos):
  128. if not self.isOpen:
  129. self.isOpen = True
  130. return True
  131. return False
  132. def canOverlap(self):
  133. """The agent can only walk over this cell when the door is open"""
  134. return self.isOpen
  135. class LockedDoor(WorldObj):
  136. def __init__(self, color, isOpen=False):
  137. super(LockedDoor, self).__init__('locked_door', color)
  138. self.isOpen = isOpen
  139. def render(self, r):
  140. c = COLORS[self.color]
  141. r.setLineColor(c[0], c[1], c[2])
  142. r.setColor(c[0], c[1], c[2], 50)
  143. if self.isOpen:
  144. r.drawPolygon([
  145. (CELL_PIXELS-2, CELL_PIXELS),
  146. (CELL_PIXELS , CELL_PIXELS),
  147. (CELL_PIXELS , 0),
  148. (CELL_PIXELS-2, 0)
  149. ])
  150. return
  151. r.drawPolygon([
  152. (0 , CELL_PIXELS),
  153. (CELL_PIXELS, CELL_PIXELS),
  154. (CELL_PIXELS, 0),
  155. (0 , 0)
  156. ])
  157. r.drawPolygon([
  158. (2 , CELL_PIXELS-2),
  159. (CELL_PIXELS-2, CELL_PIXELS-2),
  160. (CELL_PIXELS-2, 2),
  161. (2 , 2)
  162. ])
  163. r.drawLine(
  164. CELL_PIXELS * 0.55,
  165. CELL_PIXELS * 0.5,
  166. CELL_PIXELS * 0.75,
  167. CELL_PIXELS * 0.5
  168. )
  169. def toggle(self, env, pos):
  170. # If the player has the right key to open the door
  171. if isinstance(env.carrying, Key) and env.carrying.color == self.color:
  172. self.isOpen = True
  173. # The key has been used, remove it from the agent
  174. env.carrying = None
  175. return True
  176. return False
  177. def canOverlap(self):
  178. """The agent can only walk over this cell when the door is open"""
  179. return self.isOpen
  180. class Key(WorldObj):
  181. def __init__(self, color='blue'):
  182. super(Key, self).__init__('key', color)
  183. def canPickup(self):
  184. return True
  185. def render(self, r):
  186. self._setColor(r)
  187. # Vertical quad
  188. r.drawPolygon([
  189. (16, 10),
  190. (20, 10),
  191. (20, 28),
  192. (16, 28)
  193. ])
  194. # Teeth
  195. r.drawPolygon([
  196. (12, 19),
  197. (16, 19),
  198. (16, 21),
  199. (12, 21)
  200. ])
  201. r.drawPolygon([
  202. (12, 26),
  203. (16, 26),
  204. (16, 28),
  205. (12, 28)
  206. ])
  207. r.drawCircle(18, 9, 6)
  208. r.setLineColor(0, 0, 0)
  209. r.setColor(0, 0, 0)
  210. r.drawCircle(18, 9, 2)
  211. class Ball(WorldObj):
  212. def __init__(self, color='blue'):
  213. super(Ball, self).__init__('ball', color)
  214. def canPickup(self):
  215. return True
  216. def render(self, r):
  217. self._setColor(r)
  218. r.drawCircle(CELL_PIXELS * 0.5, CELL_PIXELS * 0.5, 10)
  219. class Box(WorldObj):
  220. def __init__(self, color, contains=None):
  221. super(Box, self).__init__('box', color)
  222. self.contains = contains
  223. def render(self, r):
  224. c = COLORS[self.color]
  225. r.setLineColor(c[0], c[1], c[2])
  226. r.setColor(0, 0, 0)
  227. r.setLineWidth(2)
  228. r.drawPolygon([
  229. (4 , CELL_PIXELS-4),
  230. (CELL_PIXELS-4, CELL_PIXELS-4),
  231. (CELL_PIXELS-4, 4),
  232. (4 , 4)
  233. ])
  234. r.drawLine(
  235. 4,
  236. CELL_PIXELS / 2,
  237. CELL_PIXELS - 4,
  238. CELL_PIXELS / 2
  239. )
  240. r.setLineWidth(1)
  241. def toggle(self, env, pos):
  242. # Replace the box by its contents
  243. env.grid.set(*pos, self.contains)
  244. return True
  245. class Grid:
  246. """
  247. Represent a grid and operations on it
  248. """
  249. def __init__(self, width, height):
  250. assert width >= 4
  251. assert height >= 4
  252. self.width = width
  253. self.height = height
  254. self.grid = [None] * width * height
  255. def __contains__(self, key):
  256. if isinstance(key, WorldObj):
  257. for e in self.grid:
  258. if e is key:
  259. return True
  260. elif isinstance(key, tuple):
  261. for e in self.grid:
  262. if e is None:
  263. continue
  264. if (e.color, e.type) == key:
  265. return True
  266. return False
  267. def copy(self):
  268. from copy import deepcopy
  269. return deepcopy(self)
  270. def set(self, i, j, v):
  271. assert i >= 0 and i < self.width
  272. assert j >= 0 and j < self.height
  273. self.grid[j * self.width + i] = v
  274. def get(self, i, j):
  275. assert i >= 0 and i < self.width
  276. assert j >= 0 and j < self.height
  277. return self.grid[j * self.width + i]
  278. def horzWall(self, x, y, length=None):
  279. if length is None:
  280. length = self.width - x
  281. for i in range(0, length):
  282. self.set(x + i, y, Wall())
  283. def vertWall(self, x, y, length=None):
  284. if length is None:
  285. length = self.height - y
  286. for j in range(0, length):
  287. self.set(x, y + j, Wall())
  288. def wallRect(self, x, y, w, h):
  289. self.horzWall(x, y, w)
  290. self.horzWall(x, y+h-1, w)
  291. self.vertWall(x, y, h)
  292. self.vertWall(x+w-1, y, h)
  293. def rotateLeft(self):
  294. """
  295. Rotate the grid to the left (counter-clockwise)
  296. """
  297. grid = Grid(self.width, self.height)
  298. for j in range(0, self.height):
  299. for i in range(0, self.width):
  300. v = self.get(self.width - 1 - j, i)
  301. grid.set(i, j, v)
  302. return grid
  303. def slice(self, topX, topY, width, height):
  304. """
  305. Get a subset of the grid
  306. """
  307. grid = Grid(width, height)
  308. for j in range(0, height):
  309. for i in range(0, width):
  310. x = topX + i
  311. y = topY + j
  312. if x >= 0 and x < self.width and \
  313. y >= 0 and y < self.height:
  314. v = self.get(x, y)
  315. else:
  316. v = Wall()
  317. grid.set(i, j, v)
  318. return grid
  319. def render(self, r, tileSize):
  320. """
  321. Render this grid at a given scale
  322. :param r: target renderer object
  323. :param tileSize: tile size in pixels
  324. """
  325. assert r.width == self.width * tileSize
  326. assert r.height == self.height * tileSize
  327. # Total grid size at native scale
  328. widthPx = self.width * CELL_PIXELS
  329. heightPx = self.height * CELL_PIXELS
  330. # Draw background (out-of-world) tiles the same colors as walls
  331. # so the agent understands these areas are not reachable
  332. c = COLORS['grey']
  333. r.setLineColor(c[0], c[1], c[2])
  334. r.setColor(c[0], c[1], c[2])
  335. r.drawPolygon([
  336. (0 , heightPx),
  337. (widthPx, heightPx),
  338. (widthPx, 0),
  339. (0 , 0)
  340. ])
  341. r.push()
  342. # Internally, we draw at the "large" full-grid resolution, but we
  343. # use the renderer to scale back to the desired size
  344. r.scale(tileSize / CELL_PIXELS, tileSize / CELL_PIXELS)
  345. # Draw the background of the in-world cells black
  346. r.fillRect(
  347. 0,
  348. 0,
  349. widthPx,
  350. heightPx,
  351. 0, 0, 0
  352. )
  353. # Draw grid lines
  354. r.setLineColor(100, 100, 100)
  355. for rowIdx in range(0, self.height):
  356. y = CELL_PIXELS * rowIdx
  357. r.drawLine(0, y, widthPx, y)
  358. for colIdx in range(0, self.width):
  359. x = CELL_PIXELS * colIdx
  360. r.drawLine(x, 0, x, heightPx)
  361. # Render the grid
  362. for j in range(0, self.height):
  363. for i in range(0, self.width):
  364. cell = self.get(i, j)
  365. if cell == None:
  366. continue
  367. r.push()
  368. r.translate(i * CELL_PIXELS, j * CELL_PIXELS)
  369. cell.render(r)
  370. r.pop()
  371. r.pop()
  372. def encode(self):
  373. """
  374. Produce a compact numpy encoding of the grid
  375. """
  376. codeSize = self.width * self.height * 3
  377. array = np.zeros(shape=(self.width, self.height, 3), dtype='uint8')
  378. for j in range(0, self.height):
  379. for i in range(0, self.width):
  380. v = self.get(i, j)
  381. if v == None:
  382. continue
  383. array[i, j, 0] = OBJECT_TO_IDX[v.type]
  384. array[i, j, 1] = COLOR_TO_IDX[v.color]
  385. if hasattr(v, 'isOpen') and v.isOpen:
  386. array[i, j, 2] = 1
  387. return array
  388. def decode(array):
  389. """
  390. Decode an array grid encoding back into a grid
  391. """
  392. width = array.shape[0]
  393. height = array.shape[1]
  394. assert array.shape[2] == 3
  395. grid = Grid(width, height)
  396. for j in range(0, height):
  397. for i in range(0, width):
  398. typeIdx = array[i, j, 0]
  399. colorIdx = array[i, j, 1]
  400. openIdx = array[i, j, 2]
  401. if typeIdx == 0:
  402. continue
  403. objType = IDX_TO_OBJECT[typeIdx]
  404. color = IDX_TO_COLOR[colorIdx]
  405. isOpen = True if openIdx == 1 else 0
  406. if objType == 'wall':
  407. v = Wall(color)
  408. elif objType == 'ball':
  409. v = Ball(color)
  410. elif objType == 'key':
  411. v = Key(color)
  412. elif objType == 'box':
  413. v = Box(color)
  414. elif objType == 'door':
  415. v = Door(color, isOpen)
  416. elif objType == 'locked_door':
  417. v = LockedDoor(color, isOpen)
  418. elif objType == 'goal':
  419. v = Goal()
  420. else:
  421. assert False, "unknown obj type in decode '%s'" % objType
  422. grid.set(i, j, v)
  423. return grid
  424. class MiniGridEnv(gym.Env):
  425. """
  426. 2D grid world game environment
  427. """
  428. metadata = {
  429. 'render.modes': ['human', 'rgb_array', 'pixmap'],
  430. 'video.frames_per_second' : 10
  431. }
  432. # Enumeration of possible actions
  433. class Actions(IntEnum):
  434. # Turn left, turn right, move forward
  435. left = 0
  436. right = 1
  437. forward = 2
  438. # Pick up an object
  439. pickup = 3
  440. # Drop an object
  441. drop = 4
  442. # Toggle/activate an object
  443. toggle = 5
  444. # Wait/stay put/do nothing
  445. wait = 6
  446. def __init__(self, gridSize=16, maxSteps=100):
  447. # Action enumeration for this environment
  448. self.actions = MiniGridEnv.Actions
  449. # Actions are discrete integer values
  450. self.action_space = spaces.Discrete(len(self.actions))
  451. # Observations are dictionaries containing an
  452. # encoding of the grid and a textual 'mission' string
  453. self.observation_space = spaces.Box(
  454. low=0,
  455. high=255,
  456. shape=OBS_ARRAY_SIZE,
  457. dtype='uint8'
  458. )
  459. self.observation_space = spaces.Dict({
  460. 'image': self.observation_space
  461. })
  462. # Range of possible rewards
  463. self.reward_range = (-1, 1000)
  464. # Renderer object used to render the whole grid (full-scale)
  465. self.gridRender = None
  466. # Renderer used to render observations (small-scale agent view)
  467. self.obsRender = None
  468. # Environment configuration
  469. self.gridSize = gridSize
  470. self.maxSteps = maxSteps
  471. self.startPos = (1, 1)
  472. self.startDir = 0
  473. # Initialize the state
  474. self.seed()
  475. self.reset()
  476. def _genGrid(self, width, height):
  477. assert False, "_genGrid needs to be implemented by each environment"
  478. def reset(self):
  479. # Generate a new random grid at the start of each episode
  480. # To keep the same grid for each episode, call env.seed() with
  481. # the same seed before calling env.reset()
  482. self._genGrid(self.gridSize, self.gridSize)
  483. # Place the agent in the starting position and direction
  484. self.agentPos = self.startPos
  485. self.agentDir = self.startDir
  486. # Item picked up, being carried, initially nothing
  487. self.carrying = None
  488. # Step count since episode start
  489. self.stepCount = 0
  490. # Return first observation
  491. obs = self._genObs()
  492. return obs
  493. def seed(self, seed=1337):
  494. # Seed the random number generator
  495. self.np_random, _ = seeding.np_random(seed)
  496. return [seed]
  497. def _randInt(self, low, high):
  498. """
  499. Generate random integer in [low,high[
  500. """
  501. return self.np_random.randint(low, high)
  502. def _randElem(self, iterable):
  503. """
  504. Pick a random element in a list
  505. """
  506. lst = list(iterable)
  507. idx = self._randInt(0, len(lst))
  508. return lst[idx]
  509. def _randPos(self, xLow, xHigh, yLow, yHigh):
  510. """
  511. Generate a random (x,y) position tuple
  512. """
  513. return (
  514. self.np_random.randint(xLow, xHigh),
  515. self.np_random.randint(yLow, yHigh)
  516. )
  517. def placeObj(self, obj, top=None, size=None):
  518. """
  519. Place an object at an empty position in the grid
  520. :param top: top-left position of the rectangle where to place
  521. :param size: size of the rectangle where to place
  522. """
  523. if top is None:
  524. top = (0, 0)
  525. if size is None:
  526. size = (self.grid.width, self.grid.height)
  527. while True:
  528. pos = (
  529. self._randInt(top[0], top[0] + size[0]),
  530. self._randInt(top[1], top[1] + size[1])
  531. )
  532. if self.grid.get(*pos) != None:
  533. continue
  534. if pos == self.startPos:
  535. continue
  536. break
  537. self.grid.set(*pos, obj)
  538. return pos
  539. def placeAgent(self, randDir=True):
  540. """
  541. Set the agent's starting point at an empty position in the grid
  542. """
  543. pos = self.placeObj(None)
  544. self.startPos = pos
  545. if randDir:
  546. self.startDir = self._randInt(0, 4)
  547. return pos
  548. def getStepsRemaining(self):
  549. return self.maxSteps - self.stepCount
  550. def getDirVec(self):
  551. """
  552. Get the direction vector for the agent, pointing in the direction
  553. of forward movement.
  554. """
  555. # Pointing right
  556. if self.agentDir == 0:
  557. return (1, 0)
  558. # Down (positive Y)
  559. elif self.agentDir == 1:
  560. return (0, 1)
  561. # Pointing left
  562. elif self.agentDir == 2:
  563. return (-1, 0)
  564. # Up (negative Y)
  565. elif self.agentDir == 3:
  566. return (0, -1)
  567. else:
  568. assert False
  569. def getViewExts(self):
  570. """
  571. Get the extents of the square set of tiles visible to the agent
  572. Note: the bottom extent indices are not included in the set
  573. """
  574. # Facing right
  575. if self.agentDir == 0:
  576. topX = self.agentPos[0]
  577. topY = self.agentPos[1] - AGENT_VIEW_SIZE // 2
  578. # Facing down
  579. elif self.agentDir == 1:
  580. topX = self.agentPos[0] - AGENT_VIEW_SIZE // 2
  581. topY = self.agentPos[1]
  582. # Facing left
  583. elif self.agentDir == 2:
  584. topX = self.agentPos[0] - AGENT_VIEW_SIZE + 1
  585. topY = self.agentPos[1] - AGENT_VIEW_SIZE // 2
  586. # Facing up
  587. elif self.agentDir == 3:
  588. topX = self.agentPos[0] - AGENT_VIEW_SIZE // 2
  589. topY = self.agentPos[1] - AGENT_VIEW_SIZE + 1
  590. else:
  591. assert False, "invalid agent direction"
  592. botX = topX + AGENT_VIEW_SIZE
  593. botY = topY + AGENT_VIEW_SIZE
  594. return (topX, topY, botX, botY)
  595. def agentSees(self, x, y):
  596. """
  597. Check if a grid position is visible to the agent
  598. """
  599. topX, topY, botX, botY = self.getViewExts()
  600. return (x >= topX and x < botX and y >= topY and y < botY)
  601. def step(self, action):
  602. self.stepCount += 1
  603. reward = 0
  604. done = False
  605. # Get the position in front of the agent
  606. u, v = self.getDirVec()
  607. fwdPos = (self.agentPos[0] + u, self.agentPos[1] + v)
  608. # Get the contents of the cell in front of the agent
  609. fwdCell = self.grid.get(*fwdPos)
  610. # Rotate left
  611. if action == self.actions.left:
  612. self.agentDir -= 1
  613. if self.agentDir < 0:
  614. self.agentDir += 4
  615. # Rotate right
  616. elif action == self.actions.right:
  617. self.agentDir = (self.agentDir + 1) % 4
  618. # Move forward
  619. elif action == self.actions.forward:
  620. if fwdCell == None or fwdCell.canOverlap():
  621. self.agentPos = fwdPos
  622. if fwdCell != None and fwdCell.type == 'goal':
  623. done = True
  624. reward = 1000 - self.stepCount
  625. # Pick up an object
  626. elif action == self.actions.pickup:
  627. if fwdCell and fwdCell.canPickup():
  628. if self.carrying is None:
  629. self.carrying = fwdCell
  630. self.grid.set(*fwdPos, None)
  631. # Drop an object
  632. elif action == self.actions.drop:
  633. if not fwdCell and self.carrying:
  634. self.grid.set(*fwdPos, self.carrying)
  635. self.carrying = None
  636. # Toggle/activate an object
  637. elif action == self.actions.toggle:
  638. if fwdCell:
  639. fwdCell.toggle(self, fwdPos)
  640. # Wait/do nothing
  641. elif action == self.actions.wait:
  642. pass
  643. else:
  644. assert False, "unknown action"
  645. if self.stepCount >= self.maxSteps:
  646. done = True
  647. obs = self._genObs()
  648. return obs, reward, done, {}
  649. def _genObs(self):
  650. """
  651. Generate the agent's view (partially observable, low-resolution encoding)
  652. """
  653. topX, topY, botX, botY = self.getViewExts()
  654. grid = self.grid.slice(topX, topY, AGENT_VIEW_SIZE, AGENT_VIEW_SIZE)
  655. for i in range(self.agentDir + 1):
  656. grid = grid.rotateLeft()
  657. # Make it so the agent sees what it's carrying
  658. # We do this by placing the carried object at the agent's position
  659. # in the agent's partially observable view
  660. agentPos = grid.width // 2, grid.height - 1
  661. if self.carrying:
  662. grid.set(*agentPos, self.carrying)
  663. else:
  664. grid.set(*agentPos, None)
  665. # Encode the partially observable view into a numpy array
  666. image = grid.encode()
  667. assert hasattr(self, 'mission'), "environments must define a textual mission string"
  668. # Observations are dictionaries with both an image
  669. # and a textual mission string
  670. obs = {
  671. 'image': image,
  672. 'mission': self.mission
  673. }
  674. return obs
  675. def getObsRender(self, obs):
  676. """
  677. Render an agent observation for visualization
  678. """
  679. if self.obsRender == None:
  680. self.obsRender = Renderer(
  681. AGENT_VIEW_SIZE * CELL_PIXELS // 2,
  682. AGENT_VIEW_SIZE * CELL_PIXELS // 2
  683. )
  684. r = self.obsRender
  685. r.beginFrame()
  686. grid = Grid.decode(obs)
  687. # Render the whole grid
  688. grid.render(r, CELL_PIXELS // 2)
  689. # Draw the agent
  690. r.push()
  691. r.scale(0.5, 0.5)
  692. r.translate(
  693. CELL_PIXELS * (0.5 + AGENT_VIEW_SIZE // 2),
  694. CELL_PIXELS * (AGENT_VIEW_SIZE - 0.5)
  695. )
  696. r.rotate(3 * 90)
  697. r.setLineColor(255, 0, 0)
  698. r.setColor(255, 0, 0)
  699. r.drawPolygon([
  700. (-12, 10),
  701. ( 12, 0),
  702. (-12, -10)
  703. ])
  704. r.pop()
  705. r.endFrame()
  706. return r.getPixmap()
  707. def render(self, mode='human', close=False):
  708. """
  709. Render the whole-grid human view
  710. """
  711. if close:
  712. if self.gridRender:
  713. self.gridRender.close()
  714. return
  715. if self.gridRender is None:
  716. self.gridRender = Renderer(
  717. self.gridSize * CELL_PIXELS,
  718. self.gridSize * CELL_PIXELS,
  719. True if mode == 'human' else False
  720. )
  721. r = self.gridRender
  722. r.beginFrame()
  723. # Render the whole grid
  724. self.grid.render(r, CELL_PIXELS)
  725. # Draw the agent
  726. r.push()
  727. r.translate(
  728. CELL_PIXELS * (self.agentPos[0] + 0.5),
  729. CELL_PIXELS * (self.agentPos[1] + 0.5)
  730. )
  731. r.rotate(self.agentDir * 90)
  732. r.setLineColor(255, 0, 0)
  733. r.setColor(255, 0, 0)
  734. r.drawPolygon([
  735. (-12, 10),
  736. ( 12, 0),
  737. (-12, -10)
  738. ])
  739. r.pop()
  740. # Highlight what the agent can see
  741. topX, topY, botX, botY = self.getViewExts()
  742. r.fillRect(
  743. topX * CELL_PIXELS,
  744. topY * CELL_PIXELS,
  745. AGENT_VIEW_SIZE * CELL_PIXELS,
  746. AGENT_VIEW_SIZE * CELL_PIXELS,
  747. 200, 200, 200, 75
  748. )
  749. r.endFrame()
  750. if mode == 'rgb_array':
  751. return r.getArray()
  752. elif mode == 'pixmap':
  753. return r.getPixmap()
  754. return r