minigrid_env.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781
  1. import math
  2. import gym
  3. from gym import error, spaces, utils
  4. from gym.utils import seeding
  5. import numpy as np
  6. from gym_minigrid.envs.rendering import *
  7. # Size in pixels of a cell in the full-scale human view
  8. CELL_PIXELS = 32
  9. # Number of cells (width and height) in the agent view
  10. AGENT_VIEW_SIZE = 7
  11. # Size of the array given as an observation to the agent
  12. OBS_ARRAY_SIZE = (AGENT_VIEW_SIZE, AGENT_VIEW_SIZE, 3)
  13. COLORS = {
  14. 'red' : (255, 0, 0),
  15. 'green' : (0, 255, 0),
  16. 'blue' : (0, 0, 255),
  17. 'purple': (112, 39, 195),
  18. 'yellow': (255, 255, 0),
  19. 'grey' : (100, 100, 100)
  20. }
  21. # Used to map colors to integers
  22. COLOR_TO_IDX = {
  23. 'red' : 0,
  24. 'green' : 1,
  25. 'blue' : 2,
  26. 'purple': 3,
  27. 'yellow': 4,
  28. 'grey' : 5
  29. }
  30. IDX_TO_COLOR = dict(zip(COLOR_TO_IDX.values(), COLOR_TO_IDX.keys()))
  31. # Map of object type to integers
  32. OBJECT_TO_IDX = {
  33. 'empty' : 0,
  34. 'wall' : 1,
  35. 'door' : 2,
  36. 'locked_door' : 3,
  37. 'ball' : 4,
  38. 'key' : 5,
  39. 'goal' : 6
  40. }
  41. IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys()))
  42. class WorldObj:
  43. """
  44. Base class for grid world objects
  45. """
  46. def __init__(self, type, color):
  47. assert type in OBJECT_TO_IDX, type
  48. assert color in COLOR_TO_IDX, color
  49. self.type = type
  50. self.color = color
  51. self.contains = None
  52. def canOverlap(self):
  53. """Can the agent overlap with this?"""
  54. return False
  55. def canPickup(self):
  56. """Can the agent pick this up?"""
  57. return False
  58. def canContain(self):
  59. """Can this contain another object?"""
  60. return False
  61. def toggle(self, env):
  62. """Method to trigger/toggle an action this object performs"""
  63. return False
  64. def render(self, r):
  65. assert False
  66. def _setColor(self, r):
  67. c = COLORS[self.color]
  68. r.setLineColor(c[0], c[1], c[2])
  69. r.setColor(c[0], c[1], c[2])
  70. class Goal(WorldObj):
  71. def __init__(self):
  72. super(Goal, self).__init__('goal', 'green')
  73. def render(self, r):
  74. self._setColor(r)
  75. r.drawPolygon([
  76. (0 , CELL_PIXELS),
  77. (CELL_PIXELS, CELL_PIXELS),
  78. (CELL_PIXELS, 0),
  79. (0 , 0)
  80. ])
  81. class Wall(WorldObj):
  82. def __init__(self):
  83. super(Wall, self).__init__('wall', 'grey')
  84. def render(self, r):
  85. self._setColor(r)
  86. r.drawPolygon([
  87. (0 , CELL_PIXELS),
  88. (CELL_PIXELS, CELL_PIXELS),
  89. (CELL_PIXELS, 0),
  90. (0 , 0)
  91. ])
  92. class Door(WorldObj):
  93. def __init__(self, color, isOpen=False):
  94. super(Door, self).__init__('door', color)
  95. self.isOpen = isOpen
  96. def render(self, r):
  97. c = COLORS[self.color]
  98. r.setLineColor(c[0], c[1], c[2])
  99. r.setColor(0, 0, 0)
  100. if self.isOpen:
  101. r.drawPolygon([
  102. (CELL_PIXELS-2, CELL_PIXELS),
  103. (CELL_PIXELS , CELL_PIXELS),
  104. (CELL_PIXELS , 0),
  105. (CELL_PIXELS-2, 0)
  106. ])
  107. return
  108. r.drawPolygon([
  109. (0 , CELL_PIXELS),
  110. (CELL_PIXELS, CELL_PIXELS),
  111. (CELL_PIXELS, 0),
  112. (0 , 0)
  113. ])
  114. r.drawPolygon([
  115. (2 , CELL_PIXELS-2),
  116. (CELL_PIXELS-2, CELL_PIXELS-2),
  117. (CELL_PIXELS-2, 2),
  118. (2 , 2)
  119. ])
  120. r.drawCircle(CELL_PIXELS * 0.75, CELL_PIXELS * 0.5, 2)
  121. def toggle(self, env):
  122. if not self.isOpen:
  123. self.isOpen = True
  124. return True
  125. return False
  126. def canOverlap(self):
  127. """The agent can only walk over this cell when the door is open"""
  128. return self.isOpen
  129. class LockedDoor(WorldObj):
  130. def __init__(self, color, isOpen=False):
  131. super(LockedDoor, self).__init__('locked_door', color)
  132. self.isOpen = isOpen
  133. def render(self, r):
  134. c = COLORS[self.color]
  135. r.setLineColor(c[0], c[1], c[2])
  136. r.setColor(0, 0, 0)
  137. if self.isOpen:
  138. r.drawPolygon([
  139. (CELL_PIXELS-2, CELL_PIXELS),
  140. (CELL_PIXELS , CELL_PIXELS),
  141. (CELL_PIXELS , 0),
  142. (CELL_PIXELS-2, 0)
  143. ])
  144. return
  145. r.drawPolygon([
  146. (0 , CELL_PIXELS),
  147. (CELL_PIXELS, CELL_PIXELS),
  148. (CELL_PIXELS, 0),
  149. (0 , 0)
  150. ])
  151. r.drawPolygon([
  152. (2 , CELL_PIXELS-2),
  153. (CELL_PIXELS-2, CELL_PIXELS-2),
  154. (CELL_PIXELS-2, 2),
  155. (2 , 2)
  156. ])
  157. r.drawLine(
  158. CELL_PIXELS * 0.75,
  159. CELL_PIXELS * 0.45,
  160. CELL_PIXELS * 0.75,
  161. CELL_PIXELS * 0.60
  162. )
  163. def toggle(self, env):
  164. # If the player has the right key to open the door
  165. if isinstance(env.carrying, Key) and env.carrying.color == self.color:
  166. self.isOpen = True
  167. # The key has been used, remove it from the agent
  168. env.carrying = None
  169. return True
  170. return False
  171. def canOverlap(self):
  172. """The agent can only walk over this cell when the door is open"""
  173. return self.isOpen
  174. class Ball(WorldObj):
  175. def __init__(self, color='blue'):
  176. super(Ball, self).__init__('ball', color)
  177. def canPickup(self):
  178. return True
  179. def render(self, r):
  180. self._setColor(r)
  181. r.drawCircle(CELL_PIXELS * 0.5, CELL_PIXELS * 0.5, 10)
  182. class Key(WorldObj):
  183. def __init__(self, color='blue'):
  184. super(Key, self).__init__('key', color)
  185. def canPickup(self):
  186. return True
  187. def render(self, r):
  188. self._setColor(r)
  189. # Vertical quad
  190. r.drawPolygon([
  191. (16, 10),
  192. (20, 10),
  193. (20, 28),
  194. (16, 28)
  195. ])
  196. # Teeth
  197. r.drawPolygon([
  198. (12, 19),
  199. (16, 19),
  200. (16, 21),
  201. (12, 21)
  202. ])
  203. r.drawPolygon([
  204. (12, 26),
  205. (16, 26),
  206. (16, 28),
  207. (12, 28)
  208. ])
  209. r.drawCircle(18, 9, 6)
  210. r.setLineColor(0, 0, 0)
  211. r.setColor(0, 0, 0)
  212. r.drawCircle(18, 9, 2)
  213. class Grid:
  214. """
  215. Represent a grid and operations on it
  216. """
  217. def __init__(self, width, height):
  218. assert width >= 4
  219. assert height >= 4
  220. self.width = width
  221. self.height = height
  222. self.grid = [None] * width * height
  223. def copy(self):
  224. from copy import deepcopy
  225. return deepcopy(self)
  226. def set(self, i, j, v):
  227. assert i >= 0 and i < self.width
  228. assert j >= 0 and j < self.height
  229. self.grid[j * self.width + i] = v
  230. def get(self, i, j):
  231. assert i >= 0 and i < self.width
  232. assert j >= 0 and j < self.height
  233. return self.grid[j * self.width + i]
  234. def rotateLeft(self):
  235. """
  236. Rotate the grid to the left (counter-clockwise)
  237. """
  238. grid = Grid(self.width, self.height)
  239. for j in range(0, self.height):
  240. for i in range(0, self.width):
  241. v = self.get(self.width - 1 - j, i)
  242. grid.set(i, j, v)
  243. return grid
  244. def slice(self, topX, topY, width, height):
  245. """
  246. Get a subset of the grid
  247. """
  248. grid = Grid(width, height)
  249. for j in range(0, height):
  250. for i in range(0, width):
  251. x = topX + i
  252. y = topY + j
  253. if x >= 0 and x < self.width and \
  254. y >= 0 and y < self.height:
  255. v = self.get(x, y)
  256. else:
  257. v = Wall()
  258. grid.set(i, j, v)
  259. return grid
  260. def render(self, r, tileSize):
  261. """
  262. Render this grid at a given scale
  263. :param r: target renderer object
  264. :param tileSize: tile size in pixels
  265. """
  266. assert r.width == self.width * tileSize
  267. assert r.height == self.height * tileSize
  268. # Total grid size at native scale
  269. widthPx = self.width * CELL_PIXELS
  270. heightPx = self.height * CELL_PIXELS
  271. # Draw background (out-of-world) tiles the same colors as walls
  272. # so the agent understands these areas are not reachable
  273. c = COLORS['grey']
  274. r.setLineColor(c[0], c[1], c[2])
  275. r.setColor(c[0], c[1], c[2])
  276. r.drawPolygon([
  277. (0 , heightPx),
  278. (widthPx, heightPx),
  279. (widthPx, 0),
  280. (0 , 0)
  281. ])
  282. r.push()
  283. # Internally, we draw at the "large" full-grid resolution, but we
  284. # use the renderer to scale back to the desired size
  285. r.scale(tileSize / CELL_PIXELS, tileSize / CELL_PIXELS)
  286. # Draw the background of the in-world cells black
  287. r.fillRect(
  288. 0,
  289. 0,
  290. widthPx,
  291. heightPx,
  292. 0, 0, 0
  293. )
  294. # Draw grid lines
  295. r.setLineColor(100, 100, 100)
  296. for rowIdx in range(0, self.height):
  297. y = CELL_PIXELS * rowIdx
  298. r.drawLine(0, y, widthPx, y)
  299. for colIdx in range(0, self.width):
  300. x = CELL_PIXELS * colIdx
  301. r.drawLine(x, 0, x, heightPx)
  302. # Render the grid
  303. for j in range(0, self.height):
  304. for i in range(0, self.width):
  305. cell = self.get(i, j)
  306. if cell == None:
  307. continue
  308. r.push()
  309. r.translate(i * CELL_PIXELS, j * CELL_PIXELS)
  310. cell.render(r)
  311. r.pop()
  312. r.pop()
  313. def encode(self):
  314. """
  315. Produce a compact numpy encoding of the grid
  316. """
  317. codeSize = self.width * self.height * 3
  318. array = np.zeros(shape=(self.width, self.height, 3), dtype='uint8')
  319. for j in range(0, self.height):
  320. for i in range(0, self.width):
  321. v = self.get(i, j)
  322. if v == None:
  323. continue
  324. array[i, j, 0] = OBJECT_TO_IDX[v.type]
  325. array[i, j, 1] = COLOR_TO_IDX[v.color]
  326. if hasattr(v, 'isOpen') and v.isOpen:
  327. array[i, j, 2] = 1
  328. return array
  329. def decode(array):
  330. """
  331. Decode an array grid encoding back into a grid
  332. """
  333. width = array.shape[0]
  334. height = array.shape[1]
  335. assert array.shape[2] == 3
  336. grid = Grid(width, height)
  337. for j in range(0, height):
  338. for i in range(0, width):
  339. typeIdx = array[i, j, 0]
  340. colorIdx = array[i, j, 1]
  341. openIdx = array[i, j, 2]
  342. if typeIdx == 0:
  343. continue
  344. objType = IDX_TO_OBJECT[typeIdx]
  345. color = IDX_TO_COLOR[colorIdx]
  346. isOpen = True if openIdx == 1 else 0
  347. if objType == 'wall':
  348. v = Wall()
  349. elif objType == 'ball':
  350. v = Ball(color)
  351. elif objType == 'key':
  352. v = Key(color)
  353. elif objType == 'door':
  354. v = Door(color, isOpen)
  355. elif objType == 'locked_door':
  356. v = LockedDoor(color, isOpen)
  357. elif objType == 'goal':
  358. v = Goal()
  359. else:
  360. assert False, "unknown obj type in decode '%s'" % objType
  361. grid.set(i, j, v)
  362. return grid
  363. class MiniGridEnv(gym.Env):
  364. """
  365. 2D grid world game environment
  366. """
  367. metadata = {
  368. 'render.modes': ['human', 'rgb_array', 'pixmap'],
  369. 'video.frames_per_second' : 10
  370. }
  371. # Possible actions
  372. NUM_ACTIONS = 4
  373. ACTION_LEFT = 0
  374. ACTION_RIGHT = 1
  375. ACTION_FORWARD = 2
  376. ACTION_TOGGLE = 3
  377. def __init__(self, gridSize=16, maxSteps=100):
  378. # Renderer object used to render the whole grid (full-scale)
  379. self.gridRender = None
  380. # Renderer used to render observations (small-scale agent view)
  381. self.obsRender = None
  382. # Actions are discrete integer values
  383. self.action_space = spaces.Discrete(MiniGridEnv.NUM_ACTIONS)
  384. # The observations are RGB images
  385. self.observation_space = spaces.Box(
  386. low=0,
  387. high=255,
  388. shape=OBS_ARRAY_SIZE
  389. )
  390. self.reward_range = (-1, 1000)
  391. # Environment configuration
  392. self.gridSize = gridSize
  393. self.maxSteps = maxSteps
  394. self.startPos = (1, 1)
  395. self.startDir = 0
  396. # Initialize the state
  397. self.seed()
  398. self.reset()
  399. def _genGrid(self, width, height):
  400. """
  401. Generate a new grid
  402. """
  403. # Initialize the grid
  404. grid = Grid(width, height)
  405. # Place walls around the edges
  406. for i in range(0, width):
  407. grid.set(i, 0, Wall())
  408. grid.set(i, height - 1, Wall())
  409. for j in range(0, height):
  410. grid.set(0, j, Wall())
  411. grid.set(height - 1, j, Wall())
  412. # Place a goal in the bottom-left corner
  413. grid.set(width - 2, height - 2, Goal())
  414. return grid
  415. def _reset(self):
  416. # Place the agent in the starting position and direction
  417. self.agentPos = self.startPos
  418. self.agentDir = self.startDir
  419. # Item picked up, being carried, initially nothing
  420. self.carrying = None
  421. # Step count since episode start
  422. self.stepCount = 0
  423. # Restore the initial grid
  424. self.grid = self.seedGrid.copy()
  425. # Return first observation
  426. obs = self._genObs()
  427. return obs
  428. def _seed(self, seed=None):
  429. """
  430. The seed function sets the random elements of the environment,
  431. and initializes the world.
  432. """
  433. # By default, make things deterministic, always
  434. # produce the same environment
  435. if seed == None:
  436. seed = 1337
  437. # Seed the random number generator
  438. self.np_random, _ = seeding.np_random(seed)
  439. self.grid = self._genGrid(self.gridSize, self.gridSize)
  440. # Store a copy of the grid so we can restore it on reset
  441. self.seedGrid = self.grid.copy()
  442. return [seed]
  443. def _randInt(self, low, high):
  444. return self.np_random.randint(low, high)
  445. def _randElem(self, iterable):
  446. lst = list(iterable)
  447. idx = self._randInt(0, len(lst))
  448. return lst[idx]
  449. def getStepsRemaining(self):
  450. return self.maxSteps - self.stepCount
  451. def getDirVec(self):
  452. """
  453. Get the direction vector for the agent, pointing in the direction
  454. of forward movement.
  455. """
  456. # Pointing right
  457. if self.agentDir == 0:
  458. return (1, 0)
  459. # Down (positive Y)
  460. elif self.agentDir == 1:
  461. return (0, 1)
  462. # Pointing left
  463. elif self.agentDir == 2:
  464. return (-1, 0)
  465. # Up (negative Y)
  466. elif self.agentDir == 3:
  467. return (0, -1)
  468. else:
  469. assert False
  470. def getViewExts(self):
  471. """
  472. Get the extents of the square set of tiles visible to the agent
  473. Note: the bottom extent indices are not included in the set
  474. """
  475. # Facing right
  476. if self.agentDir == 0:
  477. topX = self.agentPos[0]
  478. topY = self.agentPos[1] - AGENT_VIEW_SIZE // 2
  479. # Facing down
  480. elif self.agentDir == 1:
  481. topX = self.agentPos[0] - AGENT_VIEW_SIZE // 2
  482. topY = self.agentPos[1]
  483. # Facing right
  484. elif self.agentDir == 2:
  485. topX = self.agentPos[0] - AGENT_VIEW_SIZE + 1
  486. topY = self.agentPos[1] - AGENT_VIEW_SIZE // 2
  487. # Facing up
  488. elif self.agentDir == 3:
  489. topX = self.agentPos[0] - AGENT_VIEW_SIZE // 2
  490. topY = self.agentPos[1] - AGENT_VIEW_SIZE + 1
  491. else:
  492. assert False
  493. botX = topX + AGENT_VIEW_SIZE
  494. botY = topY + AGENT_VIEW_SIZE
  495. return (topX, topY, botX, botY)
  496. def _step(self, action):
  497. self.stepCount += 1
  498. reward = 0
  499. done = False
  500. # Rotate left
  501. if action == MiniGridEnv.ACTION_LEFT:
  502. self.agentDir -= 1
  503. if self.agentDir < 0:
  504. self.agentDir += 4
  505. # Rotate right
  506. elif action == MiniGridEnv.ACTION_RIGHT:
  507. self.agentDir = (self.agentDir + 1) % 4
  508. # Move forward
  509. elif action == MiniGridEnv.ACTION_FORWARD:
  510. u, v = self.getDirVec()
  511. newPos = (self.agentPos[0] + u, self.agentPos[1] + v)
  512. targetCell = self.grid.get(newPos[0], newPos[1])
  513. if targetCell == None or targetCell.canOverlap():
  514. self.agentPos = newPos
  515. elif targetCell.type == 'goal':
  516. done = True
  517. reward = 1000 - self.stepCount
  518. # Pick up or trigger/activate an item
  519. elif action == MiniGridEnv.ACTION_TOGGLE:
  520. u, v = self.getDirVec()
  521. cell = self.grid.get(self.agentPos[0] + u, self.agentPos[1] + v)
  522. if cell and cell.canPickup() and self.carrying is None:
  523. self.carrying = cell
  524. self.grid.set(self.agentPos[0] + u, self.agentPos[1] + v, None)
  525. elif cell:
  526. cell.toggle(self)
  527. else:
  528. assert False, "unknown action"
  529. if self.stepCount >= self.maxSteps:
  530. done = True
  531. obs = self._genObs()
  532. return obs, reward, done, {}
  533. def _genObs(self):
  534. """
  535. Generate the agent's view (partially observable, low-resolution encoding)
  536. """
  537. topX, topY, botX, botY = self.getViewExts()
  538. grid = self.grid.slice(topX, topY, AGENT_VIEW_SIZE, AGENT_VIEW_SIZE)
  539. for i in range(self.agentDir + 1):
  540. grid = grid.rotateLeft()
  541. obs = grid.encode()
  542. return obs
  543. def getObsRender(self, obs):
  544. """
  545. Render an agent observation for visualization
  546. """
  547. if self.obsRender == None:
  548. self.obsRender = Renderer(
  549. AGENT_VIEW_SIZE * CELL_PIXELS // 2,
  550. AGENT_VIEW_SIZE * CELL_PIXELS // 2
  551. )
  552. r = self.obsRender
  553. r.beginFrame()
  554. grid = Grid.decode(obs)
  555. # Render the whole grid
  556. grid.render(r, CELL_PIXELS // 2)
  557. # Draw the agent
  558. r.push()
  559. r.scale(0.5, 0.5)
  560. r.translate(
  561. CELL_PIXELS * (0.5 + AGENT_VIEW_SIZE // 2),
  562. CELL_PIXELS * (AGENT_VIEW_SIZE - 0.5)
  563. )
  564. r.rotate(3 * 90)
  565. r.setLineColor(255, 0, 0)
  566. r.setColor(255, 0, 0)
  567. r.drawPolygon([
  568. (-12, 10),
  569. ( 12, 0),
  570. (-12, -10)
  571. ])
  572. r.pop()
  573. r.endFrame()
  574. return r.getPixmap()
  575. def _render(self, mode='human', close=False):
  576. """
  577. Render the whole-grid human view
  578. """
  579. if close:
  580. if self.gridRender:
  581. self.gridRender.close()
  582. return
  583. if self.gridRender is None:
  584. self.gridRender = Renderer(
  585. self.gridSize * CELL_PIXELS,
  586. self.gridSize * CELL_PIXELS,
  587. True if mode == 'human' else False
  588. )
  589. r = self.gridRender
  590. r.beginFrame()
  591. # Render the whole grid
  592. self.grid.render(r, CELL_PIXELS)
  593. # Draw the agent
  594. r.push()
  595. r.translate(
  596. CELL_PIXELS * (self.agentPos[0] + 0.5),
  597. CELL_PIXELS * (self.agentPos[1] + 0.5)
  598. )
  599. r.rotate(self.agentDir * 90)
  600. r.setLineColor(255, 0, 0)
  601. r.setColor(255, 0, 0)
  602. r.drawPolygon([
  603. (-12, 10),
  604. ( 12, 0),
  605. (-12, -10)
  606. ])
  607. r.pop()
  608. # Highlight what the agent can see
  609. topX, topY, botX, botY = self.getViewExts()
  610. r.fillRect(
  611. topX * CELL_PIXELS,
  612. topY * CELL_PIXELS,
  613. AGENT_VIEW_SIZE * CELL_PIXELS,
  614. AGENT_VIEW_SIZE * CELL_PIXELS,
  615. 200, 200, 200, 75
  616. )
  617. r.endFrame()
  618. if mode == 'rgb_array':
  619. return r.getArray()
  620. elif mode == 'pixmap':
  621. return r.getPixmap()
  622. return r