minigrid.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131
  1. import math
  2. import gym
  3. from enum import IntEnum
  4. import numpy as np
  5. from gym import error, spaces, utils
  6. from gym.utils import seeding
  7. from gym_minigrid.rendering import *
  8. # Size in pixels of a cell in the full-scale human view
  9. CELL_PIXELS = 32
  10. # Number of cells (width and height) in the agent view
  11. AGENT_VIEW_SIZE = 7
  12. # Size of the array given as an observation to the agent
  13. OBS_ARRAY_SIZE = (AGENT_VIEW_SIZE, AGENT_VIEW_SIZE, 3)
  14. # Map of color names to RGB values
  15. COLORS = {
  16. 'red' : (255, 0, 0),
  17. 'green' : (0, 255, 0),
  18. 'blue' : (0, 0, 255),
  19. 'purple': (112, 39, 195),
  20. 'yellow': (255, 255, 0),
  21. 'grey' : (100, 100, 100)
  22. }
  23. COLOR_NAMES = sorted(list(COLORS.keys()))
  24. # Used to map colors to integers
  25. COLOR_TO_IDX = {
  26. 'red' : 0,
  27. 'green' : 1,
  28. 'blue' : 2,
  29. 'purple': 3,
  30. 'yellow': 4,
  31. 'grey' : 5
  32. }
  33. IDX_TO_COLOR = dict(zip(COLOR_TO_IDX.values(), COLOR_TO_IDX.keys()))
  34. # Map of object type to integers
  35. OBJECT_TO_IDX = {
  36. 'empty' : 0,
  37. 'wall' : 1,
  38. 'door' : 2,
  39. 'locked_door' : 3,
  40. 'key' : 4,
  41. 'ball' : 5,
  42. 'box' : 6,
  43. 'goal' : 7
  44. }
  45. IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys()))
  46. class WorldObj:
  47. """
  48. Base class for grid world objects
  49. """
  50. def __init__(self, type, color):
  51. assert type in OBJECT_TO_IDX, type
  52. assert color in COLOR_TO_IDX, color
  53. self.type = type
  54. self.color = color
  55. self.contains = None
  56. def can_overlap(self):
  57. """Can the agent overlap with this?"""
  58. return False
  59. def canPickup(self):
  60. """Can the agent pick this up?"""
  61. return False
  62. def canContain(self):
  63. """Can this contain another object?"""
  64. return False
  65. def see_behind(self):
  66. """Can the agent see behind this object?"""
  67. return True
  68. def toggle(self, env, pos):
  69. """Method to trigger/toggle an action this object performs"""
  70. return False
  71. def render(self, r):
  72. assert False
  73. def _set_color(self, r):
  74. c = COLORS[self.color]
  75. r.setLineColor(c[0], c[1], c[2])
  76. r.setColor(c[0], c[1], c[2])
  77. class Goal(WorldObj):
  78. def __init__(self):
  79. super(Goal, self).__init__('goal', 'green')
  80. def can_overlap(self):
  81. return True
  82. def render(self, r):
  83. self._set_color(r)
  84. r.drawPolygon([
  85. (0 , CELL_PIXELS),
  86. (CELL_PIXELS, CELL_PIXELS),
  87. (CELL_PIXELS, 0),
  88. (0 , 0)
  89. ])
  90. class Wall(WorldObj):
  91. def __init__(self, color='grey'):
  92. super(Wall, self).__init__('wall', color)
  93. def see_behind(self):
  94. return False
  95. def render(self, r):
  96. self._set_color(r)
  97. r.drawPolygon([
  98. (0 , CELL_PIXELS),
  99. (CELL_PIXELS, CELL_PIXELS),
  100. (CELL_PIXELS, 0),
  101. (0 , 0)
  102. ])
  103. class Door(WorldObj):
  104. def __init__(self, color, isOpen=False):
  105. super(Door, self).__init__('door', color)
  106. self.isOpen = isOpen
  107. def can_overlap(self):
  108. """The agent can only walk over this cell when the door is open"""
  109. return self.isOpen
  110. def see_behind(self):
  111. return self.isOpen
  112. def toggle(self, env, pos):
  113. if not self.isOpen:
  114. self.isOpen = True
  115. return True
  116. return False
  117. def render(self, r):
  118. c = COLORS[self.color]
  119. r.setLineColor(c[0], c[1], c[2])
  120. r.setColor(0, 0, 0)
  121. if self.isOpen:
  122. r.drawPolygon([
  123. (CELL_PIXELS-2, CELL_PIXELS),
  124. (CELL_PIXELS , CELL_PIXELS),
  125. (CELL_PIXELS , 0),
  126. (CELL_PIXELS-2, 0)
  127. ])
  128. return
  129. r.drawPolygon([
  130. (0 , CELL_PIXELS),
  131. (CELL_PIXELS, CELL_PIXELS),
  132. (CELL_PIXELS, 0),
  133. (0 , 0)
  134. ])
  135. r.drawPolygon([
  136. (2 , CELL_PIXELS-2),
  137. (CELL_PIXELS-2, CELL_PIXELS-2),
  138. (CELL_PIXELS-2, 2),
  139. (2 , 2)
  140. ])
  141. r.drawCircle(CELL_PIXELS * 0.75, CELL_PIXELS * 0.5, 2)
  142. class LockedDoor(WorldObj):
  143. def __init__(self, color, isOpen=False):
  144. super(LockedDoor, self).__init__('locked_door', color)
  145. self.isOpen = isOpen
  146. def toggle(self, env, pos):
  147. # If the player has the right key to open the door
  148. if isinstance(env.carrying, Key) and env.carrying.color == self.color:
  149. self.isOpen = True
  150. # The key has been used, remove it from the agent
  151. env.carrying = None
  152. return True
  153. return False
  154. def can_overlap(self):
  155. """The agent can only walk over this cell when the door is open"""
  156. return self.isOpen
  157. def render(self, r):
  158. c = COLORS[self.color]
  159. r.setLineColor(c[0], c[1], c[2])
  160. r.setColor(c[0], c[1], c[2], 50)
  161. if self.isOpen:
  162. r.drawPolygon([
  163. (CELL_PIXELS-2, CELL_PIXELS),
  164. (CELL_PIXELS , CELL_PIXELS),
  165. (CELL_PIXELS , 0),
  166. (CELL_PIXELS-2, 0)
  167. ])
  168. return
  169. r.drawPolygon([
  170. (0 , CELL_PIXELS),
  171. (CELL_PIXELS, CELL_PIXELS),
  172. (CELL_PIXELS, 0),
  173. (0 , 0)
  174. ])
  175. r.drawPolygon([
  176. (2 , CELL_PIXELS-2),
  177. (CELL_PIXELS-2, CELL_PIXELS-2),
  178. (CELL_PIXELS-2, 2),
  179. (2 , 2)
  180. ])
  181. r.drawLine(
  182. CELL_PIXELS * 0.55,
  183. CELL_PIXELS * 0.5,
  184. CELL_PIXELS * 0.75,
  185. CELL_PIXELS * 0.5
  186. )
  187. class Key(WorldObj):
  188. def __init__(self, color='blue'):
  189. super(Key, self).__init__('key', color)
  190. def canPickup(self):
  191. return True
  192. def render(self, r):
  193. self._set_color(r)
  194. # Vertical quad
  195. r.drawPolygon([
  196. (16, 10),
  197. (20, 10),
  198. (20, 28),
  199. (16, 28)
  200. ])
  201. # Teeth
  202. r.drawPolygon([
  203. (12, 19),
  204. (16, 19),
  205. (16, 21),
  206. (12, 21)
  207. ])
  208. r.drawPolygon([
  209. (12, 26),
  210. (16, 26),
  211. (16, 28),
  212. (12, 28)
  213. ])
  214. r.drawCircle(18, 9, 6)
  215. r.setLineColor(0, 0, 0)
  216. r.setColor(0, 0, 0)
  217. r.drawCircle(18, 9, 2)
  218. class Ball(WorldObj):
  219. def __init__(self, color='blue'):
  220. super(Ball, self).__init__('ball', color)
  221. def canPickup(self):
  222. return True
  223. def render(self, r):
  224. self._set_color(r)
  225. r.drawCircle(CELL_PIXELS * 0.5, CELL_PIXELS * 0.5, 10)
  226. class Box(WorldObj):
  227. def __init__(self, color, contains=None):
  228. super(Box, self).__init__('box', color)
  229. self.contains = contains
  230. def canPickup(self):
  231. return True
  232. def render(self, r):
  233. c = COLORS[self.color]
  234. r.setLineColor(c[0], c[1], c[2])
  235. r.setColor(0, 0, 0)
  236. r.setLineWidth(2)
  237. r.drawPolygon([
  238. (4 , CELL_PIXELS-4),
  239. (CELL_PIXELS-4, CELL_PIXELS-4),
  240. (CELL_PIXELS-4, 4),
  241. (4 , 4)
  242. ])
  243. r.drawLine(
  244. 4,
  245. CELL_PIXELS / 2,
  246. CELL_PIXELS - 4,
  247. CELL_PIXELS / 2
  248. )
  249. r.setLineWidth(1)
  250. def toggle(self, env, pos):
  251. # Replace the box by its contents
  252. env.grid.set(*pos, self.contains)
  253. return True
  254. class Grid:
  255. """
  256. Represent a grid and operations on it
  257. """
  258. def __init__(self, width, height):
  259. assert width >= 4
  260. assert height >= 4
  261. self.width = width
  262. self.height = height
  263. self.grid = [None] * width * height
  264. def __contains__(self, key):
  265. if isinstance(key, WorldObj):
  266. for e in self.grid:
  267. if e is key:
  268. return True
  269. elif isinstance(key, tuple):
  270. for e in self.grid:
  271. if e is None:
  272. continue
  273. if (e.color, e.type) == key:
  274. return True
  275. return False
  276. def __eq__(self, other):
  277. grid1 = self.encode()
  278. grid2 = other.encode()
  279. return np.array_equal(grid2, grid1)
  280. def __ne__(self, other):
  281. return not self == other
  282. def copy(self):
  283. from copy import deepcopy
  284. return deepcopy(self)
  285. def set(self, i, j, v):
  286. assert i >= 0 and i < self.width
  287. assert j >= 0 and j < self.height
  288. self.grid[j * self.width + i] = v
  289. def get(self, i, j):
  290. assert i >= 0 and i < self.width
  291. assert j >= 0 and j < self.height
  292. return self.grid[j * self.width + i]
  293. def horzWall(self, x, y, length=None):
  294. if length is None:
  295. length = self.width - x
  296. for i in range(0, length):
  297. self.set(x + i, y, Wall())
  298. def vertWall(self, x, y, length=None):
  299. if length is None:
  300. length = self.height - y
  301. for j in range(0, length):
  302. self.set(x, y + j, Wall())
  303. def wallRect(self, x, y, w, h):
  304. self.horzWall(x, y, w)
  305. self.horzWall(x, y+h-1, w)
  306. self.vertWall(x, y, h)
  307. self.vertWall(x+w-1, y, h)
  308. def rotateLeft(self):
  309. """
  310. Rotate the grid to the left (counter-clockwise)
  311. """
  312. grid = Grid(self.width, self.height)
  313. for j in range(0, self.height):
  314. for i in range(0, self.width):
  315. v = self.get(self.width - 1 - j, i)
  316. grid.set(i, j, v)
  317. return grid
  318. def slice(self, topX, topY, width, height):
  319. """
  320. Get a subset of the grid
  321. """
  322. grid = Grid(width, height)
  323. for j in range(0, height):
  324. for i in range(0, width):
  325. x = topX + i
  326. y = topY + j
  327. if x >= 0 and x < self.width and \
  328. y >= 0 and y < self.height:
  329. v = self.get(x, y)
  330. else:
  331. v = Wall()
  332. grid.set(i, j, v)
  333. return grid
  334. def render(self, r, tileSize):
  335. """
  336. Render this grid at a given scale
  337. :param r: target renderer object
  338. :param tileSize: tile size in pixels
  339. """
  340. assert r.width == self.width * tileSize
  341. assert r.height == self.height * tileSize
  342. # Total grid size at native scale
  343. widthPx = self.width * CELL_PIXELS
  344. heightPx = self.height * CELL_PIXELS
  345. # Draw background (out-of-world) tiles the same colors as walls
  346. # so the agent understands these areas are not reachable
  347. c = COLORS['grey']
  348. r.setLineColor(c[0], c[1], c[2])
  349. r.setColor(c[0], c[1], c[2])
  350. r.drawPolygon([
  351. (0 , heightPx),
  352. (widthPx, heightPx),
  353. (widthPx, 0),
  354. (0 , 0)
  355. ])
  356. r.push()
  357. # Internally, we draw at the "large" full-grid resolution, but we
  358. # use the renderer to scale back to the desired size
  359. r.scale(tileSize / CELL_PIXELS, tileSize / CELL_PIXELS)
  360. # Draw the background of the in-world cells black
  361. r.fillRect(
  362. 0,
  363. 0,
  364. widthPx,
  365. heightPx,
  366. 0, 0, 0
  367. )
  368. # Draw grid lines
  369. r.setLineColor(100, 100, 100)
  370. for rowIdx in range(0, self.height):
  371. y = CELL_PIXELS * rowIdx
  372. r.drawLine(0, y, widthPx, y)
  373. for colIdx in range(0, self.width):
  374. x = CELL_PIXELS * colIdx
  375. r.drawLine(x, 0, x, heightPx)
  376. # Render the grid
  377. for j in range(0, self.height):
  378. for i in range(0, self.width):
  379. cell = self.get(i, j)
  380. if cell == None:
  381. continue
  382. r.push()
  383. r.translate(i * CELL_PIXELS, j * CELL_PIXELS)
  384. cell.render(r)
  385. r.pop()
  386. r.pop()
  387. def encode(self):
  388. """
  389. Produce a compact numpy encoding of the grid
  390. """
  391. codeSize = self.width * self.height * 3
  392. array = np.zeros(shape=(self.width, self.height, 3), dtype='uint8')
  393. for j in range(0, self.height):
  394. for i in range(0, self.width):
  395. v = self.get(i, j)
  396. if v == None:
  397. continue
  398. array[i, j, 0] = OBJECT_TO_IDX[v.type]
  399. array[i, j, 1] = COLOR_TO_IDX[v.color]
  400. if hasattr(v, 'isOpen') and v.isOpen:
  401. array[i, j, 2] = 1
  402. return array
  403. def decode(array):
  404. """
  405. Decode an array grid encoding back into a grid
  406. """
  407. width = array.shape[0]
  408. height = array.shape[1]
  409. assert array.shape[2] == 3
  410. grid = Grid(width, height)
  411. for j in range(0, height):
  412. for i in range(0, width):
  413. typeIdx = array[i, j, 0]
  414. colorIdx = array[i, j, 1]
  415. openIdx = array[i, j, 2]
  416. if typeIdx == 0:
  417. continue
  418. objType = IDX_TO_OBJECT[typeIdx]
  419. color = IDX_TO_COLOR[colorIdx]
  420. isOpen = True if openIdx == 1 else 0
  421. if objType == 'wall':
  422. v = Wall(color)
  423. elif objType == 'ball':
  424. v = Ball(color)
  425. elif objType == 'key':
  426. v = Key(color)
  427. elif objType == 'box':
  428. v = Box(color)
  429. elif objType == 'door':
  430. v = Door(color, isOpen)
  431. elif objType == 'locked_door':
  432. v = LockedDoor(color, isOpen)
  433. elif objType == 'goal':
  434. v = Goal()
  435. else:
  436. assert False, "unknown obj type in decode '%s'" % objType
  437. grid.set(i, j, v)
  438. return grid
  439. def process_vis(
  440. grid,
  441. agent_pos,
  442. n_rays = 32,
  443. n_steps = 24,
  444. a_min = math.pi,
  445. a_max = 2 * math.pi
  446. ):
  447. """
  448. Use ray casting to determine the visibility of each grid cell
  449. """
  450. mask = np.zeros(shape=(grid.width, grid.height), dtype=np.bool)
  451. ang_step = (a_max - a_min) / n_rays
  452. dst_step = math.sqrt(grid.width ** 2 + grid.height ** 2) / n_steps
  453. ax = agent_pos[0] + 0.5
  454. ay = agent_pos[1] + 0.5
  455. for ray_idx in range(0, n_rays):
  456. angle = a_min + ang_step * ray_idx
  457. dx = dst_step * math.cos(angle)
  458. dy = dst_step * math.sin(angle)
  459. for step_idx in range(0, n_steps):
  460. x = ax + (step_idx * dx)
  461. y = ay + (step_idx * dy)
  462. i = math.floor(x)
  463. j = math.floor(y)
  464. # If we're outside of the grid, stop
  465. if i < 0 or i >= grid.width or j < 0 or j >= grid.height:
  466. break
  467. # Mark this cell as visible
  468. mask[i, j] = True
  469. # If we hit the obstructor, stop
  470. cell = grid.get(i, j)
  471. if cell and not cell.see_behind():
  472. break
  473. for j in range(0, grid.height):
  474. for i in range(0, grid.width):
  475. if not mask[i, j]:
  476. #grid.set(i, j, None)
  477. grid.set(i, j, Wall('red'))
  478. return mask
  479. class MiniGridEnv(gym.Env):
  480. """
  481. 2D grid world game environment
  482. """
  483. metadata = {
  484. 'render.modes': ['human', 'rgb_array', 'pixmap'],
  485. 'video.frames_per_second' : 10
  486. }
  487. # Enumeration of possible actions
  488. class Actions(IntEnum):
  489. # Turn left, turn right, move forward
  490. left = 0
  491. right = 1
  492. forward = 2
  493. # Pick up an object
  494. pickup = 3
  495. # Drop an object
  496. drop = 4
  497. # Toggle/activate an object
  498. toggle = 5
  499. # Wait/stay put/do nothing
  500. wait = 6
  501. def __init__(self, grid_size=16, max_steps=100):
  502. # Action enumeration for this environment
  503. self.actions = MiniGridEnv.Actions
  504. # Actions are discrete integer values
  505. self.action_space = spaces.Discrete(len(self.actions))
  506. # Observations are dictionaries containing an
  507. # encoding of the grid and a textual 'mission' string
  508. self.observation_space = spaces.Box(
  509. low=0,
  510. high=255,
  511. shape=OBS_ARRAY_SIZE,
  512. dtype='uint8'
  513. )
  514. self.observation_space = spaces.Dict({
  515. 'image': self.observation_space
  516. })
  517. # Range of possible rewards
  518. self.reward_range = (-1, 1000)
  519. # Renderer object used to render the whole grid (full-scale)
  520. self.grid_render = None
  521. # Renderer used to render observations (small-scale agent view)
  522. self.obs_render = None
  523. # Environment configuration
  524. self.grid_size = grid_size
  525. self.max_steps = max_steps
  526. # Starting position and direction for the agent
  527. self.start_pos = None
  528. self.start_dir = None
  529. # Initialize the state
  530. self.seed()
  531. self.reset()
  532. def reset(self):
  533. # Generate a new random grid at the start of each episode
  534. # To keep the same grid for each episode, call env.seed() with
  535. # the same seed before calling env.reset()
  536. self._genGrid(self.grid_size, self.grid_size)
  537. # These fields should be defined by _genGrid
  538. assert self.start_pos != None
  539. assert self.start_dir != None
  540. # Check that the agent doesn't overlap with an object
  541. assert self.grid.get(*self.start_pos) is None
  542. # Place the agent in the starting position and direction
  543. self.agent_pos = self.start_pos
  544. self.agent_dir = self.start_dir
  545. # Item picked up, being carried, initially nothing
  546. self.carrying = None
  547. # Step count since episode start
  548. self.step_count = 0
  549. # Return first observation
  550. obs = self._gen_obs()
  551. return obs
  552. def seed(self, seed=1337):
  553. # Seed the random number generator
  554. self.np_random, _ = seeding.np_random(seed)
  555. return [seed]
  556. def __str__(self):
  557. """
  558. Produce a pretty string of the environment's grid along with the agent.
  559. The agent is represented by `⏩`. A grid pixel is represented by 2-character
  560. string, the first one for the object and the second one for the color.
  561. """
  562. from copy import deepcopy
  563. def rotate_left(array):
  564. new_array = deepcopy(array)
  565. for i in range(len(array)):
  566. for j in range(len(array[0])):
  567. new_array[j][len(array[0])-1-i] = array[i][j]
  568. return new_array
  569. def vertically_symmetrize(array):
  570. new_array = deepcopy(array)
  571. for i in range(len(array)):
  572. for j in range(len(array[0])):
  573. new_array[i][len(array[0])-1-j] = array[i][j]
  574. return new_array
  575. # Map of object id to short string
  576. OBJECT_IDX_TO_IDS = {
  577. 0: ' ',
  578. 1: 'W',
  579. 2: 'D',
  580. 3: 'L',
  581. 4: 'K',
  582. 5: 'B',
  583. 6: 'X',
  584. 7: 'G'
  585. }
  586. # Short string for opened door
  587. OPENDED_DOOR_IDS = '_'
  588. # Map of color id to short string
  589. COLOR_IDX_TO_IDS = {
  590. 0: 'R',
  591. 1: 'G',
  592. 2: 'B',
  593. 3: 'P',
  594. 4: 'Y',
  595. 5: 'E'
  596. }
  597. # Map agent's direction to short string
  598. AGENT_DIR_TO_IDS = {
  599. 0: '⏩',
  600. 1: '⏬',
  601. 2: '⏪',
  602. 3: '⏫'
  603. }
  604. array = self.grid.encode()
  605. array = rotate_left(array)
  606. array = vertically_symmetrize(array)
  607. new_array = []
  608. for line in array:
  609. new_line = []
  610. for pixel in line:
  611. # If the door is opened
  612. if pixel[0] in [2, 3] and pixel[2] == 1:
  613. object_ids = OPENDED_DOOR_IDS
  614. else:
  615. object_ids = OBJECT_IDX_TO_IDS[pixel[0]]
  616. # If no object
  617. if pixel[0] == 0:
  618. color_ids = ' '
  619. else:
  620. color_ids = COLOR_IDX_TO_IDS[pixel[1]]
  621. new_line.append(object_ids + color_ids)
  622. new_array.append(new_line)
  623. # Add the agent
  624. new_array[self.agent_pos[1]][self.agent_pos[0]] = AGENT_DIR_TO_IDS[self.agent_dir]
  625. return "\n".join([" ".join(line) for line in new_array])
  626. def _genGrid(self, width, height):
  627. assert False, "_genGrid needs to be implemented by each environment"
  628. def _randInt(self, low, high):
  629. """
  630. Generate random integer in [low,high[
  631. """
  632. return self.np_random.randint(low, high)
  633. def _randElem(self, iterable):
  634. """
  635. Pick a random element in a list
  636. """
  637. lst = list(iterable)
  638. idx = self._randInt(0, len(lst))
  639. return lst[idx]
  640. def _randPos(self, xLow, xHigh, yLow, yHigh):
  641. """
  642. Generate a random (x,y) position tuple
  643. """
  644. return (
  645. self.np_random.randint(xLow, xHigh),
  646. self.np_random.randint(yLow, yHigh)
  647. )
  648. def placeObj(self, obj, top=None, size=None, reject_fn=None):
  649. """
  650. Place an object at an empty position in the grid
  651. :param top: top-left position of the rectangle where to place
  652. :param size: size of the rectangle where to place
  653. :param reject_fn: function to filter out potential positions
  654. """
  655. if top is None:
  656. top = (0, 0)
  657. if size is None:
  658. size = (self.grid.width, self.grid.height)
  659. while True:
  660. pos = (
  661. self._randInt(top[0], top[0] + size[0]),
  662. self._randInt(top[1], top[1] + size[1])
  663. )
  664. # Don't place the object on top of another object
  665. if self.grid.get(*pos) != None:
  666. continue
  667. # Don't place the object where the agent is
  668. if pos == self.start_pos:
  669. continue
  670. # Check if there is a filtering criterion
  671. if reject_fn and reject_fn(self, pos):
  672. continue
  673. break
  674. self.grid.set(*pos, obj)
  675. return pos
  676. def placeAgent(self, top=None, size=None, randDir=True):
  677. """
  678. Set the agent's starting point at an empty position in the grid
  679. """
  680. pos = self.placeObj(None, top, size)
  681. self.start_pos = pos
  682. if randDir:
  683. self.start_dir = self._randInt(0, 4)
  684. return pos
  685. def getStepsRemaining(self):
  686. return self.max_steps - self.step_count
  687. def getDirVec(self):
  688. """
  689. Get the direction vector for the agent, pointing in the direction
  690. of forward movement.
  691. """
  692. # Pointing right
  693. if self.agent_dir == 0:
  694. return (1, 0)
  695. # Down (positive Y)
  696. elif self.agent_dir == 1:
  697. return (0, 1)
  698. # Pointing left
  699. elif self.agent_dir == 2:
  700. return (-1, 0)
  701. # Up (negative Y)
  702. elif self.agent_dir == 3:
  703. return (0, -1)
  704. else:
  705. assert False
  706. def getViewExts(self):
  707. """
  708. Get the extents of the square set of tiles visible to the agent
  709. Note: the bottom extent indices are not included in the set
  710. """
  711. # Facing right
  712. if self.agent_dir == 0:
  713. topX = self.agent_pos[0]
  714. topY = self.agent_pos[1] - AGENT_VIEW_SIZE // 2
  715. # Facing down
  716. elif self.agent_dir == 1:
  717. topX = self.agent_pos[0] - AGENT_VIEW_SIZE // 2
  718. topY = self.agent_pos[1]
  719. # Facing left
  720. elif self.agent_dir == 2:
  721. topX = self.agent_pos[0] - AGENT_VIEW_SIZE + 1
  722. topY = self.agent_pos[1] - AGENT_VIEW_SIZE // 2
  723. # Facing up
  724. elif self.agent_dir == 3:
  725. topX = self.agent_pos[0] - AGENT_VIEW_SIZE // 2
  726. topY = self.agent_pos[1] - AGENT_VIEW_SIZE + 1
  727. else:
  728. assert False, "invalid agent direction"
  729. botX = topX + AGENT_VIEW_SIZE
  730. botY = topY + AGENT_VIEW_SIZE
  731. return (topX, topY, botX, botY)
  732. def agentSees(self, x, y):
  733. """
  734. Check if a grid position is visible to the agent
  735. """
  736. topX, topY, botX, botY = self.getViewExts()
  737. return (x >= topX and x < botX and y >= topY and y < botY)
  738. def step(self, action):
  739. self.step_count += 1
  740. reward = 0
  741. done = False
  742. # Get the position in front of the agent
  743. u, v = self.getDirVec()
  744. fwdPos = (self.agent_pos[0] + u, self.agent_pos[1] + v)
  745. # Get the contents of the cell in front of the agent
  746. fwdCell = self.grid.get(*fwdPos)
  747. # Rotate left
  748. if action == self.actions.left:
  749. self.agent_dir -= 1
  750. if self.agent_dir < 0:
  751. self.agent_dir += 4
  752. # Rotate right
  753. elif action == self.actions.right:
  754. self.agent_dir = (self.agent_dir + 1) % 4
  755. # Move forward
  756. elif action == self.actions.forward:
  757. if fwdCell == None or fwdCell.can_overlap():
  758. self.agent_pos = fwdPos
  759. if fwdCell != None and fwdCell.type == 'goal':
  760. done = True
  761. reward = 1000 - self.step_count
  762. # Pick up an object
  763. elif action == self.actions.pickup:
  764. if fwdCell and fwdCell.canPickup():
  765. if self.carrying is None:
  766. self.carrying = fwdCell
  767. self.grid.set(*fwdPos, None)
  768. # Drop an object
  769. elif action == self.actions.drop:
  770. if not fwdCell and self.carrying:
  771. self.grid.set(*fwdPos, self.carrying)
  772. self.carrying = None
  773. # Toggle/activate an object
  774. elif action == self.actions.toggle:
  775. if fwdCell:
  776. fwdCell.toggle(self, fwdPos)
  777. # Wait/do nothing
  778. elif action == self.actions.wait:
  779. pass
  780. else:
  781. assert False, "unknown action"
  782. if self.step_count >= self.max_steps:
  783. done = True
  784. obs = self._gen_obs()
  785. return obs, reward, done, {}
  786. def _gen_obs(self):
  787. """
  788. Generate the agent's view (partially observable, low-resolution encoding)
  789. """
  790. topX, topY, botX, botY = self.getViewExts()
  791. grid = self.grid.slice(topX, topY, AGENT_VIEW_SIZE, AGENT_VIEW_SIZE)
  792. for i in range(self.agent_dir + 1):
  793. grid = grid.rotateLeft()
  794. # Make it so the agent sees what it's carrying
  795. # We do this by placing the carried object at the agent's position
  796. # in the agent's partially observable view
  797. agent_pos = grid.width // 2, grid.height - 1
  798. if self.carrying:
  799. grid.set(*agent_pos, self.carrying)
  800. else:
  801. grid.set(*agent_pos, None)
  802. # Process occluders and visibility
  803. grid.process_vis(agent_pos=(3, 6))
  804. # Encode the partially observable view into a numpy array
  805. image = grid.encode()
  806. assert hasattr(self, 'mission'), "environments must define a textual mission string"
  807. # Observations are dictionaries containing:
  808. # - an image (partially observable view of the environment)
  809. # - the agent's direction/orientation (acting as a compass)
  810. # - a textual mission string (instructions for the agent)
  811. obs = {
  812. 'image': image,
  813. 'direction': self.agent_dir,
  814. 'mission': self.mission
  815. }
  816. return obs
  817. def getObsRender(self, obs):
  818. """
  819. Render an agent observation for visualization
  820. """
  821. if self.obs_render == None:
  822. self.obs_render = Renderer(
  823. AGENT_VIEW_SIZE * CELL_PIXELS // 2,
  824. AGENT_VIEW_SIZE * CELL_PIXELS // 2
  825. )
  826. r = self.obs_render
  827. r.beginFrame()
  828. grid = Grid.decode(obs)
  829. # Render the whole grid
  830. grid.render(r, CELL_PIXELS // 2)
  831. # Draw the agent
  832. r.push()
  833. r.scale(0.5, 0.5)
  834. r.translate(
  835. CELL_PIXELS * (0.5 + AGENT_VIEW_SIZE // 2),
  836. CELL_PIXELS * (AGENT_VIEW_SIZE - 0.5)
  837. )
  838. r.rotate(3 * 90)
  839. r.setLineColor(255, 0, 0)
  840. r.setColor(255, 0, 0)
  841. r.drawPolygon([
  842. (-12, 10),
  843. ( 12, 0),
  844. (-12, -10)
  845. ])
  846. r.pop()
  847. r.endFrame()
  848. return r.getPixmap()
  849. def render(self, mode='human', close=False):
  850. """
  851. Render the whole-grid human view
  852. """
  853. if close:
  854. if self.grid_render:
  855. self.grid_render.close()
  856. return
  857. if self.grid_render is None:
  858. self.grid_render = Renderer(
  859. self.grid_size * CELL_PIXELS,
  860. self.grid_size * CELL_PIXELS,
  861. True if mode == 'human' else False
  862. )
  863. r = self.grid_render
  864. r.beginFrame()
  865. # Render the whole grid
  866. self.grid.render(r, CELL_PIXELS)
  867. # Draw the agent
  868. r.push()
  869. r.translate(
  870. CELL_PIXELS * (self.agent_pos[0] + 0.5),
  871. CELL_PIXELS * (self.agent_pos[1] + 0.5)
  872. )
  873. r.rotate(self.agent_dir * 90)
  874. r.setLineColor(255, 0, 0)
  875. r.setColor(255, 0, 0)
  876. r.drawPolygon([
  877. (-12, 10),
  878. ( 12, 0),
  879. (-12, -10)
  880. ])
  881. r.pop()
  882. # Highlight what the agent can see
  883. topX, topY, botX, botY = self.getViewExts()
  884. r.fillRect(
  885. topX * CELL_PIXELS,
  886. topY * CELL_PIXELS,
  887. AGENT_VIEW_SIZE * CELL_PIXELS,
  888. AGENT_VIEW_SIZE * CELL_PIXELS,
  889. 200, 200, 200, 75
  890. )
  891. r.endFrame()
  892. if mode == 'rgb_array':
  893. return r.getArray()
  894. elif mode == 'pixmap':
  895. return r.getPixmap()
  896. return r