minigrid.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978
  1. import math
  2. import gym
  3. from enum import IntEnum
  4. import numpy as np
  5. from gym import error, spaces, utils
  6. from gym.utils import seeding
  7. from gym_minigrid.rendering import *
  8. # Size in pixels of a cell in the full-scale human view
  9. CELL_PIXELS = 32
  10. # Number of cells (width and height) in the agent view
  11. AGENT_VIEW_SIZE = 7
  12. # Size of the array given as an observation to the agent
  13. OBS_ARRAY_SIZE = (AGENT_VIEW_SIZE, AGENT_VIEW_SIZE, 3)
  14. # Map of color names to RGB values
  15. COLORS = {
  16. 'red' : (255, 0, 0),
  17. 'green' : (0, 255, 0),
  18. 'blue' : (0, 0, 255),
  19. 'purple': (112, 39, 195),
  20. 'yellow': (255, 255, 0),
  21. 'grey' : (100, 100, 100)
  22. }
  23. COLOR_NAMES = sorted(list(COLORS.keys()))
  24. # Used to map colors to integers
  25. COLOR_TO_IDX = {
  26. 'red' : 0,
  27. 'green' : 1,
  28. 'blue' : 2,
  29. 'purple': 3,
  30. 'yellow': 4,
  31. 'grey' : 5
  32. }
  33. IDX_TO_COLOR = dict(zip(COLOR_TO_IDX.values(), COLOR_TO_IDX.keys()))
  34. # Map of object type to integers
  35. OBJECT_TO_IDX = {
  36. 'empty' : 0,
  37. 'wall' : 1,
  38. 'door' : 2,
  39. 'locked_door' : 3,
  40. 'key' : 4,
  41. 'ball' : 5,
  42. 'box' : 6,
  43. 'goal' : 7
  44. }
  45. IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys()))
  46. class WorldObj:
  47. """
  48. Base class for grid world objects
  49. """
  50. def __init__(self, type, color):
  51. assert type in OBJECT_TO_IDX, type
  52. assert color in COLOR_TO_IDX, color
  53. self.type = type
  54. self.color = color
  55. self.contains = None
  56. def isOpaque(self):
  57. """Can the agent see objects that are behind this?"""
  58. return False
  59. def canOverlap(self):
  60. """Can the agent overlap with this?"""
  61. return False
  62. def canPickup(self):
  63. """Can the agent pick this up?"""
  64. return False
  65. def canContain(self):
  66. """Can this contain another object?"""
  67. return False
  68. def toggle(self, env, pos):
  69. """Method to trigger/toggle an action this object performs"""
  70. return False
  71. def render(self, r):
  72. assert False
  73. def _setColor(self, r):
  74. c = COLORS[self.color]
  75. r.setLineColor(c[0], c[1], c[2])
  76. r.setColor(c[0], c[1], c[2])
  77. class Goal(WorldObj):
  78. def __init__(self):
  79. super(Goal, self).__init__('goal', 'green')
  80. def render(self, r):
  81. self._setColor(r)
  82. r.drawPolygon([
  83. (0 , CELL_PIXELS),
  84. (CELL_PIXELS, CELL_PIXELS),
  85. (CELL_PIXELS, 0),
  86. (0 , 0)
  87. ])
  88. class Wall(WorldObj):
  89. def __init__(self, color='grey'):
  90. super(Wall, self).__init__('wall', color)
  91. def isOpaque(self):
  92. return True
  93. def render(self, r):
  94. self._setColor(r)
  95. r.drawPolygon([
  96. (0 , CELL_PIXELS),
  97. (CELL_PIXELS, CELL_PIXELS),
  98. (CELL_PIXELS, 0),
  99. (0 , 0)
  100. ])
  101. class Door(WorldObj):
  102. def __init__(self, color, isOpen=False, type='door'):
  103. super().__init__('door', color)
  104. self.isOpen = isOpen
  105. def isOpaque(self):
  106. return not self.isOpen
  107. def canOverlap(self):
  108. """The agent can only walk over this cell when the door is open"""
  109. return self.isOpen
  110. def toggle(self, env, pos):
  111. if not self.isOpen:
  112. self.isOpen = True
  113. return True
  114. return False
  115. def render(self, r):
  116. c = COLORS[self.color]
  117. r.setLineColor(c[0], c[1], c[2])
  118. r.setColor(0, 0, 0)
  119. if self.isOpen:
  120. r.drawPolygon([
  121. (CELL_PIXELS-2, CELL_PIXELS),
  122. (CELL_PIXELS , CELL_PIXELS),
  123. (CELL_PIXELS , 0),
  124. (CELL_PIXELS-2, 0)
  125. ])
  126. return
  127. r.drawPolygon([
  128. (0 , CELL_PIXELS),
  129. (CELL_PIXELS, CELL_PIXELS),
  130. (CELL_PIXELS, 0),
  131. (0 , 0)
  132. ])
  133. r.drawPolygon([
  134. (2 , CELL_PIXELS-2),
  135. (CELL_PIXELS-2, CELL_PIXELS-2),
  136. (CELL_PIXELS-2, 2),
  137. (2 , 2)
  138. ])
  139. r.drawCircle(CELL_PIXELS * 0.75, CELL_PIXELS * 0.5, 2)
  140. class LockedDoor(Door):
  141. def __init__(self, color, isOpen=False):
  142. super().__init__(color, isOpen, type='locked_door')
  143. def toggle(self, env, pos):
  144. # If the player has the right key to open the door
  145. if isinstance(env.carrying, Key) and env.carrying.color == self.color:
  146. self.isOpen = True
  147. # The key has been used, remove it from the agent
  148. env.carrying = None
  149. return True
  150. return False
  151. def render(self, r):
  152. c = COLORS[self.color]
  153. r.setLineColor(c[0], c[1], c[2])
  154. r.setColor(c[0], c[1], c[2], 50)
  155. if self.isOpen:
  156. r.drawPolygon([
  157. (CELL_PIXELS-2, CELL_PIXELS),
  158. (CELL_PIXELS , CELL_PIXELS),
  159. (CELL_PIXELS , 0),
  160. (CELL_PIXELS-2, 0)
  161. ])
  162. return
  163. r.drawPolygon([
  164. (0 , CELL_PIXELS),
  165. (CELL_PIXELS, CELL_PIXELS),
  166. (CELL_PIXELS, 0),
  167. (0 , 0)
  168. ])
  169. r.drawPolygon([
  170. (2 , CELL_PIXELS-2),
  171. (CELL_PIXELS-2, CELL_PIXELS-2),
  172. (CELL_PIXELS-2, 2),
  173. (2 , 2)
  174. ])
  175. r.drawLine(
  176. CELL_PIXELS * 0.55,
  177. CELL_PIXELS * 0.5,
  178. CELL_PIXELS * 0.75,
  179. CELL_PIXELS * 0.5
  180. )
  181. class Key(WorldObj):
  182. def __init__(self, color='blue'):
  183. super(Key, self).__init__('key', color)
  184. def canPickup(self):
  185. return True
  186. def render(self, r):
  187. self._setColor(r)
  188. # Vertical quad
  189. r.drawPolygon([
  190. (16, 10),
  191. (20, 10),
  192. (20, 28),
  193. (16, 28)
  194. ])
  195. # Teeth
  196. r.drawPolygon([
  197. (12, 19),
  198. (16, 19),
  199. (16, 21),
  200. (12, 21)
  201. ])
  202. r.drawPolygon([
  203. (12, 26),
  204. (16, 26),
  205. (16, 28),
  206. (12, 28)
  207. ])
  208. r.drawCircle(18, 9, 6)
  209. r.setLineColor(0, 0, 0)
  210. r.setColor(0, 0, 0)
  211. r.drawCircle(18, 9, 2)
  212. class Ball(WorldObj):
  213. def __init__(self, color='blue'):
  214. super(Ball, self).__init__('ball', color)
  215. def canPickup(self):
  216. return True
  217. def render(self, r):
  218. self._setColor(r)
  219. r.drawCircle(CELL_PIXELS * 0.5, CELL_PIXELS * 0.5, 10)
  220. class Box(WorldObj):
  221. def __init__(self, color, contains=None):
  222. super(Box, self).__init__('box', color)
  223. self.contains = contains
  224. def render(self, r):
  225. c = COLORS[self.color]
  226. r.setLineColor(c[0], c[1], c[2])
  227. r.setColor(0, 0, 0)
  228. r.setLineWidth(2)
  229. r.drawPolygon([
  230. (4 , CELL_PIXELS-4),
  231. (CELL_PIXELS-4, CELL_PIXELS-4),
  232. (CELL_PIXELS-4, 4),
  233. (4 , 4)
  234. ])
  235. r.drawLine(
  236. 4,
  237. CELL_PIXELS / 2,
  238. CELL_PIXELS - 4,
  239. CELL_PIXELS / 2
  240. )
  241. r.setLineWidth(1)
  242. def toggle(self, env, pos):
  243. # Replace the box by its contents
  244. env.grid.set(*pos, self.contains)
  245. return True
  246. class Grid:
  247. """
  248. Represent a grid and operations on it
  249. """
  250. def __init__(self, width, height):
  251. assert width >= 4
  252. assert height >= 4
  253. self.width = width
  254. self.height = height
  255. self.grid = [None] * width * height
  256. def __contains__(self, key):
  257. if isinstance(key, WorldObj):
  258. for e in self.grid:
  259. if e is key:
  260. return True
  261. elif isinstance(key, tuple):
  262. for e in self.grid:
  263. if e is None:
  264. continue
  265. if (e.color, e.type) == key:
  266. return True
  267. return False
  268. def copy(self):
  269. from copy import deepcopy
  270. return deepcopy(self)
  271. def set(self, i, j, v):
  272. assert i >= 0 and i < self.width
  273. assert j >= 0 and j < self.height
  274. self.grid[j * self.width + i] = v
  275. def get(self, i, j):
  276. assert i >= 0 and i < self.width
  277. assert j >= 0 and j < self.height
  278. return self.grid[j * self.width + i]
  279. def horzWall(self, x, y, length=None):
  280. if length is None:
  281. length = self.width - x
  282. for i in range(0, length):
  283. self.set(x + i, y, Wall())
  284. def vertWall(self, x, y, length=None):
  285. if length is None:
  286. length = self.height - y
  287. for j in range(0, length):
  288. self.set(x, y + j, Wall())
  289. def rotateLeft(self):
  290. """
  291. Rotate the grid to the left (counter-clockwise)
  292. """
  293. grid = Grid(self.width, self.height)
  294. for j in range(0, self.height):
  295. for i in range(0, self.width):
  296. v = self.get(self.width - 1 - j, i)
  297. grid.set(i, j, v)
  298. return grid
  299. def slice(self, topX, topY, width, height):
  300. """
  301. Get a subset of the grid
  302. """
  303. grid = Grid(width, height)
  304. for j in range(0, height):
  305. for i in range(0, width):
  306. x = topX + i
  307. y = topY + j
  308. if x >= 0 and x < self.width and \
  309. y >= 0 and y < self.height:
  310. v = self.get(x, y)
  311. else:
  312. v = Wall()
  313. grid.set(i, j, v)
  314. return grid
  315. def render(self, r, tileSize):
  316. """
  317. Render this grid at a given scale
  318. :param r: target renderer object
  319. :param tileSize: tile size in pixels
  320. """
  321. assert r.width == self.width * tileSize
  322. assert r.height == self.height * tileSize
  323. # Total grid size at native scale
  324. widthPx = self.width * CELL_PIXELS
  325. heightPx = self.height * CELL_PIXELS
  326. # Draw background (out-of-world) tiles the same colors as walls
  327. # so the agent understands these areas are not reachable
  328. c = COLORS['grey']
  329. r.setLineColor(c[0], c[1], c[2])
  330. r.setColor(c[0], c[1], c[2])
  331. r.drawPolygon([
  332. (0 , heightPx),
  333. (widthPx, heightPx),
  334. (widthPx, 0),
  335. (0 , 0)
  336. ])
  337. r.push()
  338. # Internally, we draw at the "large" full-grid resolution, but we
  339. # use the renderer to scale back to the desired size
  340. r.scale(tileSize / CELL_PIXELS, tileSize / CELL_PIXELS)
  341. # Draw the background of the in-world cells black
  342. r.fillRect(
  343. 0,
  344. 0,
  345. widthPx,
  346. heightPx,
  347. 0, 0, 0
  348. )
  349. # Draw grid lines
  350. r.setLineColor(100, 100, 100)
  351. for rowIdx in range(0, self.height):
  352. y = CELL_PIXELS * rowIdx
  353. r.drawLine(0, y, widthPx, y)
  354. for colIdx in range(0, self.width):
  355. x = CELL_PIXELS * colIdx
  356. r.drawLine(x, 0, x, heightPx)
  357. # Render the grid
  358. for j in range(0, self.height):
  359. for i in range(0, self.width):
  360. cell = self.get(i, j)
  361. if cell == None:
  362. continue
  363. r.push()
  364. r.translate(i * CELL_PIXELS, j * CELL_PIXELS)
  365. cell.render(r)
  366. r.pop()
  367. r.pop()
  368. def encode(self):
  369. """
  370. Produce a compact numpy encoding of the grid
  371. """
  372. codeSize = self.width * self.height * 3
  373. array = np.zeros(shape=(self.width, self.height, 3), dtype='uint8')
  374. for j in range(0, self.height):
  375. for i in range(0, self.width):
  376. v = self.get(i, j)
  377. if v == None:
  378. continue
  379. array[i, j, 0] = OBJECT_TO_IDX[v.type]
  380. array[i, j, 1] = COLOR_TO_IDX[v.color]
  381. if hasattr(v, 'isOpen') and v.isOpen:
  382. array[i, j, 2] = 1
  383. return array
  384. def decode(array):
  385. """
  386. Decode an array grid encoding back into a grid
  387. """
  388. width = array.shape[0]
  389. height = array.shape[1]
  390. assert array.shape[2] == 3
  391. grid = Grid(width, height)
  392. for j in range(0, height):
  393. for i in range(0, width):
  394. typeIdx = array[i, j, 0]
  395. colorIdx = array[i, j, 1]
  396. openIdx = array[i, j, 2]
  397. if typeIdx == 0:
  398. continue
  399. objType = IDX_TO_OBJECT[typeIdx]
  400. color = IDX_TO_COLOR[colorIdx]
  401. isOpen = True if openIdx == 1 else 0
  402. if objType == 'wall':
  403. v = Wall(color)
  404. elif objType == 'ball':
  405. v = Ball(color)
  406. elif objType == 'key':
  407. v = Key(color)
  408. elif objType == 'box':
  409. v = Box(color)
  410. elif objType == 'door':
  411. v = Door(color, isOpen)
  412. elif objType == 'locked_door':
  413. v = LockedDoor(color, isOpen)
  414. elif objType == 'goal':
  415. v = Goal()
  416. else:
  417. assert False, "unknown obj type in decode '%s'" % objType
  418. grid.set(i, j, v)
  419. return grid
  420. class MiniGridEnv(gym.Env):
  421. """
  422. 2D grid world game environment
  423. """
  424. metadata = {
  425. 'render.modes': ['human', 'rgb_array', 'pixmap'],
  426. 'video.frames_per_second' : 10
  427. }
  428. # Enumeration of possible actions
  429. class Actions(IntEnum):
  430. left = 0
  431. right = 1
  432. forward = 2
  433. # Toggle/pick up/activate object
  434. toggle = 3
  435. # Wait/stay put/do nothing
  436. wait = 4
  437. def __init__(self, gridSize=16, maxSteps=100):
  438. # Action enumeration for this environment
  439. self.actions = MiniGridEnv.Actions
  440. # Actions are discrete integer values
  441. self.action_space = spaces.Discrete(len(self.actions))
  442. # Observations are dictionaries containing an
  443. # encoding of the grid and a textual 'mission' string
  444. self.observation_space = spaces.Box(
  445. low=0,
  446. high=255,
  447. shape=OBS_ARRAY_SIZE,
  448. dtype='uint8'
  449. )
  450. self.observation_space = spaces.Dict({
  451. 'image': self.observation_space
  452. })
  453. # Range of possible rewards
  454. self.reward_range = (-1, 1000)
  455. # Renderer object used to render the whole grid (full-scale)
  456. self.gridRender = None
  457. # Renderer used to render observations (small-scale agent view)
  458. self.obsRender = None
  459. # Environment configuration
  460. self.gridSize = gridSize
  461. self.maxSteps = maxSteps
  462. self.startPos = (1, 1)
  463. self.startDir = 0
  464. # Initialize the state
  465. self.seed()
  466. self.reset()
  467. def _genGrid(self, width, height):
  468. assert False, "_genGrid needs to be implemented by each environment"
  469. def reset(self):
  470. # Generate a new random grid at the start of each episode
  471. # To keep the same grid for each episode, call env.seed() with
  472. # the same seed before calling env.reset()
  473. self._genGrid(self.gridSize, self.gridSize)
  474. # Place the agent in the starting position and direction
  475. self.agentPos = self.startPos
  476. self.agentDir = self.startDir
  477. # Item picked up, being carried, initially nothing
  478. self.carrying = None
  479. # Step count since episode start
  480. self.stepCount = 0
  481. # Return first observation
  482. obs = self._genObs()
  483. return obs
  484. def seed(self, seed=1337):
  485. # Seed the random number generator
  486. self.np_random, _ = seeding.np_random(seed)
  487. return [seed]
  488. def _randInt(self, low, high):
  489. """
  490. Generate random integer in [low,high[
  491. """
  492. return self.np_random.randint(low, high)
  493. def _randElem(self, iterable):
  494. """
  495. Pick a random element in a list
  496. """
  497. lst = list(iterable)
  498. idx = self._randInt(0, len(lst))
  499. return lst[idx]
  500. def _randPos(self, xLow, xHigh, yLow, yHigh):
  501. """
  502. Generate a random (x,y) position tuple
  503. """
  504. return (
  505. self.np_random.randint(xLow, xHigh),
  506. self.np_random.randint(yLow, yHigh)
  507. )
  508. def placeObj(self, obj):
  509. """
  510. Place an object at an empty position in the grid
  511. """
  512. while True:
  513. pos = (
  514. self._randInt(0, self.grid.width),
  515. self._randInt(0, self.grid.height)
  516. )
  517. if self.grid.get(*pos) != None:
  518. continue
  519. if pos == self.startPos:
  520. continue
  521. break
  522. self.grid.set(*pos, obj)
  523. return pos
  524. def placeAgent(self, randDir=True):
  525. """
  526. Set the agent's starting point at an empty position in the grid
  527. """
  528. pos = self.placeObj(None)
  529. self.startPos = pos
  530. if randDir:
  531. self.startDir = self._randInt(0, 4)
  532. return pos
  533. def getStepsRemaining(self):
  534. return self.maxSteps - self.stepCount
  535. def getDirVec(self):
  536. """
  537. Get the direction vector for the agent, pointing in the direction
  538. of forward movement.
  539. """
  540. # Pointing right
  541. if self.agentDir == 0:
  542. return (1, 0)
  543. # Down (positive Y)
  544. elif self.agentDir == 1:
  545. return (0, 1)
  546. # Pointing left
  547. elif self.agentDir == 2:
  548. return (-1, 0)
  549. # Up (negative Y)
  550. elif self.agentDir == 3:
  551. return (0, -1)
  552. else:
  553. assert False
  554. def getViewExts(self):
  555. """
  556. Get the extents of the square set of tiles visible to the agent
  557. Note: the bottom extent indices are not included in the set
  558. """
  559. # Facing right
  560. if self.agentDir == 0:
  561. topX = self.agentPos[0]
  562. topY = self.agentPos[1] - AGENT_VIEW_SIZE // 2
  563. # Facing down
  564. elif self.agentDir == 1:
  565. topX = self.agentPos[0] - AGENT_VIEW_SIZE // 2
  566. topY = self.agentPos[1]
  567. # Facing right
  568. elif self.agentDir == 2:
  569. topX = self.agentPos[0] - AGENT_VIEW_SIZE + 1
  570. topY = self.agentPos[1] - AGENT_VIEW_SIZE // 2
  571. # Facing up
  572. elif self.agentDir == 3:
  573. topX = self.agentPos[0] - AGENT_VIEW_SIZE // 2
  574. topY = self.agentPos[1] - AGENT_VIEW_SIZE + 1
  575. else:
  576. assert False, "invalid agent direction"
  577. botX = topX + AGENT_VIEW_SIZE
  578. botY = topY + AGENT_VIEW_SIZE
  579. return (topX, topY, botX, botY)
  580. def agentSees(self, x, y):
  581. """
  582. Check if a grid position is visible to the agent
  583. """
  584. topX, topY, botX, botY = self.getViewExts()
  585. return (x >= topX and x < botX and y >= topY and y < botY)
  586. def step(self, action):
  587. self.stepCount += 1
  588. reward = 0
  589. done = False
  590. # Rotate left
  591. if action == self.actions.left:
  592. self.agentDir -= 1
  593. if self.agentDir < 0:
  594. self.agentDir += 4
  595. # Rotate right
  596. elif action == self.actions.right:
  597. self.agentDir = (self.agentDir + 1) % 4
  598. # Move forward
  599. elif action == self.actions.forward:
  600. u, v = self.getDirVec()
  601. newPos = (self.agentPos[0] + u, self.agentPos[1] + v)
  602. targetCell = self.grid.get(newPos[0], newPos[1])
  603. if targetCell == None or targetCell.canOverlap():
  604. self.agentPos = newPos
  605. elif targetCell.type == 'goal':
  606. done = True
  607. reward = 1000 - self.stepCount
  608. # Pick up or trigger/activate an item
  609. elif action == self.actions.toggle:
  610. u, v = self.getDirVec()
  611. objPos = (self.agentPos[0] + u, self.agentPos[1] + v)
  612. cell = self.grid.get(*objPos)
  613. if cell and cell.canPickup():
  614. if self.carrying is None:
  615. self.carrying = cell
  616. self.grid.set(*objPos, None)
  617. elif cell:
  618. cell.toggle(self, objPos)
  619. elif self.carrying:
  620. self.grid.set(*objPos, self.carrying)
  621. self.carrying = None
  622. # Wait/do nothing
  623. elif action == self.actions.wait:
  624. pass
  625. else:
  626. assert False, "unknown action"
  627. if self.stepCount >= self.maxSteps:
  628. done = True
  629. obs = self._genObs()
  630. return obs, reward, done, {}
  631. def _getObsGrid(self):
  632. topX, topY, botX, botY = self.getViewExts()
  633. grid = self.grid.slice(topX, topY, AGENT_VIEW_SIZE, AGENT_VIEW_SIZE)
  634. for i in range(self.agentDir+1):
  635. grid = grid.rotateLeft()
  636. # Matrix of opaque objects (rows going forward, columns)
  637. objMatrix = np.zeros((AGENT_VIEW_SIZE, AGENT_VIEW_SIZE))
  638. for i in range(0, AGENT_VIEW_SIZE):
  639. for j in range(0, AGENT_VIEW_SIZE):
  640. obj = grid.get(i, AGENT_VIEW_SIZE - (1+j))
  641. if obj is not None and obj.isOpaque():
  642. print('opaque')
  643. objMatrix[i, j] = 1
  644. visMatrix = np.zeros((AGENT_VIEW_SIZE, AGENT_VIEW_SIZE))
  645. mid = AGENT_VIEW_SIZE // 2
  646. # The agent can see its own position
  647. visMatrix[mid, 0] = 1
  648. # Agent can see in the first row if not obstructed
  649. for k in range(0, mid):
  650. i = mid + 1 + k
  651. if visMatrix[i-1, 0] and not objMatrix[i, 0]:
  652. visMatrix[i, 0] = 1
  653. for k in range(0, mid):
  654. i = mid - 1 - k
  655. if visMatrix[i+1, 0] and not objMatrix[i, 0]:
  656. visMatrix[i, 0] = 1
  657. # Propagate forward the region in which the agent can see
  658. for i in range(0, AGENT_VIEW_SIZE):
  659. for j in range(1, AGENT_VIEW_SIZE):
  660. if visMatrix[i, j-1] and not objMatrix[i, j]:
  661. visMatrix[i, j] = 1
  662. # Make cells immediately adjacent to the visible region also visible
  663. adjMatrix = np.copy(visMatrix)
  664. for i in range(0, AGENT_VIEW_SIZE):
  665. for j in range(0, AGENT_VIEW_SIZE):
  666. if j > 0 and visMatrix[i, j-1]:
  667. adjMatrix[i, j] = 1
  668. if i > 0:
  669. if visMatrix[i-1, j]:
  670. adjMatrix[i, j] = 1
  671. if j > 0 and visMatrix[i-1, j-1]:
  672. adjMatrix[i, j] = 1
  673. if i < AGENT_VIEW_SIZE - 1:
  674. if visMatrix[i+1, j]:
  675. adjMatrix[i, j] = 1
  676. if j > 0 and visMatrix[i+1, j-1]:
  677. adjMatrix[i, j] = 1
  678. visMatrix = adjMatrix
  679. # Hide non-visible cells
  680. for i in range(0, AGENT_VIEW_SIZE):
  681. for j in range(0, AGENT_VIEW_SIZE):
  682. if visMatrix[i, j] == 0:
  683. grid.set(i, AGENT_VIEW_SIZE - (1+j), Wall('red'))
  684. return grid, visMatrix
  685. def _genObs(self):
  686. """
  687. Generate the agent's view (partially observable, low-resolution encoding)
  688. """
  689. grid, visMatrix = self._getObsGrid()
  690. # Make it so the agent sees what it's carrying
  691. # We do this by placing the carried object at the agent's position
  692. # in the agent's partially observable view
  693. agentPos = grid.width // 2, grid.height - 1
  694. if self.carrying:
  695. grid.set(*agentPos, self.carrying)
  696. else:
  697. grid.set(*agentPos, None)
  698. # Encode the partially observable view into a numpy array
  699. image = grid.encode()
  700. assert hasattr(self, 'mission'), "environments must define a textual mission string"
  701. # Observations are dictionaries with both an image
  702. # and a textual mission string
  703. obs = {
  704. 'image': image,
  705. 'mission': self.mission
  706. }
  707. return obs
  708. def getObsRender(self, obs):
  709. """
  710. Render an agent observation for visualization
  711. """
  712. if self.obsRender == None:
  713. self.obsRender = Renderer(
  714. AGENT_VIEW_SIZE * CELL_PIXELS // 2,
  715. AGENT_VIEW_SIZE * CELL_PIXELS // 2
  716. )
  717. r = self.obsRender
  718. r.beginFrame()
  719. grid = Grid.decode(obs)
  720. # Render the whole grid
  721. grid.render(r, CELL_PIXELS // 2)
  722. # Draw the agent
  723. r.push()
  724. r.scale(0.5, 0.5)
  725. r.translate(
  726. CELL_PIXELS * (0.5 + AGENT_VIEW_SIZE // 2),
  727. CELL_PIXELS * (AGENT_VIEW_SIZE - 0.5)
  728. )
  729. r.rotate(3 * 90)
  730. r.setLineColor(255, 0, 0)
  731. r.setColor(255, 0, 0)
  732. r.drawPolygon([
  733. (-12, 10),
  734. ( 12, 0),
  735. (-12, -10)
  736. ])
  737. r.pop()
  738. r.endFrame()
  739. return r.getPixmap()
  740. def render(self, mode='human', close=False):
  741. """
  742. Render the whole-grid human view
  743. """
  744. if close:
  745. if self.gridRender:
  746. self.gridRender.close()
  747. return
  748. if self.gridRender is None:
  749. self.gridRender = Renderer(
  750. self.gridSize * CELL_PIXELS,
  751. self.gridSize * CELL_PIXELS,
  752. True if mode == 'human' else False
  753. )
  754. r = self.gridRender
  755. r.beginFrame()
  756. # Render the whole grid
  757. self.grid.render(r, CELL_PIXELS)
  758. # Draw the agent
  759. r.push()
  760. r.translate(
  761. CELL_PIXELS * (self.agentPos[0] + 0.5),
  762. CELL_PIXELS * (self.agentPos[1] + 0.5)
  763. )
  764. r.rotate(self.agentDir * 90)
  765. r.setLineColor(255, 0, 0)
  766. r.setColor(255, 0, 0)
  767. r.drawPolygon([
  768. (-12, 10),
  769. ( 12, 0),
  770. (-12, -10)
  771. ])
  772. r.pop()
  773. # Highlight what the agent can see
  774. topX, topY, botX, botY = self.getViewExts()
  775. r.fillRect(
  776. topX * CELL_PIXELS,
  777. topY * CELL_PIXELS,
  778. AGENT_VIEW_SIZE * CELL_PIXELS,
  779. AGENT_VIEW_SIZE * CELL_PIXELS,
  780. 200, 200, 200, 75
  781. )
  782. r.endFrame()
  783. if mode == 'rgb_array':
  784. return r.getArray()
  785. elif mode == 'pixmap':
  786. return r.getPixmap()
  787. return r