minigrid.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929
  1. import math
  2. import gym
  3. from enum import IntEnum
  4. import numpy as np
  5. from gym import error, spaces, utils
  6. from gym.utils import seeding
  7. from gym_minigrid.rendering import *
  8. # Size in pixels of a cell in the full-scale human view
  9. CELL_PIXELS = 32
  10. # Number of cells (width and height) in the agent view
  11. AGENT_VIEW_SIZE = 7
  12. # Size of the array given as an observation to the agent
  13. OBS_ARRAY_SIZE = (AGENT_VIEW_SIZE, AGENT_VIEW_SIZE, 3)
  14. # Map of color names to RGB values
  15. COLORS = {
  16. 'red' : (255, 0, 0),
  17. 'green' : (0, 255, 0),
  18. 'blue' : (0, 0, 255),
  19. 'purple': (112, 39, 195),
  20. 'yellow': (255, 255, 0),
  21. 'grey' : (100, 100, 100)
  22. }
  23. COLOR_NAMES = sorted(list(COLORS.keys()))
  24. # Used to map colors to integers
  25. COLOR_TO_IDX = {
  26. 'red' : 0,
  27. 'green' : 1,
  28. 'blue' : 2,
  29. 'purple': 3,
  30. 'yellow': 4,
  31. 'grey' : 5
  32. }
  33. IDX_TO_COLOR = dict(zip(COLOR_TO_IDX.values(), COLOR_TO_IDX.keys()))
  34. # Map of object type to integers
  35. OBJECT_TO_IDX = {
  36. 'empty' : 0,
  37. 'wall' : 1,
  38. 'door' : 2,
  39. 'locked_door' : 3,
  40. 'key' : 4,
  41. 'ball' : 5,
  42. 'box' : 6,
  43. 'goal' : 7
  44. }
  45. IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys()))
  46. class WorldObj:
  47. """
  48. Base class for grid world objects
  49. """
  50. def __init__(self, type, color):
  51. assert type in OBJECT_TO_IDX, type
  52. assert color in COLOR_TO_IDX, color
  53. self.type = type
  54. self.color = color
  55. self.contains = None
  56. def canOverlap(self):
  57. """Can the agent overlap with this?"""
  58. return False
  59. def canPickup(self):
  60. """Can the agent pick this up?"""
  61. return False
  62. def canContain(self):
  63. """Can this contain another object?"""
  64. return False
  65. def toggle(self, env, pos):
  66. """Method to trigger/toggle an action this object performs"""
  67. return False
  68. def render(self, r):
  69. assert False
  70. def _setColor(self, r):
  71. c = COLORS[self.color]
  72. r.setLineColor(c[0], c[1], c[2])
  73. r.setColor(c[0], c[1], c[2])
  74. class Goal(WorldObj):
  75. def __init__(self):
  76. super(Goal, self).__init__('goal', 'green')
  77. def render(self, r):
  78. self._setColor(r)
  79. r.drawPolygon([
  80. (0 , CELL_PIXELS),
  81. (CELL_PIXELS, CELL_PIXELS),
  82. (CELL_PIXELS, 0),
  83. (0 , 0)
  84. ])
  85. class Wall(WorldObj):
  86. def __init__(self, color='grey'):
  87. super(Wall, self).__init__('wall', color)
  88. def render(self, r):
  89. self._setColor(r)
  90. r.drawPolygon([
  91. (0 , CELL_PIXELS),
  92. (CELL_PIXELS, CELL_PIXELS),
  93. (CELL_PIXELS, 0),
  94. (0 , 0)
  95. ])
  96. class Door(WorldObj):
  97. def __init__(self, color, isOpen=False):
  98. super(Door, self).__init__('door', color)
  99. self.isOpen = isOpen
  100. def render(self, r):
  101. c = COLORS[self.color]
  102. r.setLineColor(c[0], c[1], c[2])
  103. r.setColor(0, 0, 0)
  104. if self.isOpen:
  105. r.drawPolygon([
  106. (CELL_PIXELS-2, CELL_PIXELS),
  107. (CELL_PIXELS , CELL_PIXELS),
  108. (CELL_PIXELS , 0),
  109. (CELL_PIXELS-2, 0)
  110. ])
  111. return
  112. r.drawPolygon([
  113. (0 , CELL_PIXELS),
  114. (CELL_PIXELS, CELL_PIXELS),
  115. (CELL_PIXELS, 0),
  116. (0 , 0)
  117. ])
  118. r.drawPolygon([
  119. (2 , CELL_PIXELS-2),
  120. (CELL_PIXELS-2, CELL_PIXELS-2),
  121. (CELL_PIXELS-2, 2),
  122. (2 , 2)
  123. ])
  124. r.drawCircle(CELL_PIXELS * 0.75, CELL_PIXELS * 0.5, 2)
  125. def toggle(self, env, pos):
  126. if not self.isOpen:
  127. self.isOpen = True
  128. return True
  129. return False
  130. def canOverlap(self):
  131. """The agent can only walk over this cell when the door is open"""
  132. return self.isOpen
  133. class LockedDoor(WorldObj):
  134. def __init__(self, color, isOpen=False):
  135. super(LockedDoor, self).__init__('locked_door', color)
  136. self.isOpen = isOpen
  137. def render(self, r):
  138. c = COLORS[self.color]
  139. r.setLineColor(c[0], c[1], c[2])
  140. r.setColor(c[0], c[1], c[2], 50)
  141. if self.isOpen:
  142. r.drawPolygon([
  143. (CELL_PIXELS-2, CELL_PIXELS),
  144. (CELL_PIXELS , CELL_PIXELS),
  145. (CELL_PIXELS , 0),
  146. (CELL_PIXELS-2, 0)
  147. ])
  148. return
  149. r.drawPolygon([
  150. (0 , CELL_PIXELS),
  151. (CELL_PIXELS, CELL_PIXELS),
  152. (CELL_PIXELS, 0),
  153. (0 , 0)
  154. ])
  155. r.drawPolygon([
  156. (2 , CELL_PIXELS-2),
  157. (CELL_PIXELS-2, CELL_PIXELS-2),
  158. (CELL_PIXELS-2, 2),
  159. (2 , 2)
  160. ])
  161. r.drawLine(
  162. CELL_PIXELS * 0.55,
  163. CELL_PIXELS * 0.5,
  164. CELL_PIXELS * 0.75,
  165. CELL_PIXELS * 0.5
  166. )
  167. def toggle(self, env, pos):
  168. # If the player has the right key to open the door
  169. if isinstance(env.carrying, Key) and env.carrying.color == self.color:
  170. self.isOpen = True
  171. # The key has been used, remove it from the agent
  172. env.carrying = None
  173. return True
  174. return False
  175. def canOverlap(self):
  176. """The agent can only walk over this cell when the door is open"""
  177. return self.isOpen
  178. class Key(WorldObj):
  179. def __init__(self, color='blue'):
  180. super(Key, self).__init__('key', color)
  181. def canPickup(self):
  182. return True
  183. def render(self, r):
  184. self._setColor(r)
  185. # Vertical quad
  186. r.drawPolygon([
  187. (16, 10),
  188. (20, 10),
  189. (20, 28),
  190. (16, 28)
  191. ])
  192. # Teeth
  193. r.drawPolygon([
  194. (12, 19),
  195. (16, 19),
  196. (16, 21),
  197. (12, 21)
  198. ])
  199. r.drawPolygon([
  200. (12, 26),
  201. (16, 26),
  202. (16, 28),
  203. (12, 28)
  204. ])
  205. r.drawCircle(18, 9, 6)
  206. r.setLineColor(0, 0, 0)
  207. r.setColor(0, 0, 0)
  208. r.drawCircle(18, 9, 2)
  209. class Ball(WorldObj):
  210. def __init__(self, color='blue'):
  211. super(Ball, self).__init__('ball', color)
  212. def canPickup(self):
  213. return True
  214. def render(self, r):
  215. self._setColor(r)
  216. r.drawCircle(CELL_PIXELS * 0.5, CELL_PIXELS * 0.5, 10)
  217. class Box(WorldObj):
  218. def __init__(self, color, contains=None):
  219. super(Box, self).__init__('box', color)
  220. self.contains = contains
  221. def render(self, r):
  222. c = COLORS[self.color]
  223. r.setLineColor(c[0], c[1], c[2])
  224. r.setColor(0, 0, 0)
  225. r.setLineWidth(2)
  226. r.drawPolygon([
  227. (4 , CELL_PIXELS-4),
  228. (CELL_PIXELS-4, CELL_PIXELS-4),
  229. (CELL_PIXELS-4, 4),
  230. (4 , 4)
  231. ])
  232. r.drawLine(
  233. 4,
  234. CELL_PIXELS / 2,
  235. CELL_PIXELS - 4,
  236. CELL_PIXELS / 2
  237. )
  238. r.setLineWidth(1)
  239. def toggle(self, env, pos):
  240. # Replace the box by its contents
  241. env.grid.set(*pos, self.contains)
  242. return True
  243. class Grid:
  244. """
  245. Represent a grid and operations on it
  246. """
  247. def __init__(self, width, height):
  248. assert width >= 4
  249. assert height >= 4
  250. self.width = width
  251. self.height = height
  252. self.grid = [None] * width * height
  253. def __contains__(self, key):
  254. if isinstance(key, WorldObj):
  255. for e in self.grid:
  256. if e is key:
  257. return True
  258. elif isinstance(key, tuple):
  259. for e in self.grid:
  260. if e is None:
  261. continue
  262. if (e.color, e.type) == key:
  263. return True
  264. return False
  265. def copy(self):
  266. from copy import deepcopy
  267. return deepcopy(self)
  268. def set(self, i, j, v):
  269. assert i >= 0 and i < self.width
  270. assert j >= 0 and j < self.height
  271. self.grid[j * self.width + i] = v
  272. def get(self, i, j):
  273. assert i >= 0 and i < self.width
  274. assert j >= 0 and j < self.height
  275. return self.grid[j * self.width + i]
  276. def horzWall(self, x, y, length=None):
  277. if length is None:
  278. length = self.width - x
  279. for i in range(0, length):
  280. self.set(x + i, y, Wall())
  281. def vertWall(self, x, y, length=None):
  282. if length is None:
  283. length = self.height - y
  284. for j in range(0, length):
  285. self.set(x, y + j, Wall())
  286. def wallRect(self, x, y, w, h):
  287. self.horzWall(x, y, w)
  288. self.horzWall(x, y+h-1, w)
  289. self.vertWall(x, y, h)
  290. self.vertWall(x+w-1, y, h)
  291. def rotateLeft(self):
  292. """
  293. Rotate the grid to the left (counter-clockwise)
  294. """
  295. grid = Grid(self.width, self.height)
  296. for j in range(0, self.height):
  297. for i in range(0, self.width):
  298. v = self.get(self.width - 1 - j, i)
  299. grid.set(i, j, v)
  300. return grid
  301. def slice(self, topX, topY, width, height):
  302. """
  303. Get a subset of the grid
  304. """
  305. grid = Grid(width, height)
  306. for j in range(0, height):
  307. for i in range(0, width):
  308. x = topX + i
  309. y = topY + j
  310. if x >= 0 and x < self.width and \
  311. y >= 0 and y < self.height:
  312. v = self.get(x, y)
  313. else:
  314. v = Wall()
  315. grid.set(i, j, v)
  316. return grid
  317. def render(self, r, tileSize):
  318. """
  319. Render this grid at a given scale
  320. :param r: target renderer object
  321. :param tileSize: tile size in pixels
  322. """
  323. assert r.width == self.width * tileSize
  324. assert r.height == self.height * tileSize
  325. # Total grid size at native scale
  326. widthPx = self.width * CELL_PIXELS
  327. heightPx = self.height * CELL_PIXELS
  328. # Draw background (out-of-world) tiles the same colors as walls
  329. # so the agent understands these areas are not reachable
  330. c = COLORS['grey']
  331. r.setLineColor(c[0], c[1], c[2])
  332. r.setColor(c[0], c[1], c[2])
  333. r.drawPolygon([
  334. (0 , heightPx),
  335. (widthPx, heightPx),
  336. (widthPx, 0),
  337. (0 , 0)
  338. ])
  339. r.push()
  340. # Internally, we draw at the "large" full-grid resolution, but we
  341. # use the renderer to scale back to the desired size
  342. r.scale(tileSize / CELL_PIXELS, tileSize / CELL_PIXELS)
  343. # Draw the background of the in-world cells black
  344. r.fillRect(
  345. 0,
  346. 0,
  347. widthPx,
  348. heightPx,
  349. 0, 0, 0
  350. )
  351. # Draw grid lines
  352. r.setLineColor(100, 100, 100)
  353. for rowIdx in range(0, self.height):
  354. y = CELL_PIXELS * rowIdx
  355. r.drawLine(0, y, widthPx, y)
  356. for colIdx in range(0, self.width):
  357. x = CELL_PIXELS * colIdx
  358. r.drawLine(x, 0, x, heightPx)
  359. # Render the grid
  360. for j in range(0, self.height):
  361. for i in range(0, self.width):
  362. cell = self.get(i, j)
  363. if cell == None:
  364. continue
  365. r.push()
  366. r.translate(i * CELL_PIXELS, j * CELL_PIXELS)
  367. cell.render(r)
  368. r.pop()
  369. r.pop()
  370. def encode(self):
  371. """
  372. Produce a compact numpy encoding of the grid
  373. """
  374. codeSize = self.width * self.height * 3
  375. array = np.zeros(shape=(self.width, self.height, 3), dtype='uint8')
  376. for j in range(0, self.height):
  377. for i in range(0, self.width):
  378. v = self.get(i, j)
  379. if v == None:
  380. continue
  381. array[i, j, 0] = OBJECT_TO_IDX[v.type]
  382. array[i, j, 1] = COLOR_TO_IDX[v.color]
  383. if hasattr(v, 'isOpen') and v.isOpen:
  384. array[i, j, 2] = 1
  385. return array
  386. def decode(array):
  387. """
  388. Decode an array grid encoding back into a grid
  389. """
  390. width = array.shape[0]
  391. height = array.shape[1]
  392. assert array.shape[2] == 3
  393. grid = Grid(width, height)
  394. for j in range(0, height):
  395. for i in range(0, width):
  396. typeIdx = array[i, j, 0]
  397. colorIdx = array[i, j, 1]
  398. openIdx = array[i, j, 2]
  399. if typeIdx == 0:
  400. continue
  401. objType = IDX_TO_OBJECT[typeIdx]
  402. color = IDX_TO_COLOR[colorIdx]
  403. isOpen = True if openIdx == 1 else 0
  404. if objType == 'wall':
  405. v = Wall(color)
  406. elif objType == 'ball':
  407. v = Ball(color)
  408. elif objType == 'key':
  409. v = Key(color)
  410. elif objType == 'box':
  411. v = Box(color)
  412. elif objType == 'door':
  413. v = Door(color, isOpen)
  414. elif objType == 'locked_door':
  415. v = LockedDoor(color, isOpen)
  416. elif objType == 'goal':
  417. v = Goal()
  418. else:
  419. assert False, "unknown obj type in decode '%s'" % objType
  420. grid.set(i, j, v)
  421. return grid
  422. class MiniGridEnv(gym.Env):
  423. """
  424. 2D grid world game environment
  425. """
  426. metadata = {
  427. 'render.modes': ['human', 'rgb_array', 'pixmap'],
  428. 'video.frames_per_second' : 10
  429. }
  430. # Enumeration of possible actions
  431. class Actions(IntEnum):
  432. # Turn left, turn right, move forward
  433. left = 0
  434. right = 1
  435. forward = 2
  436. # Toggle/pick up/activate object
  437. toggle = 3
  438. # Wait/stay put/do nothing
  439. wait = 4
  440. def __init__(self, gridSize=16, maxSteps=100):
  441. # Action enumeration for this environment
  442. self.actions = MiniGridEnv.Actions
  443. # Actions are discrete integer values
  444. self.action_space = spaces.Discrete(len(self.actions))
  445. # Observations are dictionaries containing an
  446. # encoding of the grid and a textual 'mission' string
  447. self.observation_space = spaces.Box(
  448. low=0,
  449. high=255,
  450. shape=OBS_ARRAY_SIZE,
  451. dtype='uint8'
  452. )
  453. self.observation_space = spaces.Dict({
  454. 'image': self.observation_space
  455. })
  456. # Range of possible rewards
  457. self.reward_range = (-1, 1000)
  458. # Renderer object used to render the whole grid (full-scale)
  459. self.gridRender = None
  460. # Renderer used to render observations (small-scale agent view)
  461. self.obsRender = None
  462. # Environment configuration
  463. self.gridSize = gridSize
  464. self.maxSteps = maxSteps
  465. self.startPos = (1, 1)
  466. self.startDir = 0
  467. # Initialize the state
  468. self.seed()
  469. self.reset()
  470. def _genGrid(self, width, height):
  471. assert False, "_genGrid needs to be implemented by each environment"
  472. def reset(self):
  473. # Generate a new random grid at the start of each episode
  474. # To keep the same grid for each episode, call env.seed() with
  475. # the same seed before calling env.reset()
  476. self._genGrid(self.gridSize, self.gridSize)
  477. # Place the agent in the starting position and direction
  478. self.agentPos = self.startPos
  479. self.agentDir = self.startDir
  480. # Item picked up, being carried, initially nothing
  481. self.carrying = None
  482. # Step count since episode start
  483. self.stepCount = 0
  484. # Return first observation
  485. obs = self._genObs()
  486. return obs
  487. def seed(self, seed=1337):
  488. # Seed the random number generator
  489. self.np_random, _ = seeding.np_random(seed)
  490. return [seed]
  491. def _randInt(self, low, high):
  492. """
  493. Generate random integer in [low,high[
  494. """
  495. return self.np_random.randint(low, high)
  496. def _randElem(self, iterable):
  497. """
  498. Pick a random element in a list
  499. """
  500. lst = list(iterable)
  501. idx = self._randInt(0, len(lst))
  502. return lst[idx]
  503. def _randPos(self, xLow, xHigh, yLow, yHigh):
  504. """
  505. Generate a random (x,y) position tuple
  506. """
  507. return (
  508. self.np_random.randint(xLow, xHigh),
  509. self.np_random.randint(yLow, yHigh)
  510. )
  511. def placeObj(self, obj, top=None, size=None):
  512. """
  513. Place an object at an empty position in the grid
  514. :param top: top-left position of the rectangle where to place
  515. :param size: size of the rectangle where to place
  516. """
  517. if top is None:
  518. top = (0, 0)
  519. if size is None:
  520. size = (self.grid.width, self.grid.height)
  521. while True:
  522. pos = (
  523. self._randInt(top[0], top[0] + size[0]),
  524. self._randInt(top[1], top[1] + size[1])
  525. )
  526. if self.grid.get(*pos) != None:
  527. continue
  528. if pos == self.startPos:
  529. continue
  530. break
  531. self.grid.set(*pos, obj)
  532. return pos
  533. def placeAgent(self, randDir=True):
  534. """
  535. Set the agent's starting point at an empty position in the grid
  536. """
  537. pos = self.placeObj(None)
  538. self.startPos = pos
  539. if randDir:
  540. self.startDir = self._randInt(0, 4)
  541. return pos
  542. def getStepsRemaining(self):
  543. return self.maxSteps - self.stepCount
  544. def getDirVec(self):
  545. """
  546. Get the direction vector for the agent, pointing in the direction
  547. of forward movement.
  548. """
  549. # Pointing right
  550. if self.agentDir == 0:
  551. return (1, 0)
  552. # Down (positive Y)
  553. elif self.agentDir == 1:
  554. return (0, 1)
  555. # Pointing left
  556. elif self.agentDir == 2:
  557. return (-1, 0)
  558. # Up (negative Y)
  559. elif self.agentDir == 3:
  560. return (0, -1)
  561. else:
  562. assert False
  563. def getViewExts(self):
  564. """
  565. Get the extents of the square set of tiles visible to the agent
  566. Note: the bottom extent indices are not included in the set
  567. """
  568. # Facing right
  569. if self.agentDir == 0:
  570. topX = self.agentPos[0]
  571. topY = self.agentPos[1] - AGENT_VIEW_SIZE // 2
  572. # Facing down
  573. elif self.agentDir == 1:
  574. topX = self.agentPos[0] - AGENT_VIEW_SIZE // 2
  575. topY = self.agentPos[1]
  576. # Facing left
  577. elif self.agentDir == 2:
  578. topX = self.agentPos[0] - AGENT_VIEW_SIZE + 1
  579. topY = self.agentPos[1] - AGENT_VIEW_SIZE // 2
  580. # Facing up
  581. elif self.agentDir == 3:
  582. topX = self.agentPos[0] - AGENT_VIEW_SIZE // 2
  583. topY = self.agentPos[1] - AGENT_VIEW_SIZE + 1
  584. else:
  585. assert False, "invalid agent direction"
  586. botX = topX + AGENT_VIEW_SIZE
  587. botY = topY + AGENT_VIEW_SIZE
  588. return (topX, topY, botX, botY)
  589. def agentSees(self, x, y):
  590. """
  591. Check if a grid position is visible to the agent
  592. """
  593. topX, topY, botX, botY = self.getViewExts()
  594. return (x >= topX and x < botX and y >= topY and y < botY)
  595. def step(self, action):
  596. self.stepCount += 1
  597. reward = 0
  598. done = False
  599. # Rotate left
  600. if action == self.actions.left:
  601. self.agentDir -= 1
  602. if self.agentDir < 0:
  603. self.agentDir += 4
  604. # Rotate right
  605. elif action == self.actions.right:
  606. self.agentDir = (self.agentDir + 1) % 4
  607. # Move forward
  608. elif action == self.actions.forward:
  609. u, v = self.getDirVec()
  610. newPos = (self.agentPos[0] + u, self.agentPos[1] + v)
  611. targetCell = self.grid.get(newPos[0], newPos[1])
  612. if targetCell == None or targetCell.canOverlap():
  613. self.agentPos = newPos
  614. elif targetCell.type == 'goal':
  615. done = True
  616. reward = 1000 - self.stepCount
  617. # Pick up or trigger/activate an item
  618. elif action == self.actions.toggle:
  619. u, v = self.getDirVec()
  620. objPos = (self.agentPos[0] + u, self.agentPos[1] + v)
  621. cell = self.grid.get(*objPos)
  622. if cell and cell.canPickup():
  623. if self.carrying is None:
  624. self.carrying = cell
  625. self.grid.set(*objPos, None)
  626. elif cell:
  627. cell.toggle(self, objPos)
  628. elif self.carrying:
  629. self.grid.set(*objPos, self.carrying)
  630. self.carrying = None
  631. # Wait/do nothing
  632. elif action == self.actions.wait:
  633. pass
  634. else:
  635. assert False, "unknown action"
  636. if self.stepCount >= self.maxSteps:
  637. done = True
  638. obs = self._genObs()
  639. return obs, reward, done, {}
  640. def _genObs(self):
  641. """
  642. Generate the agent's view (partially observable, low-resolution encoding)
  643. """
  644. topX, topY, botX, botY = self.getViewExts()
  645. grid = self.grid.slice(topX, topY, AGENT_VIEW_SIZE, AGENT_VIEW_SIZE)
  646. for i in range(self.agentDir + 1):
  647. grid = grid.rotateLeft()
  648. # Make it so the agent sees what it's carrying
  649. # We do this by placing the carried object at the agent's position
  650. # in the agent's partially observable view
  651. agentPos = grid.width // 2, grid.height - 1
  652. if self.carrying:
  653. grid.set(*agentPos, self.carrying)
  654. else:
  655. grid.set(*agentPos, None)
  656. # Encode the partially observable view into a numpy array
  657. image = grid.encode()
  658. assert hasattr(self, 'mission'), "environments must define a textual mission string"
  659. # Observations are dictionaries with both an image
  660. # and a textual mission string
  661. obs = {
  662. 'image': image,
  663. 'mission': self.mission
  664. }
  665. return obs
  666. def getObsRender(self, obs):
  667. """
  668. Render an agent observation for visualization
  669. """
  670. if self.obsRender == None:
  671. self.obsRender = Renderer(
  672. AGENT_VIEW_SIZE * CELL_PIXELS // 2,
  673. AGENT_VIEW_SIZE * CELL_PIXELS // 2
  674. )
  675. r = self.obsRender
  676. r.beginFrame()
  677. grid = Grid.decode(obs)
  678. # Render the whole grid
  679. grid.render(r, CELL_PIXELS // 2)
  680. # Draw the agent
  681. r.push()
  682. r.scale(0.5, 0.5)
  683. r.translate(
  684. CELL_PIXELS * (0.5 + AGENT_VIEW_SIZE // 2),
  685. CELL_PIXELS * (AGENT_VIEW_SIZE - 0.5)
  686. )
  687. r.rotate(3 * 90)
  688. r.setLineColor(255, 0, 0)
  689. r.setColor(255, 0, 0)
  690. r.drawPolygon([
  691. (-12, 10),
  692. ( 12, 0),
  693. (-12, -10)
  694. ])
  695. r.pop()
  696. r.endFrame()
  697. return r.getPixmap()
  698. def render(self, mode='human', close=False):
  699. """
  700. Render the whole-grid human view
  701. """
  702. if close:
  703. if self.gridRender:
  704. self.gridRender.close()
  705. return
  706. if self.gridRender is None:
  707. self.gridRender = Renderer(
  708. self.gridSize * CELL_PIXELS,
  709. self.gridSize * CELL_PIXELS,
  710. True if mode == 'human' else False
  711. )
  712. r = self.gridRender
  713. r.beginFrame()
  714. # Render the whole grid
  715. self.grid.render(r, CELL_PIXELS)
  716. # Draw the agent
  717. r.push()
  718. r.translate(
  719. CELL_PIXELS * (self.agentPos[0] + 0.5),
  720. CELL_PIXELS * (self.agentPos[1] + 0.5)
  721. )
  722. r.rotate(self.agentDir * 90)
  723. r.setLineColor(255, 0, 0)
  724. r.setColor(255, 0, 0)
  725. r.drawPolygon([
  726. (-12, 10),
  727. ( 12, 0),
  728. (-12, -10)
  729. ])
  730. r.pop()
  731. # Highlight what the agent can see
  732. topX, topY, botX, botY = self.getViewExts()
  733. r.fillRect(
  734. topX * CELL_PIXELS,
  735. topY * CELL_PIXELS,
  736. AGENT_VIEW_SIZE * CELL_PIXELS,
  737. AGENT_VIEW_SIZE * CELL_PIXELS,
  738. 200, 200, 200, 75
  739. )
  740. r.endFrame()
  741. if mode == 'rgb_array':
  742. return r.getArray()
  743. elif mode == 'pixmap':
  744. return r.getPixmap()
  745. return r