minigrid.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778
  1. import math
  2. import gym
  3. from enum import IntEnum
  4. import numpy as np
  5. from gym import error, spaces, utils
  6. from gym.utils import seeding
  7. from gym_minigrid.rendering import *
  8. # Size in pixels of a cell in the full-scale human view
  9. CELL_PIXELS = 32
  10. # Number of cells (width and height) in the agent view
  11. AGENT_VIEW_SIZE = 7
  12. # Size of the array given as an observation to the agent
  13. OBS_ARRAY_SIZE = (AGENT_VIEW_SIZE, AGENT_VIEW_SIZE, 3)
  14. COLORS = {
  15. 'red' : (255, 0, 0),
  16. 'green' : (0, 255, 0),
  17. 'blue' : (0, 0, 255),
  18. 'purple': (112, 39, 195),
  19. 'yellow': (255, 255, 0),
  20. 'grey' : (100, 100, 100)
  21. }
  22. # Used to map colors to integers
  23. COLOR_TO_IDX = {
  24. 'red' : 0,
  25. 'green' : 1,
  26. 'blue' : 2,
  27. 'purple': 3,
  28. 'yellow': 4,
  29. 'grey' : 5
  30. }
  31. IDX_TO_COLOR = dict(zip(COLOR_TO_IDX.values(), COLOR_TO_IDX.keys()))
  32. # Map of object type to integers
  33. OBJECT_TO_IDX = {
  34. 'empty' : 0,
  35. 'wall' : 1,
  36. 'door' : 2,
  37. 'locked_door' : 3,
  38. 'ball' : 4,
  39. 'key' : 5,
  40. 'goal' : 6
  41. }
  42. IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys()))
  43. class WorldObj:
  44. """
  45. Base class for grid world objects
  46. """
  47. def __init__(self, type, color):
  48. assert type in OBJECT_TO_IDX, type
  49. assert color in COLOR_TO_IDX, color
  50. self.type = type
  51. self.color = color
  52. self.contains = None
  53. def canOverlap(self):
  54. """Can the agent overlap with this?"""
  55. return False
  56. def canPickup(self):
  57. """Can the agent pick this up?"""
  58. return False
  59. def canContain(self):
  60. """Can this contain another object?"""
  61. return False
  62. def toggle(self, env):
  63. """Method to trigger/toggle an action this object performs"""
  64. return False
  65. def render(self, r):
  66. assert False
  67. def _setColor(self, r):
  68. c = COLORS[self.color]
  69. r.setLineColor(c[0], c[1], c[2])
  70. r.setColor(c[0], c[1], c[2])
  71. class Goal(WorldObj):
  72. def __init__(self):
  73. super(Goal, self).__init__('goal', 'green')
  74. def render(self, r):
  75. self._setColor(r)
  76. r.drawPolygon([
  77. (0 , CELL_PIXELS),
  78. (CELL_PIXELS, CELL_PIXELS),
  79. (CELL_PIXELS, 0),
  80. (0 , 0)
  81. ])
  82. class Wall(WorldObj):
  83. def __init__(self, color='grey'):
  84. super(Wall, self).__init__('wall', color)
  85. def render(self, r):
  86. self._setColor(r)
  87. r.drawPolygon([
  88. (0 , CELL_PIXELS),
  89. (CELL_PIXELS, CELL_PIXELS),
  90. (CELL_PIXELS, 0),
  91. (0 , 0)
  92. ])
  93. class Door(WorldObj):
  94. def __init__(self, color, isOpen=False):
  95. super(Door, self).__init__('door', color)
  96. self.isOpen = isOpen
  97. def render(self, r):
  98. c = COLORS[self.color]
  99. r.setLineColor(c[0], c[1], c[2])
  100. r.setColor(0, 0, 0)
  101. if self.isOpen:
  102. r.drawPolygon([
  103. (CELL_PIXELS-2, CELL_PIXELS),
  104. (CELL_PIXELS , CELL_PIXELS),
  105. (CELL_PIXELS , 0),
  106. (CELL_PIXELS-2, 0)
  107. ])
  108. return
  109. r.drawPolygon([
  110. (0 , CELL_PIXELS),
  111. (CELL_PIXELS, CELL_PIXELS),
  112. (CELL_PIXELS, 0),
  113. (0 , 0)
  114. ])
  115. r.drawPolygon([
  116. (2 , CELL_PIXELS-2),
  117. (CELL_PIXELS-2, CELL_PIXELS-2),
  118. (CELL_PIXELS-2, 2),
  119. (2 , 2)
  120. ])
  121. r.drawCircle(CELL_PIXELS * 0.75, CELL_PIXELS * 0.5, 2)
  122. def toggle(self, env):
  123. if not self.isOpen:
  124. self.isOpen = True
  125. return True
  126. return False
  127. def canOverlap(self):
  128. """The agent can only walk over this cell when the door is open"""
  129. return self.isOpen
  130. class LockedDoor(WorldObj):
  131. def __init__(self, color, isOpen=False):
  132. super(LockedDoor, self).__init__('locked_door', color)
  133. self.isOpen = isOpen
  134. def render(self, r):
  135. c = COLORS[self.color]
  136. r.setLineColor(c[0], c[1], c[2])
  137. r.setColor(0, 0, 0)
  138. if self.isOpen:
  139. r.drawPolygon([
  140. (CELL_PIXELS-2, CELL_PIXELS),
  141. (CELL_PIXELS , CELL_PIXELS),
  142. (CELL_PIXELS , 0),
  143. (CELL_PIXELS-2, 0)
  144. ])
  145. return
  146. r.drawPolygon([
  147. (0 , CELL_PIXELS),
  148. (CELL_PIXELS, CELL_PIXELS),
  149. (CELL_PIXELS, 0),
  150. (0 , 0)
  151. ])
  152. r.drawPolygon([
  153. (2 , CELL_PIXELS-2),
  154. (CELL_PIXELS-2, CELL_PIXELS-2),
  155. (CELL_PIXELS-2, 2),
  156. (2 , 2)
  157. ])
  158. r.drawLine(
  159. CELL_PIXELS * 0.75,
  160. CELL_PIXELS * 0.45,
  161. CELL_PIXELS * 0.75,
  162. CELL_PIXELS * 0.60
  163. )
  164. def toggle(self, env):
  165. # If the player has the right key to open the door
  166. if isinstance(env.carrying, Key) and env.carrying.color == self.color:
  167. self.isOpen = True
  168. # The key has been used, remove it from the agent
  169. env.carrying = None
  170. return True
  171. return False
  172. def canOverlap(self):
  173. """The agent can only walk over this cell when the door is open"""
  174. return self.isOpen
  175. class Ball(WorldObj):
  176. def __init__(self, color='blue'):
  177. super(Ball, self).__init__('ball', color)
  178. def canPickup(self):
  179. return True
  180. def render(self, r):
  181. self._setColor(r)
  182. r.drawCircle(CELL_PIXELS * 0.5, CELL_PIXELS * 0.5, 10)
  183. class Key(WorldObj):
  184. def __init__(self, color='blue'):
  185. super(Key, self).__init__('key', color)
  186. def canPickup(self):
  187. return True
  188. def render(self, r):
  189. self._setColor(r)
  190. # Vertical quad
  191. r.drawPolygon([
  192. (16, 10),
  193. (20, 10),
  194. (20, 28),
  195. (16, 28)
  196. ])
  197. # Teeth
  198. r.drawPolygon([
  199. (12, 19),
  200. (16, 19),
  201. (16, 21),
  202. (12, 21)
  203. ])
  204. r.drawPolygon([
  205. (12, 26),
  206. (16, 26),
  207. (16, 28),
  208. (12, 28)
  209. ])
  210. r.drawCircle(18, 9, 6)
  211. r.setLineColor(0, 0, 0)
  212. r.setColor(0, 0, 0)
  213. r.drawCircle(18, 9, 2)
  214. class Grid:
  215. """
  216. Represent a grid and operations on it
  217. """
  218. def __init__(self, width, height):
  219. assert width >= 4
  220. assert height >= 4
  221. self.width = width
  222. self.height = height
  223. self.grid = [None] * width * height
  224. def copy(self):
  225. from copy import deepcopy
  226. return deepcopy(self)
  227. def set(self, i, j, v):
  228. assert i >= 0 and i < self.width
  229. assert j >= 0 and j < self.height
  230. self.grid[j * self.width + i] = v
  231. def get(self, i, j):
  232. assert i >= 0 and i < self.width
  233. assert j >= 0 and j < self.height
  234. return self.grid[j * self.width + i]
  235. def rotateLeft(self):
  236. """
  237. Rotate the grid to the left (counter-clockwise)
  238. """
  239. grid = Grid(self.width, self.height)
  240. for j in range(0, self.height):
  241. for i in range(0, self.width):
  242. v = self.get(self.width - 1 - j, i)
  243. grid.set(i, j, v)
  244. return grid
  245. def slice(self, topX, topY, width, height):
  246. """
  247. Get a subset of the grid
  248. """
  249. grid = Grid(width, height)
  250. for j in range(0, height):
  251. for i in range(0, width):
  252. x = topX + i
  253. y = topY + j
  254. if x >= 0 and x < self.width and \
  255. y >= 0 and y < self.height:
  256. v = self.get(x, y)
  257. else:
  258. v = Wall()
  259. grid.set(i, j, v)
  260. return grid
  261. def render(self, r, tileSize):
  262. """
  263. Render this grid at a given scale
  264. :param r: target renderer object
  265. :param tileSize: tile size in pixels
  266. """
  267. assert r.width == self.width * tileSize
  268. assert r.height == self.height * tileSize
  269. # Total grid size at native scale
  270. widthPx = self.width * CELL_PIXELS
  271. heightPx = self.height * CELL_PIXELS
  272. # Draw background (out-of-world) tiles the same colors as walls
  273. # so the agent understands these areas are not reachable
  274. c = COLORS['grey']
  275. r.setLineColor(c[0], c[1], c[2])
  276. r.setColor(c[0], c[1], c[2])
  277. r.drawPolygon([
  278. (0 , heightPx),
  279. (widthPx, heightPx),
  280. (widthPx, 0),
  281. (0 , 0)
  282. ])
  283. r.push()
  284. # Internally, we draw at the "large" full-grid resolution, but we
  285. # use the renderer to scale back to the desired size
  286. r.scale(tileSize / CELL_PIXELS, tileSize / CELL_PIXELS)
  287. # Draw the background of the in-world cells black
  288. r.fillRect(
  289. 0,
  290. 0,
  291. widthPx,
  292. heightPx,
  293. 0, 0, 0
  294. )
  295. # Draw grid lines
  296. r.setLineColor(100, 100, 100)
  297. for rowIdx in range(0, self.height):
  298. y = CELL_PIXELS * rowIdx
  299. r.drawLine(0, y, widthPx, y)
  300. for colIdx in range(0, self.width):
  301. x = CELL_PIXELS * colIdx
  302. r.drawLine(x, 0, x, heightPx)
  303. # Render the grid
  304. for j in range(0, self.height):
  305. for i in range(0, self.width):
  306. cell = self.get(i, j)
  307. if cell == None:
  308. continue
  309. r.push()
  310. r.translate(i * CELL_PIXELS, j * CELL_PIXELS)
  311. cell.render(r)
  312. r.pop()
  313. r.pop()
  314. def encode(self):
  315. """
  316. Produce a compact numpy encoding of the grid
  317. """
  318. codeSize = self.width * self.height * 3
  319. array = np.zeros(shape=(self.width, self.height, 3), dtype='uint8')
  320. for j in range(0, self.height):
  321. for i in range(0, self.width):
  322. v = self.get(i, j)
  323. if v == None:
  324. continue
  325. array[i, j, 0] = OBJECT_TO_IDX[v.type]
  326. array[i, j, 1] = COLOR_TO_IDX[v.color]
  327. if hasattr(v, 'isOpen') and v.isOpen:
  328. array[i, j, 2] = 1
  329. return array
  330. def decode(array):
  331. """
  332. Decode an array grid encoding back into a grid
  333. """
  334. width = array.shape[0]
  335. height = array.shape[1]
  336. assert array.shape[2] == 3
  337. grid = Grid(width, height)
  338. for j in range(0, height):
  339. for i in range(0, width):
  340. typeIdx = array[i, j, 0]
  341. colorIdx = array[i, j, 1]
  342. openIdx = array[i, j, 2]
  343. if typeIdx == 0:
  344. continue
  345. objType = IDX_TO_OBJECT[typeIdx]
  346. color = IDX_TO_COLOR[colorIdx]
  347. isOpen = True if openIdx == 1 else 0
  348. if objType == 'wall':
  349. v = Wall()
  350. elif objType == 'ball':
  351. v = Ball(color)
  352. elif objType == 'key':
  353. v = Key(color)
  354. elif objType == 'door':
  355. v = Door(color, isOpen)
  356. elif objType == 'locked_door':
  357. v = LockedDoor(color, isOpen)
  358. elif objType == 'goal':
  359. v = Goal()
  360. else:
  361. assert False, "unknown obj type in decode '%s'" % objType
  362. grid.set(i, j, v)
  363. return grid
  364. class MiniGridEnv(gym.Env):
  365. """
  366. 2D grid world game environment
  367. """
  368. metadata = {
  369. 'render.modes': ['human', 'rgb_array', 'pixmap'],
  370. 'video.frames_per_second' : 10
  371. }
  372. # Enumeration of possible actions
  373. class Actions(IntEnum):
  374. left = 0
  375. right = 1
  376. forward = 2
  377. toggle = 3
  378. def __init__(self, gridSize=16, maxSteps=100):
  379. # Action enumeration for this environment
  380. self.actions = MiniGridEnv.Actions
  381. # Actions are discrete integer values
  382. self.action_space = spaces.Discrete(len(self.actions))
  383. # The observations are RGB images
  384. self.observation_space = spaces.Box(
  385. low=0,
  386. high=255,
  387. shape=OBS_ARRAY_SIZE
  388. )
  389. # Range of possible rewards
  390. self.reward_range = (-1, 1000)
  391. # Renderer object used to render the whole grid (full-scale)
  392. self.gridRender = None
  393. # Renderer used to render observations (small-scale agent view)
  394. self.obsRender = None
  395. # Environment configuration
  396. self.gridSize = gridSize
  397. self.maxSteps = maxSteps
  398. self.startPos = (1, 1)
  399. self.startDir = 0
  400. # Initialize the state
  401. self.seed()
  402. self.reset()
  403. def _genGrid(self, width, height):
  404. """
  405. Generate a new grid
  406. """
  407. # Initialize the grid
  408. grid = Grid(width, height)
  409. # Place walls around the edges
  410. for i in range(0, width):
  411. grid.set(i, 0, Wall())
  412. grid.set(i, height - 1, Wall())
  413. for j in range(0, height):
  414. grid.set(0, j, Wall())
  415. grid.set(height - 1, j, Wall())
  416. # Place a goal in the bottom-left corner
  417. grid.set(width - 2, height - 2, Goal())
  418. return grid
  419. def _reset(self):
  420. # Generate a new random grid at the start of each episode
  421. # To prevent this behavior, call env.seed with the same
  422. # seed before env.reset
  423. self.grid = self._genGrid(self.gridSize, self.gridSize)
  424. # Place the agent in the starting position and direction
  425. self.agentPos = self.startPos
  426. self.agentDir = self.startDir
  427. # Item picked up, being carried, initially nothing
  428. self.carrying = None
  429. # Step count since episode start
  430. self.stepCount = 0
  431. # Return first observation
  432. obs = self._genObs()
  433. return obs
  434. def _seed(self, seed=1337):
  435. """
  436. The seed function sets the random elements of the environment,
  437. and initializes the world.
  438. """
  439. # Seed the random number generator
  440. self.np_random, _ = seeding.np_random(seed)
  441. return [seed]
  442. def _randInt(self, low, high):
  443. return self.np_random.randint(low, high)
  444. def _randElem(self, iterable):
  445. lst = list(iterable)
  446. idx = self._randInt(0, len(lst))
  447. return lst[idx]
  448. def getStepsRemaining(self):
  449. return self.maxSteps - self.stepCount
  450. def getDirVec(self):
  451. """
  452. Get the direction vector for the agent, pointing in the direction
  453. of forward movement.
  454. """
  455. # Pointing right
  456. if self.agentDir == 0:
  457. return (1, 0)
  458. # Down (positive Y)
  459. elif self.agentDir == 1:
  460. return (0, 1)
  461. # Pointing left
  462. elif self.agentDir == 2:
  463. return (-1, 0)
  464. # Up (negative Y)
  465. elif self.agentDir == 3:
  466. return (0, -1)
  467. else:
  468. assert False
  469. def getViewExts(self):
  470. """
  471. Get the extents of the square set of tiles visible to the agent
  472. Note: the bottom extent indices are not included in the set
  473. """
  474. # Facing right
  475. if self.agentDir == 0:
  476. topX = self.agentPos[0]
  477. topY = self.agentPos[1] - AGENT_VIEW_SIZE // 2
  478. # Facing down
  479. elif self.agentDir == 1:
  480. topX = self.agentPos[0] - AGENT_VIEW_SIZE // 2
  481. topY = self.agentPos[1]
  482. # Facing right
  483. elif self.agentDir == 2:
  484. topX = self.agentPos[0] - AGENT_VIEW_SIZE + 1
  485. topY = self.agentPos[1] - AGENT_VIEW_SIZE // 2
  486. # Facing up
  487. elif self.agentDir == 3:
  488. topX = self.agentPos[0] - AGENT_VIEW_SIZE // 2
  489. topY = self.agentPos[1] - AGENT_VIEW_SIZE + 1
  490. else:
  491. assert False
  492. botX = topX + AGENT_VIEW_SIZE
  493. botY = topY + AGENT_VIEW_SIZE
  494. return (topX, topY, botX, botY)
  495. def _step(self, action):
  496. self.stepCount += 1
  497. reward = 0
  498. done = False
  499. # Rotate left
  500. if action == self.actions.left:
  501. self.agentDir -= 1
  502. if self.agentDir < 0:
  503. self.agentDir += 4
  504. # Rotate right
  505. elif action == self.actions.right:
  506. self.agentDir = (self.agentDir + 1) % 4
  507. # Move forward
  508. elif action == self.actions.forward:
  509. u, v = self.getDirVec()
  510. newPos = (self.agentPos[0] + u, self.agentPos[1] + v)
  511. targetCell = self.grid.get(newPos[0], newPos[1])
  512. if targetCell == None or targetCell.canOverlap():
  513. self.agentPos = newPos
  514. elif targetCell.type == 'goal':
  515. done = True
  516. reward = 1000 - self.stepCount
  517. # Pick up or trigger/activate an item
  518. elif action == self.actions.toggle:
  519. u, v = self.getDirVec()
  520. cell = self.grid.get(self.agentPos[0] + u, self.agentPos[1] + v)
  521. if cell and cell.canPickup() and self.carrying is None:
  522. self.carrying = cell
  523. self.grid.set(self.agentPos[0] + u, self.agentPos[1] + v, None)
  524. elif cell:
  525. cell.toggle(self)
  526. else:
  527. assert False, "unknown action"
  528. if self.stepCount >= self.maxSteps:
  529. done = True
  530. obs = self._genObs()
  531. return obs, reward, done, {}
  532. def _genObs(self):
  533. """
  534. Generate the agent's view (partially observable, low-resolution encoding)
  535. """
  536. topX, topY, botX, botY = self.getViewExts()
  537. grid = self.grid.slice(topX, topY, AGENT_VIEW_SIZE, AGENT_VIEW_SIZE)
  538. for i in range(self.agentDir + 1):
  539. grid = grid.rotateLeft()
  540. obs = grid.encode()
  541. return obs
  542. def getObsRender(self, obs):
  543. """
  544. Render an agent observation for visualization
  545. """
  546. if self.obsRender == None:
  547. self.obsRender = Renderer(
  548. AGENT_VIEW_SIZE * CELL_PIXELS // 2,
  549. AGENT_VIEW_SIZE * CELL_PIXELS // 2
  550. )
  551. r = self.obsRender
  552. r.beginFrame()
  553. grid = Grid.decode(obs)
  554. # Render the whole grid
  555. grid.render(r, CELL_PIXELS // 2)
  556. # Draw the agent
  557. r.push()
  558. r.scale(0.5, 0.5)
  559. r.translate(
  560. CELL_PIXELS * (0.5 + AGENT_VIEW_SIZE // 2),
  561. CELL_PIXELS * (AGENT_VIEW_SIZE - 0.5)
  562. )
  563. r.rotate(3 * 90)
  564. r.setLineColor(255, 0, 0)
  565. r.setColor(255, 0, 0)
  566. r.drawPolygon([
  567. (-12, 10),
  568. ( 12, 0),
  569. (-12, -10)
  570. ])
  571. r.pop()
  572. r.endFrame()
  573. return r.getPixmap()
  574. def _render(self, mode='human', close=False):
  575. """
  576. Render the whole-grid human view
  577. """
  578. if close:
  579. if self.gridRender:
  580. self.gridRender.close()
  581. return
  582. if self.gridRender is None:
  583. self.gridRender = Renderer(
  584. self.gridSize * CELL_PIXELS,
  585. self.gridSize * CELL_PIXELS,
  586. True if mode == 'human' else False
  587. )
  588. r = self.gridRender
  589. r.beginFrame()
  590. # Render the whole grid
  591. self.grid.render(r, CELL_PIXELS)
  592. # Draw the agent
  593. r.push()
  594. r.translate(
  595. CELL_PIXELS * (self.agentPos[0] + 0.5),
  596. CELL_PIXELS * (self.agentPos[1] + 0.5)
  597. )
  598. r.rotate(self.agentDir * 90)
  599. r.setLineColor(255, 0, 0)
  600. r.setColor(255, 0, 0)
  601. r.drawPolygon([
  602. (-12, 10),
  603. ( 12, 0),
  604. (-12, -10)
  605. ])
  606. r.pop()
  607. # Highlight what the agent can see
  608. topX, topY, botX, botY = self.getViewExts()
  609. r.fillRect(
  610. topX * CELL_PIXELS,
  611. topY * CELL_PIXELS,
  612. AGENT_VIEW_SIZE * CELL_PIXELS,
  613. AGENT_VIEW_SIZE * CELL_PIXELS,
  614. 200, 200, 200, 75
  615. )
  616. r.endFrame()
  617. if mode == 'rgb_array':
  618. return r.getArray()
  619. elif mode == 'pixmap':
  620. return r.getPixmap()
  621. return r