minigrid.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773
  1. import math
  2. import gym
  3. from gym import error, spaces, utils
  4. from gym.utils import seeding
  5. import numpy as np
  6. from gym_minigrid.rendering import *
  7. # Size in pixels of a cell in the full-scale human view
  8. CELL_PIXELS = 32
  9. # Number of cells (width and height) in the agent view
  10. AGENT_VIEW_SIZE = 7
  11. # Size of the array given as an observation to the agent
  12. OBS_ARRAY_SIZE = (AGENT_VIEW_SIZE, AGENT_VIEW_SIZE, 3)
  13. COLORS = {
  14. 'red' : (255, 0, 0),
  15. 'green' : (0, 255, 0),
  16. 'blue' : (0, 0, 255),
  17. 'purple': (112, 39, 195),
  18. 'yellow': (255, 255, 0),
  19. 'grey' : (100, 100, 100)
  20. }
  21. # Used to map colors to integers
  22. COLOR_TO_IDX = {
  23. 'red' : 0,
  24. 'green' : 1,
  25. 'blue' : 2,
  26. 'purple': 3,
  27. 'yellow': 4,
  28. 'grey' : 5
  29. }
  30. IDX_TO_COLOR = dict(zip(COLOR_TO_IDX.values(), COLOR_TO_IDX.keys()))
  31. # Map of object type to integers
  32. OBJECT_TO_IDX = {
  33. 'empty' : 0,
  34. 'wall' : 1,
  35. 'door' : 2,
  36. 'locked_door' : 3,
  37. 'ball' : 4,
  38. 'key' : 5,
  39. 'goal' : 6
  40. }
  41. IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys()))
  42. class WorldObj:
  43. """
  44. Base class for grid world objects
  45. """
  46. def __init__(self, type, color):
  47. assert type in OBJECT_TO_IDX, type
  48. assert color in COLOR_TO_IDX, color
  49. self.type = type
  50. self.color = color
  51. self.contains = None
  52. def canOverlap(self):
  53. """Can the agent overlap with this?"""
  54. return False
  55. def canPickup(self):
  56. """Can the agent pick this up?"""
  57. return False
  58. def canContain(self):
  59. """Can this contain another object?"""
  60. return False
  61. def toggle(self, env):
  62. """Method to trigger/toggle an action this object performs"""
  63. return False
  64. def render(self, r):
  65. assert False
  66. def _setColor(self, r):
  67. c = COLORS[self.color]
  68. r.setLineColor(c[0], c[1], c[2])
  69. r.setColor(c[0], c[1], c[2])
  70. class Goal(WorldObj):
  71. def __init__(self):
  72. super(Goal, self).__init__('goal', 'green')
  73. def render(self, r):
  74. self._setColor(r)
  75. r.drawPolygon([
  76. (0 , CELL_PIXELS),
  77. (CELL_PIXELS, CELL_PIXELS),
  78. (CELL_PIXELS, 0),
  79. (0 , 0)
  80. ])
  81. class Wall(WorldObj):
  82. def __init__(self, color='grey'):
  83. super(Wall, self).__init__('wall', color)
  84. def render(self, r):
  85. self._setColor(r)
  86. r.drawPolygon([
  87. (0 , CELL_PIXELS),
  88. (CELL_PIXELS, CELL_PIXELS),
  89. (CELL_PIXELS, 0),
  90. (0 , 0)
  91. ])
  92. class Door(WorldObj):
  93. def __init__(self, color, isOpen=False):
  94. super(Door, self).__init__('door', color)
  95. self.isOpen = isOpen
  96. def render(self, r):
  97. c = COLORS[self.color]
  98. r.setLineColor(c[0], c[1], c[2])
  99. r.setColor(0, 0, 0)
  100. if self.isOpen:
  101. r.drawPolygon([
  102. (CELL_PIXELS-2, CELL_PIXELS),
  103. (CELL_PIXELS , CELL_PIXELS),
  104. (CELL_PIXELS , 0),
  105. (CELL_PIXELS-2, 0)
  106. ])
  107. return
  108. r.drawPolygon([
  109. (0 , CELL_PIXELS),
  110. (CELL_PIXELS, CELL_PIXELS),
  111. (CELL_PIXELS, 0),
  112. (0 , 0)
  113. ])
  114. r.drawPolygon([
  115. (2 , CELL_PIXELS-2),
  116. (CELL_PIXELS-2, CELL_PIXELS-2),
  117. (CELL_PIXELS-2, 2),
  118. (2 , 2)
  119. ])
  120. r.drawCircle(CELL_PIXELS * 0.75, CELL_PIXELS * 0.5, 2)
  121. def toggle(self, env):
  122. if not self.isOpen:
  123. self.isOpen = True
  124. return True
  125. return False
  126. def canOverlap(self):
  127. """The agent can only walk over this cell when the door is open"""
  128. return self.isOpen
  129. class LockedDoor(WorldObj):
  130. def __init__(self, color, isOpen=False):
  131. super(LockedDoor, self).__init__('locked_door', color)
  132. self.isOpen = isOpen
  133. def render(self, r):
  134. c = COLORS[self.color]
  135. r.setLineColor(c[0], c[1], c[2])
  136. r.setColor(0, 0, 0)
  137. if self.isOpen:
  138. r.drawPolygon([
  139. (CELL_PIXELS-2, CELL_PIXELS),
  140. (CELL_PIXELS , CELL_PIXELS),
  141. (CELL_PIXELS , 0),
  142. (CELL_PIXELS-2, 0)
  143. ])
  144. return
  145. r.drawPolygon([
  146. (0 , CELL_PIXELS),
  147. (CELL_PIXELS, CELL_PIXELS),
  148. (CELL_PIXELS, 0),
  149. (0 , 0)
  150. ])
  151. r.drawPolygon([
  152. (2 , CELL_PIXELS-2),
  153. (CELL_PIXELS-2, CELL_PIXELS-2),
  154. (CELL_PIXELS-2, 2),
  155. (2 , 2)
  156. ])
  157. r.drawLine(
  158. CELL_PIXELS * 0.75,
  159. CELL_PIXELS * 0.45,
  160. CELL_PIXELS * 0.75,
  161. CELL_PIXELS * 0.60
  162. )
  163. def toggle(self, env):
  164. # If the player has the right key to open the door
  165. if isinstance(env.carrying, Key) and env.carrying.color == self.color:
  166. self.isOpen = True
  167. # The key has been used, remove it from the agent
  168. env.carrying = None
  169. return True
  170. return False
  171. def canOverlap(self):
  172. """The agent can only walk over this cell when the door is open"""
  173. return self.isOpen
  174. class Ball(WorldObj):
  175. def __init__(self, color='blue'):
  176. super(Ball, self).__init__('ball', color)
  177. def canPickup(self):
  178. return True
  179. def render(self, r):
  180. self._setColor(r)
  181. r.drawCircle(CELL_PIXELS * 0.5, CELL_PIXELS * 0.5, 10)
  182. class Key(WorldObj):
  183. def __init__(self, color='blue'):
  184. super(Key, self).__init__('key', color)
  185. def canPickup(self):
  186. return True
  187. def render(self, r):
  188. self._setColor(r)
  189. # Vertical quad
  190. r.drawPolygon([
  191. (16, 10),
  192. (20, 10),
  193. (20, 28),
  194. (16, 28)
  195. ])
  196. # Teeth
  197. r.drawPolygon([
  198. (12, 19),
  199. (16, 19),
  200. (16, 21),
  201. (12, 21)
  202. ])
  203. r.drawPolygon([
  204. (12, 26),
  205. (16, 26),
  206. (16, 28),
  207. (12, 28)
  208. ])
  209. r.drawCircle(18, 9, 6)
  210. r.setLineColor(0, 0, 0)
  211. r.setColor(0, 0, 0)
  212. r.drawCircle(18, 9, 2)
  213. class Grid:
  214. """
  215. Represent a grid and operations on it
  216. """
  217. def __init__(self, width, height):
  218. assert width >= 4
  219. assert height >= 4
  220. self.width = width
  221. self.height = height
  222. self.grid = [None] * width * height
  223. def copy(self):
  224. from copy import deepcopy
  225. return deepcopy(self)
  226. def set(self, i, j, v):
  227. assert i >= 0 and i < self.width
  228. assert j >= 0 and j < self.height
  229. self.grid[j * self.width + i] = v
  230. def get(self, i, j):
  231. assert i >= 0 and i < self.width
  232. assert j >= 0 and j < self.height
  233. return self.grid[j * self.width + i]
  234. def rotateLeft(self):
  235. """
  236. Rotate the grid to the left (counter-clockwise)
  237. """
  238. grid = Grid(self.width, self.height)
  239. for j in range(0, self.height):
  240. for i in range(0, self.width):
  241. v = self.get(self.width - 1 - j, i)
  242. grid.set(i, j, v)
  243. return grid
  244. def slice(self, topX, topY, width, height):
  245. """
  246. Get a subset of the grid
  247. """
  248. grid = Grid(width, height)
  249. for j in range(0, height):
  250. for i in range(0, width):
  251. x = topX + i
  252. y = topY + j
  253. if x >= 0 and x < self.width and \
  254. y >= 0 and y < self.height:
  255. v = self.get(x, y)
  256. else:
  257. v = Wall()
  258. grid.set(i, j, v)
  259. return grid
  260. def render(self, r, tileSize):
  261. """
  262. Render this grid at a given scale
  263. :param r: target renderer object
  264. :param tileSize: tile size in pixels
  265. """
  266. assert r.width == self.width * tileSize
  267. assert r.height == self.height * tileSize
  268. # Total grid size at native scale
  269. widthPx = self.width * CELL_PIXELS
  270. heightPx = self.height * CELL_PIXELS
  271. # Draw background (out-of-world) tiles the same colors as walls
  272. # so the agent understands these areas are not reachable
  273. c = COLORS['grey']
  274. r.setLineColor(c[0], c[1], c[2])
  275. r.setColor(c[0], c[1], c[2])
  276. r.drawPolygon([
  277. (0 , heightPx),
  278. (widthPx, heightPx),
  279. (widthPx, 0),
  280. (0 , 0)
  281. ])
  282. r.push()
  283. # Internally, we draw at the "large" full-grid resolution, but we
  284. # use the renderer to scale back to the desired size
  285. r.scale(tileSize / CELL_PIXELS, tileSize / CELL_PIXELS)
  286. # Draw the background of the in-world cells black
  287. r.fillRect(
  288. 0,
  289. 0,
  290. widthPx,
  291. heightPx,
  292. 0, 0, 0
  293. )
  294. # Draw grid lines
  295. r.setLineColor(100, 100, 100)
  296. for rowIdx in range(0, self.height):
  297. y = CELL_PIXELS * rowIdx
  298. r.drawLine(0, y, widthPx, y)
  299. for colIdx in range(0, self.width):
  300. x = CELL_PIXELS * colIdx
  301. r.drawLine(x, 0, x, heightPx)
  302. # Render the grid
  303. for j in range(0, self.height):
  304. for i in range(0, self.width):
  305. cell = self.get(i, j)
  306. if cell == None:
  307. continue
  308. r.push()
  309. r.translate(i * CELL_PIXELS, j * CELL_PIXELS)
  310. cell.render(r)
  311. r.pop()
  312. r.pop()
  313. def encode(self):
  314. """
  315. Produce a compact numpy encoding of the grid
  316. """
  317. codeSize = self.width * self.height * 3
  318. array = np.zeros(shape=(self.width, self.height, 3), dtype='uint8')
  319. for j in range(0, self.height):
  320. for i in range(0, self.width):
  321. v = self.get(i, j)
  322. if v == None:
  323. continue
  324. array[i, j, 0] = OBJECT_TO_IDX[v.type]
  325. array[i, j, 1] = COLOR_TO_IDX[v.color]
  326. if hasattr(v, 'isOpen') and v.isOpen:
  327. array[i, j, 2] = 1
  328. return array
  329. def decode(array):
  330. """
  331. Decode an array grid encoding back into a grid
  332. """
  333. width = array.shape[0]
  334. height = array.shape[1]
  335. assert array.shape[2] == 3
  336. grid = Grid(width, height)
  337. for j in range(0, height):
  338. for i in range(0, width):
  339. typeIdx = array[i, j, 0]
  340. colorIdx = array[i, j, 1]
  341. openIdx = array[i, j, 2]
  342. if typeIdx == 0:
  343. continue
  344. objType = IDX_TO_OBJECT[typeIdx]
  345. color = IDX_TO_COLOR[colorIdx]
  346. isOpen = True if openIdx == 1 else 0
  347. if objType == 'wall':
  348. v = Wall()
  349. elif objType == 'ball':
  350. v = Ball(color)
  351. elif objType == 'key':
  352. v = Key(color)
  353. elif objType == 'door':
  354. v = Door(color, isOpen)
  355. elif objType == 'locked_door':
  356. v = LockedDoor(color, isOpen)
  357. elif objType == 'goal':
  358. v = Goal()
  359. else:
  360. assert False, "unknown obj type in decode '%s'" % objType
  361. grid.set(i, j, v)
  362. return grid
  363. class MiniGridEnv(gym.Env):
  364. """
  365. 2D grid world game environment
  366. """
  367. metadata = {
  368. 'render.modes': ['human', 'rgb_array', 'pixmap'],
  369. 'video.frames_per_second' : 10
  370. }
  371. # Possible actions
  372. NUM_ACTIONS = 4
  373. ACTION_LEFT = 0
  374. ACTION_RIGHT = 1
  375. ACTION_FORWARD = 2
  376. ACTION_TOGGLE = 3
  377. def __init__(self, gridSize=16, maxSteps=100):
  378. # Renderer object used to render the whole grid (full-scale)
  379. self.gridRender = None
  380. # Renderer used to render observations (small-scale agent view)
  381. self.obsRender = None
  382. # Actions are discrete integer values
  383. self.action_space = spaces.Discrete(MiniGridEnv.NUM_ACTIONS)
  384. # The observations are RGB images
  385. self.observation_space = spaces.Box(
  386. low=0,
  387. high=255,
  388. shape=OBS_ARRAY_SIZE
  389. )
  390. self.reward_range = (-1, 1000)
  391. # Environment configuration
  392. self.gridSize = gridSize
  393. self.maxSteps = maxSteps
  394. self.startPos = (1, 1)
  395. self.startDir = 0
  396. # Initialize the state
  397. self.seed()
  398. self.reset()
  399. def _genGrid(self, width, height):
  400. """
  401. Generate a new grid
  402. """
  403. # Initialize the grid
  404. grid = Grid(width, height)
  405. # Place walls around the edges
  406. for i in range(0, width):
  407. grid.set(i, 0, Wall())
  408. grid.set(i, height - 1, Wall())
  409. for j in range(0, height):
  410. grid.set(0, j, Wall())
  411. grid.set(height - 1, j, Wall())
  412. # Place a goal in the bottom-left corner
  413. grid.set(width - 2, height - 2, Goal())
  414. return grid
  415. def _reset(self):
  416. # Generate a new random grid at the start of each episode
  417. # To prevent this behavior, call env.seed with the same
  418. # seed before env.reset
  419. self.grid = self._genGrid(self.gridSize, self.gridSize)
  420. # Place the agent in the starting position and direction
  421. self.agentPos = self.startPos
  422. self.agentDir = self.startDir
  423. # Item picked up, being carried, initially nothing
  424. self.carrying = None
  425. # Step count since episode start
  426. self.stepCount = 0
  427. # Return first observation
  428. obs = self._genObs()
  429. return obs
  430. def _seed(self, seed=1337):
  431. """
  432. The seed function sets the random elements of the environment,
  433. and initializes the world.
  434. """
  435. # Seed the random number generator
  436. self.np_random, _ = seeding.np_random(seed)
  437. return [seed]
  438. def _randInt(self, low, high):
  439. return self.np_random.randint(low, high)
  440. def _randElem(self, iterable):
  441. lst = list(iterable)
  442. idx = self._randInt(0, len(lst))
  443. return lst[idx]
  444. def getStepsRemaining(self):
  445. return self.maxSteps - self.stepCount
  446. def getDirVec(self):
  447. """
  448. Get the direction vector for the agent, pointing in the direction
  449. of forward movement.
  450. """
  451. # Pointing right
  452. if self.agentDir == 0:
  453. return (1, 0)
  454. # Down (positive Y)
  455. elif self.agentDir == 1:
  456. return (0, 1)
  457. # Pointing left
  458. elif self.agentDir == 2:
  459. return (-1, 0)
  460. # Up (negative Y)
  461. elif self.agentDir == 3:
  462. return (0, -1)
  463. else:
  464. assert False
  465. def getViewExts(self):
  466. """
  467. Get the extents of the square set of tiles visible to the agent
  468. Note: the bottom extent indices are not included in the set
  469. """
  470. # Facing right
  471. if self.agentDir == 0:
  472. topX = self.agentPos[0]
  473. topY = self.agentPos[1] - AGENT_VIEW_SIZE // 2
  474. # Facing down
  475. elif self.agentDir == 1:
  476. topX = self.agentPos[0] - AGENT_VIEW_SIZE // 2
  477. topY = self.agentPos[1]
  478. # Facing right
  479. elif self.agentDir == 2:
  480. topX = self.agentPos[0] - AGENT_VIEW_SIZE + 1
  481. topY = self.agentPos[1] - AGENT_VIEW_SIZE // 2
  482. # Facing up
  483. elif self.agentDir == 3:
  484. topX = self.agentPos[0] - AGENT_VIEW_SIZE // 2
  485. topY = self.agentPos[1] - AGENT_VIEW_SIZE + 1
  486. else:
  487. assert False
  488. botX = topX + AGENT_VIEW_SIZE
  489. botY = topY + AGENT_VIEW_SIZE
  490. return (topX, topY, botX, botY)
  491. def _step(self, action):
  492. self.stepCount += 1
  493. reward = 0
  494. done = False
  495. # Rotate left
  496. if action == MiniGridEnv.ACTION_LEFT:
  497. self.agentDir -= 1
  498. if self.agentDir < 0:
  499. self.agentDir += 4
  500. # Rotate right
  501. elif action == MiniGridEnv.ACTION_RIGHT:
  502. self.agentDir = (self.agentDir + 1) % 4
  503. # Move forward
  504. elif action == MiniGridEnv.ACTION_FORWARD:
  505. u, v = self.getDirVec()
  506. newPos = (self.agentPos[0] + u, self.agentPos[1] + v)
  507. targetCell = self.grid.get(newPos[0], newPos[1])
  508. if targetCell == None or targetCell.canOverlap():
  509. self.agentPos = newPos
  510. elif targetCell.type == 'goal':
  511. done = True
  512. reward = 1000 - self.stepCount
  513. # Pick up or trigger/activate an item
  514. elif action == MiniGridEnv.ACTION_TOGGLE:
  515. u, v = self.getDirVec()
  516. cell = self.grid.get(self.agentPos[0] + u, self.agentPos[1] + v)
  517. if cell and cell.canPickup() and self.carrying is None:
  518. self.carrying = cell
  519. self.grid.set(self.agentPos[0] + u, self.agentPos[1] + v, None)
  520. elif cell:
  521. cell.toggle(self)
  522. else:
  523. assert False, "unknown action"
  524. if self.stepCount >= self.maxSteps:
  525. done = True
  526. obs = self._genObs()
  527. return obs, reward, done, {}
  528. def _genObs(self):
  529. """
  530. Generate the agent's view (partially observable, low-resolution encoding)
  531. """
  532. topX, topY, botX, botY = self.getViewExts()
  533. grid = self.grid.slice(topX, topY, AGENT_VIEW_SIZE, AGENT_VIEW_SIZE)
  534. for i in range(self.agentDir + 1):
  535. grid = grid.rotateLeft()
  536. obs = grid.encode()
  537. return obs
  538. def getObsRender(self, obs):
  539. """
  540. Render an agent observation for visualization
  541. """
  542. if self.obsRender == None:
  543. self.obsRender = Renderer(
  544. AGENT_VIEW_SIZE * CELL_PIXELS // 2,
  545. AGENT_VIEW_SIZE * CELL_PIXELS // 2
  546. )
  547. r = self.obsRender
  548. r.beginFrame()
  549. grid = Grid.decode(obs)
  550. # Render the whole grid
  551. grid.render(r, CELL_PIXELS // 2)
  552. # Draw the agent
  553. r.push()
  554. r.scale(0.5, 0.5)
  555. r.translate(
  556. CELL_PIXELS * (0.5 + AGENT_VIEW_SIZE // 2),
  557. CELL_PIXELS * (AGENT_VIEW_SIZE - 0.5)
  558. )
  559. r.rotate(3 * 90)
  560. r.setLineColor(255, 0, 0)
  561. r.setColor(255, 0, 0)
  562. r.drawPolygon([
  563. (-12, 10),
  564. ( 12, 0),
  565. (-12, -10)
  566. ])
  567. r.pop()
  568. r.endFrame()
  569. return r.getPixmap()
  570. def _render(self, mode='human', close=False):
  571. """
  572. Render the whole-grid human view
  573. """
  574. if close:
  575. if self.gridRender:
  576. self.gridRender.close()
  577. return
  578. if self.gridRender is None:
  579. self.gridRender = Renderer(
  580. self.gridSize * CELL_PIXELS,
  581. self.gridSize * CELL_PIXELS,
  582. True if mode == 'human' else False
  583. )
  584. r = self.gridRender
  585. r.beginFrame()
  586. # Render the whole grid
  587. self.grid.render(r, CELL_PIXELS)
  588. # Draw the agent
  589. r.push()
  590. r.translate(
  591. CELL_PIXELS * (self.agentPos[0] + 0.5),
  592. CELL_PIXELS * (self.agentPos[1] + 0.5)
  593. )
  594. r.rotate(self.agentDir * 90)
  595. r.setLineColor(255, 0, 0)
  596. r.setColor(255, 0, 0)
  597. r.drawPolygon([
  598. (-12, 10),
  599. ( 12, 0),
  600. (-12, -10)
  601. ])
  602. r.pop()
  603. # Highlight what the agent can see
  604. topX, topY, botX, botY = self.getViewExts()
  605. r.fillRect(
  606. topX * CELL_PIXELS,
  607. topY * CELL_PIXELS,
  608. AGENT_VIEW_SIZE * CELL_PIXELS,
  609. AGENT_VIEW_SIZE * CELL_PIXELS,
  610. 200, 200, 200, 75
  611. )
  612. r.endFrame()
  613. if mode == 'rgb_array':
  614. return r.getArray()
  615. elif mode == 'pixmap':
  616. return r.getPixmap()
  617. return r