minigrid.py 31 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177
  1. import math
  2. import gym
  3. from enum import IntEnum
  4. import numpy as np
  5. from gym import error, spaces, utils
  6. from gym.utils import seeding
  7. from gym_minigrid.rendering import *
  8. # Size in pixels of a cell in the full-scale human view
  9. CELL_PIXELS = 32
  10. # Number of cells (width and height) in the agent view
  11. AGENT_VIEW_SIZE = 7
  12. # Size of the array given as an observation to the agent
  13. OBS_ARRAY_SIZE = (AGENT_VIEW_SIZE, AGENT_VIEW_SIZE, 3)
  14. # Map of color names to RGB values
  15. COLORS = {
  16. 'red' : (255, 0, 0),
  17. 'green' : (0, 255, 0),
  18. 'blue' : (0, 0, 255),
  19. 'purple': (112, 39, 195),
  20. 'yellow': (255, 255, 0),
  21. 'grey' : (100, 100, 100)
  22. }
  23. COLOR_NAMES = sorted(list(COLORS.keys()))
  24. # Used to map colors to integers
  25. COLOR_TO_IDX = {
  26. 'red' : 0,
  27. 'green' : 1,
  28. 'blue' : 2,
  29. 'purple': 3,
  30. 'yellow': 4,
  31. 'grey' : 5
  32. }
  33. IDX_TO_COLOR = dict(zip(COLOR_TO_IDX.values(), COLOR_TO_IDX.keys()))
  34. # Map of object type to integers
  35. OBJECT_TO_IDX = {
  36. 'empty' : 0,
  37. 'wall' : 1,
  38. 'door' : 2,
  39. 'locked_door' : 3,
  40. 'key' : 4,
  41. 'ball' : 5,
  42. 'box' : 6,
  43. 'goal' : 7
  44. }
  45. IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys()))
  46. class WorldObj:
  47. """
  48. Base class for grid world objects
  49. """
  50. def __init__(self, type, color):
  51. assert type in OBJECT_TO_IDX, type
  52. assert color in COLOR_TO_IDX, color
  53. self.type = type
  54. self.color = color
  55. self.contains = None
  56. def can_overlap(self):
  57. """Can the agent overlap with this?"""
  58. return False
  59. def canPickup(self):
  60. """Can the agent pick this up?"""
  61. return False
  62. def canContain(self):
  63. """Can this contain another object?"""
  64. return False
  65. def see_behind(self):
  66. """Can the agent see behind this object?"""
  67. return True
  68. def toggle(self, env, pos):
  69. """Method to trigger/toggle an action this object performs"""
  70. return False
  71. def render(self, r):
  72. assert False
  73. def _set_color(self, r):
  74. c = COLORS[self.color]
  75. r.setLineColor(c[0], c[1], c[2])
  76. r.setColor(c[0], c[1], c[2])
  77. class Goal(WorldObj):
  78. def __init__(self):
  79. super(Goal, self).__init__('goal', 'green')
  80. def can_overlap(self):
  81. return True
  82. def render(self, r):
  83. self._set_color(r)
  84. r.drawPolygon([
  85. (0 , CELL_PIXELS),
  86. (CELL_PIXELS, CELL_PIXELS),
  87. (CELL_PIXELS, 0),
  88. (0 , 0)
  89. ])
  90. class Wall(WorldObj):
  91. def __init__(self, color='grey'):
  92. super(Wall, self).__init__('wall', color)
  93. def see_behind(self):
  94. return False
  95. def render(self, r):
  96. self._set_color(r)
  97. r.drawPolygon([
  98. (0 , CELL_PIXELS),
  99. (CELL_PIXELS, CELL_PIXELS),
  100. (CELL_PIXELS, 0),
  101. (0 , 0)
  102. ])
  103. class Door(WorldObj):
  104. def __init__(self, color, is_open=False):
  105. super(Door, self).__init__('door', color)
  106. self.is_open = is_open
  107. def can_overlap(self):
  108. """The agent can only walk over this cell when the door is open"""
  109. return self.is_open
  110. def see_behind(self):
  111. return self.is_open
  112. def toggle(self, env, pos):
  113. if not self.is_open:
  114. self.is_open = True
  115. return True
  116. return False
  117. def render(self, r):
  118. c = COLORS[self.color]
  119. r.setLineColor(c[0], c[1], c[2])
  120. r.setColor(0, 0, 0)
  121. if self.is_open:
  122. r.drawPolygon([
  123. (CELL_PIXELS-2, CELL_PIXELS),
  124. (CELL_PIXELS , CELL_PIXELS),
  125. (CELL_PIXELS , 0),
  126. (CELL_PIXELS-2, 0)
  127. ])
  128. return
  129. r.drawPolygon([
  130. (0 , CELL_PIXELS),
  131. (CELL_PIXELS, CELL_PIXELS),
  132. (CELL_PIXELS, 0),
  133. (0 , 0)
  134. ])
  135. r.drawPolygon([
  136. (2 , CELL_PIXELS-2),
  137. (CELL_PIXELS-2, CELL_PIXELS-2),
  138. (CELL_PIXELS-2, 2),
  139. (2 , 2)
  140. ])
  141. r.drawCircle(CELL_PIXELS * 0.75, CELL_PIXELS * 0.5, 2)
  142. class LockedDoor(WorldObj):
  143. def __init__(self, color, is_open=False):
  144. super(LockedDoor, self).__init__('locked_door', color)
  145. self.is_open = is_open
  146. def toggle(self, env, pos):
  147. # If the player has the right key to open the door
  148. if isinstance(env.carrying, Key) and env.carrying.color == self.color:
  149. self.is_open = True
  150. # The key has been used, remove it from the agent
  151. env.carrying = None
  152. return True
  153. return False
  154. def can_overlap(self):
  155. """The agent can only walk over this cell when the door is open"""
  156. return self.is_open
  157. def render(self, r):
  158. c = COLORS[self.color]
  159. r.setLineColor(c[0], c[1], c[2])
  160. r.setColor(c[0], c[1], c[2], 50)
  161. if self.is_open:
  162. r.drawPolygon([
  163. (CELL_PIXELS-2, CELL_PIXELS),
  164. (CELL_PIXELS , CELL_PIXELS),
  165. (CELL_PIXELS , 0),
  166. (CELL_PIXELS-2, 0)
  167. ])
  168. return
  169. r.drawPolygon([
  170. (0 , CELL_PIXELS),
  171. (CELL_PIXELS, CELL_PIXELS),
  172. (CELL_PIXELS, 0),
  173. (0 , 0)
  174. ])
  175. r.drawPolygon([
  176. (2 , CELL_PIXELS-2),
  177. (CELL_PIXELS-2, CELL_PIXELS-2),
  178. (CELL_PIXELS-2, 2),
  179. (2 , 2)
  180. ])
  181. r.drawLine(
  182. CELL_PIXELS * 0.55,
  183. CELL_PIXELS * 0.5,
  184. CELL_PIXELS * 0.75,
  185. CELL_PIXELS * 0.5
  186. )
  187. class Key(WorldObj):
  188. def __init__(self, color='blue'):
  189. super(Key, self).__init__('key', color)
  190. def canPickup(self):
  191. return True
  192. def render(self, r):
  193. self._set_color(r)
  194. # Vertical quad
  195. r.drawPolygon([
  196. (16, 10),
  197. (20, 10),
  198. (20, 28),
  199. (16, 28)
  200. ])
  201. # Teeth
  202. r.drawPolygon([
  203. (12, 19),
  204. (16, 19),
  205. (16, 21),
  206. (12, 21)
  207. ])
  208. r.drawPolygon([
  209. (12, 26),
  210. (16, 26),
  211. (16, 28),
  212. (12, 28)
  213. ])
  214. r.drawCircle(18, 9, 6)
  215. r.setLineColor(0, 0, 0)
  216. r.setColor(0, 0, 0)
  217. r.drawCircle(18, 9, 2)
  218. class Ball(WorldObj):
  219. def __init__(self, color='blue'):
  220. super(Ball, self).__init__('ball', color)
  221. def canPickup(self):
  222. return True
  223. def render(self, r):
  224. self._set_color(r)
  225. r.drawCircle(CELL_PIXELS * 0.5, CELL_PIXELS * 0.5, 10)
  226. class Box(WorldObj):
  227. def __init__(self, color, contains=None):
  228. super(Box, self).__init__('box', color)
  229. self.contains = contains
  230. def canPickup(self):
  231. return True
  232. def render(self, r):
  233. c = COLORS[self.color]
  234. r.setLineColor(c[0], c[1], c[2])
  235. r.setColor(0, 0, 0)
  236. r.setLineWidth(2)
  237. r.drawPolygon([
  238. (4 , CELL_PIXELS-4),
  239. (CELL_PIXELS-4, CELL_PIXELS-4),
  240. (CELL_PIXELS-4, 4),
  241. (4 , 4)
  242. ])
  243. r.drawLine(
  244. 4,
  245. CELL_PIXELS / 2,
  246. CELL_PIXELS - 4,
  247. CELL_PIXELS / 2
  248. )
  249. r.setLineWidth(1)
  250. def toggle(self, env, pos):
  251. # Replace the box by its contents
  252. env.grid.set(*pos, self.contains)
  253. return True
  254. class Grid:
  255. """
  256. Represent a grid and operations on it
  257. """
  258. def __init__(self, width, height):
  259. assert width >= 4
  260. assert height >= 4
  261. self.width = width
  262. self.height = height
  263. self.grid = [None] * width * height
  264. def __contains__(self, key):
  265. if isinstance(key, WorldObj):
  266. for e in self.grid:
  267. if e is key:
  268. return True
  269. elif isinstance(key, tuple):
  270. for e in self.grid:
  271. if e is None:
  272. continue
  273. if (e.color, e.type) == key:
  274. return True
  275. return False
  276. def __eq__(self, other):
  277. grid1 = self.encode()
  278. grid2 = other.encode()
  279. return np.array_equal(grid2, grid1)
  280. def __ne__(self, other):
  281. return not self == other
  282. def copy(self):
  283. from copy import deepcopy
  284. return deepcopy(self)
  285. def set(self, i, j, v):
  286. assert i >= 0 and i < self.width
  287. assert j >= 0 and j < self.height
  288. self.grid[j * self.width + i] = v
  289. def get(self, i, j):
  290. assert i >= 0 and i < self.width
  291. assert j >= 0 and j < self.height
  292. return self.grid[j * self.width + i]
  293. def horzWall(self, x, y, length=None):
  294. if length is None:
  295. length = self.width - x
  296. for i in range(0, length):
  297. self.set(x + i, y, Wall())
  298. def vertWall(self, x, y, length=None):
  299. if length is None:
  300. length = self.height - y
  301. for j in range(0, length):
  302. self.set(x, y + j, Wall())
  303. def wallRect(self, x, y, w, h):
  304. self.horzWall(x, y, w)
  305. self.horzWall(x, y+h-1, w)
  306. self.vertWall(x, y, h)
  307. self.vertWall(x+w-1, y, h)
  308. def rotateLeft(self):
  309. """
  310. Rotate the grid to the left (counter-clockwise)
  311. """
  312. grid = Grid(self.width, self.height)
  313. for j in range(0, self.height):
  314. for i in range(0, self.width):
  315. v = self.get(self.width - 1 - j, i)
  316. grid.set(i, j, v)
  317. return grid
  318. def slice(self, topX, topY, width, height):
  319. """
  320. Get a subset of the grid
  321. """
  322. grid = Grid(width, height)
  323. for j in range(0, height):
  324. for i in range(0, width):
  325. x = topX + i
  326. y = topY + j
  327. if x >= 0 and x < self.width and \
  328. y >= 0 and y < self.height:
  329. v = self.get(x, y)
  330. else:
  331. v = Wall()
  332. grid.set(i, j, v)
  333. return grid
  334. def render(self, r, tileSize):
  335. """
  336. Render this grid at a given scale
  337. :param r: target renderer object
  338. :param tileSize: tile size in pixels
  339. """
  340. assert r.width == self.width * tileSize
  341. assert r.height == self.height * tileSize
  342. # Total grid size at native scale
  343. widthPx = self.width * CELL_PIXELS
  344. heightPx = self.height * CELL_PIXELS
  345. """
  346. # Draw background (out-of-world) tiles the same colors as walls
  347. # so the agent understands these areas are not reachable
  348. c = COLORS['grey']
  349. r.setLineColor(c[0], c[1], c[2])
  350. r.setColor(c[0], c[1], c[2])
  351. r.drawPolygon([
  352. (0 , heightPx),
  353. (widthPx, heightPx),
  354. (widthPx, 0),
  355. (0 , 0)
  356. ])
  357. """
  358. r.push()
  359. # Internally, we draw at the "large" full-grid resolution, but we
  360. # use the renderer to scale back to the desired size
  361. r.scale(tileSize / CELL_PIXELS, tileSize / CELL_PIXELS)
  362. # Draw the background of the in-world cells black
  363. r.fillRect(
  364. 0,
  365. 0,
  366. widthPx,
  367. heightPx,
  368. 0, 0, 0
  369. )
  370. # Draw grid lines
  371. r.setLineColor(100, 100, 100)
  372. for rowIdx in range(0, self.height):
  373. y = CELL_PIXELS * rowIdx
  374. r.drawLine(0, y, widthPx, y)
  375. for colIdx in range(0, self.width):
  376. x = CELL_PIXELS * colIdx
  377. r.drawLine(x, 0, x, heightPx)
  378. # Render the grid
  379. for j in range(0, self.height):
  380. for i in range(0, self.width):
  381. cell = self.get(i, j)
  382. if cell == None:
  383. continue
  384. r.push()
  385. r.translate(i * CELL_PIXELS, j * CELL_PIXELS)
  386. cell.render(r)
  387. r.pop()
  388. r.pop()
  389. def encode(self):
  390. """
  391. Produce a compact numpy encoding of the grid
  392. """
  393. codeSize = self.width * self.height * 3
  394. array = np.zeros(shape=(self.width, self.height, 3), dtype='uint8')
  395. for j in range(0, self.height):
  396. for i in range(0, self.width):
  397. v = self.get(i, j)
  398. if v == None:
  399. continue
  400. array[i, j, 0] = OBJECT_TO_IDX[v.type]
  401. array[i, j, 1] = COLOR_TO_IDX[v.color]
  402. if hasattr(v, 'is_open') and v.is_open:
  403. array[i, j, 2] = 1
  404. return array
  405. def decode(array):
  406. """
  407. Decode an array grid encoding back into a grid
  408. """
  409. width = array.shape[0]
  410. height = array.shape[1]
  411. assert array.shape[2] == 3
  412. grid = Grid(width, height)
  413. for j in range(0, height):
  414. for i in range(0, width):
  415. typeIdx = array[i, j, 0]
  416. colorIdx = array[i, j, 1]
  417. openIdx = array[i, j, 2]
  418. if typeIdx == 0:
  419. continue
  420. objType = IDX_TO_OBJECT[typeIdx]
  421. color = IDX_TO_COLOR[colorIdx]
  422. is_open = True if openIdx == 1 else 0
  423. if objType == 'wall':
  424. v = Wall(color)
  425. elif objType == 'ball':
  426. v = Ball(color)
  427. elif objType == 'key':
  428. v = Key(color)
  429. elif objType == 'box':
  430. v = Box(color)
  431. elif objType == 'door':
  432. v = Door(color, is_open)
  433. elif objType == 'locked_door':
  434. v = LockedDoor(color, is_open)
  435. elif objType == 'goal':
  436. v = Goal()
  437. else:
  438. assert False, "unknown obj type in decode '%s'" % objType
  439. grid.set(i, j, v)
  440. return grid
  441. def process_vis(
  442. grid,
  443. agent_pos,
  444. n_rays = 32,
  445. n_steps = 24,
  446. a_min = math.pi,
  447. a_max = 2 * math.pi
  448. ):
  449. """
  450. Use ray casting to determine the visibility of each grid cell
  451. """
  452. mask = np.zeros(shape=(grid.width, grid.height), dtype=np.bool)
  453. ang_step = (a_max - a_min) / n_rays
  454. dst_step = math.sqrt(grid.width ** 2 + grid.height ** 2) / n_steps
  455. ax = agent_pos[0] + 0.5
  456. ay = agent_pos[1] + 0.5
  457. for ray_idx in range(0, n_rays):
  458. angle = a_min + ang_step * ray_idx
  459. dx = dst_step * math.cos(angle)
  460. dy = dst_step * math.sin(angle)
  461. for step_idx in range(0, n_steps):
  462. x = ax + (step_idx * dx)
  463. y = ay + (step_idx * dy)
  464. i = math.floor(x)
  465. j = math.floor(y)
  466. # If we're outside of the grid, stop
  467. if i < 0 or i >= grid.width or j < 0 or j >= grid.height:
  468. break
  469. # Mark this cell as visible
  470. mask[i, j] = True
  471. # If we hit the obstructor, stop
  472. cell = grid.get(i, j)
  473. if cell and not cell.see_behind():
  474. break
  475. for j in range(0, grid.height):
  476. for i in range(0, grid.width):
  477. if not mask[i, j]:
  478. grid.set(i, j, None)
  479. #grid.set(i, j, Wall('red'))
  480. return mask
  481. class MiniGridEnv(gym.Env):
  482. """
  483. 2D grid world game environment
  484. """
  485. metadata = {
  486. 'render.modes': ['human', 'rgb_array', 'pixmap'],
  487. 'video.frames_per_second' : 10
  488. }
  489. # Enumeration of possible actions
  490. class Actions(IntEnum):
  491. # Turn left, turn right, move forward
  492. left = 0
  493. right = 1
  494. forward = 2
  495. # Pick up an object
  496. pickup = 3
  497. # Drop an object
  498. drop = 4
  499. # Toggle/activate an object
  500. toggle = 5
  501. # Wait/stay put/do nothing
  502. wait = 6
  503. def __init__(self, grid_size=16, max_steps=100):
  504. # Action enumeration for this environment
  505. self.actions = MiniGridEnv.Actions
  506. # Actions are discrete integer values
  507. self.action_space = spaces.Discrete(len(self.actions))
  508. # Observations are dictionaries containing an
  509. # encoding of the grid and a textual 'mission' string
  510. self.observation_space = spaces.Box(
  511. low=0,
  512. high=255,
  513. shape=OBS_ARRAY_SIZE,
  514. dtype='uint8'
  515. )
  516. self.observation_space = spaces.Dict({
  517. 'image': self.observation_space
  518. })
  519. # Range of possible rewards
  520. self.reward_range = (-1, 1000)
  521. # Renderer object used to render the whole grid (full-scale)
  522. self.grid_render = None
  523. # Renderer used to render observations (small-scale agent view)
  524. self.obs_render = None
  525. # Environment configuration
  526. self.grid_size = grid_size
  527. self.max_steps = max_steps
  528. # Starting position and direction for the agent
  529. self.start_pos = None
  530. self.start_dir = None
  531. # Initialize the state
  532. self.seed()
  533. self.reset()
  534. def reset(self):
  535. # Generate a new random grid at the start of each episode
  536. # To keep the same grid for each episode, call env.seed() with
  537. # the same seed before calling env.reset()
  538. self._genGrid(self.grid_size, self.grid_size)
  539. # These fields should be defined by _genGrid
  540. assert self.start_pos != None
  541. assert self.start_dir != None
  542. # Check that the agent doesn't overlap with an object
  543. assert self.grid.get(*self.start_pos) is None
  544. # Place the agent in the starting position and direction
  545. self.agent_pos = self.start_pos
  546. self.agent_dir = self.start_dir
  547. # Item picked up, being carried, initially nothing
  548. self.carrying = None
  549. # Step count since episode start
  550. self.step_count = 0
  551. # Return first observation
  552. obs = self._gen_obs()
  553. return obs
  554. def seed(self, seed=1337):
  555. # Seed the random number generator
  556. self.np_random, _ = seeding.np_random(seed)
  557. return [seed]
  558. def __str__(self):
  559. """
  560. Produce a pretty string of the environment's grid along with the agent.
  561. The agent is represented by `⏩`. A grid pixel is represented by 2-character
  562. string, the first one for the object and the second one for the color.
  563. """
  564. from copy import deepcopy
  565. def rotate_left(array):
  566. new_array = deepcopy(array)
  567. for i in range(len(array)):
  568. for j in range(len(array[0])):
  569. new_array[j][len(array[0])-1-i] = array[i][j]
  570. return new_array
  571. def vertically_symmetrize(array):
  572. new_array = deepcopy(array)
  573. for i in range(len(array)):
  574. for j in range(len(array[0])):
  575. new_array[i][len(array[0])-1-j] = array[i][j]
  576. return new_array
  577. # Map of object id to short string
  578. OBJECT_IDX_TO_IDS = {
  579. 0: ' ',
  580. 1: 'W',
  581. 2: 'D',
  582. 3: 'L',
  583. 4: 'K',
  584. 5: 'B',
  585. 6: 'X',
  586. 7: 'G'
  587. }
  588. # Short string for opened door
  589. OPENDED_DOOR_IDS = '_'
  590. # Map of color id to short string
  591. COLOR_IDX_TO_IDS = {
  592. 0: 'R',
  593. 1: 'G',
  594. 2: 'B',
  595. 3: 'P',
  596. 4: 'Y',
  597. 5: 'E'
  598. }
  599. # Map agent's direction to short string
  600. AGENT_DIR_TO_IDS = {
  601. 0: '⏩',
  602. 1: '⏬',
  603. 2: '⏪',
  604. 3: '⏫'
  605. }
  606. array = self.grid.encode()
  607. array = rotate_left(array)
  608. array = vertically_symmetrize(array)
  609. new_array = []
  610. for line in array:
  611. new_line = []
  612. for pixel in line:
  613. # If the door is opened
  614. if pixel[0] in [2, 3] and pixel[2] == 1:
  615. object_ids = OPENDED_DOOR_IDS
  616. else:
  617. object_ids = OBJECT_IDX_TO_IDS[pixel[0]]
  618. # If no object
  619. if pixel[0] == 0:
  620. color_ids = ' '
  621. else:
  622. color_ids = COLOR_IDX_TO_IDS[pixel[1]]
  623. new_line.append(object_ids + color_ids)
  624. new_array.append(new_line)
  625. # Add the agent
  626. new_array[self.agent_pos[1]][self.agent_pos[0]] = AGENT_DIR_TO_IDS[self.agent_dir]
  627. return "\n".join([" ".join(line) for line in new_array])
  628. def _genGrid(self, width, height):
  629. assert False, "_genGrid needs to be implemented by each environment"
  630. def _randInt(self, low, high):
  631. """
  632. Generate random integer in [low,high[
  633. """
  634. return self.np_random.randint(low, high)
  635. def _randElem(self, iterable):
  636. """
  637. Pick a random element in a list
  638. """
  639. lst = list(iterable)
  640. idx = self._randInt(0, len(lst))
  641. return lst[idx]
  642. def _randPos(self, xLow, xHigh, yLow, yHigh):
  643. """
  644. Generate a random (x,y) position tuple
  645. """
  646. return (
  647. self.np_random.randint(xLow, xHigh),
  648. self.np_random.randint(yLow, yHigh)
  649. )
  650. def placeObj(self, obj, top=None, size=None, reject_fn=None):
  651. """
  652. Place an object at an empty position in the grid
  653. :param top: top-left position of the rectangle where to place
  654. :param size: size of the rectangle where to place
  655. :param reject_fn: function to filter out potential positions
  656. """
  657. if top is None:
  658. top = (0, 0)
  659. if size is None:
  660. size = (self.grid.width, self.grid.height)
  661. while True:
  662. pos = (
  663. self._randInt(top[0], top[0] + size[0]),
  664. self._randInt(top[1], top[1] + size[1])
  665. )
  666. # Don't place the object on top of another object
  667. if self.grid.get(*pos) != None:
  668. continue
  669. # Don't place the object where the agent is
  670. if pos == self.start_pos:
  671. continue
  672. # Check if there is a filtering criterion
  673. if reject_fn and reject_fn(self, pos):
  674. continue
  675. break
  676. self.grid.set(*pos, obj)
  677. return pos
  678. def placeAgent(self, top=None, size=None, randDir=True):
  679. """
  680. Set the agent's starting point at an empty position in the grid
  681. """
  682. pos = self.placeObj(None, top, size)
  683. self.start_pos = pos
  684. if randDir:
  685. self.start_dir = self._randInt(0, 4)
  686. return pos
  687. def getStepsRemaining(self):
  688. return self.max_steps - self.step_count
  689. def getDirVec(self):
  690. """
  691. Get the direction vector for the agent, pointing in the direction
  692. of forward movement.
  693. """
  694. # Pointing right
  695. if self.agent_dir == 0:
  696. return (1, 0)
  697. # Down (positive Y)
  698. elif self.agent_dir == 1:
  699. return (0, 1)
  700. # Pointing left
  701. elif self.agent_dir == 2:
  702. return (-1, 0)
  703. # Up (negative Y)
  704. elif self.agent_dir == 3:
  705. return (0, -1)
  706. else:
  707. assert False
  708. def get_right_vec(self):
  709. """
  710. Get the vector pointing to the right of the agent.
  711. """
  712. dx, dy = self.getDirVec()
  713. return -dy, dx
  714. def get_view_coords(self, i, j):
  715. """
  716. Translate and rotate absolute grid coordinates (i, j) into the
  717. agent's partially observable view (sub-grid). Note that the resulting
  718. coordinates may be negative or outside of the agent's view size.
  719. """
  720. ax, ay = self.agent_pos
  721. dx, dy = self.getDirVec()
  722. rx, ry = self.get_right_vec()
  723. # Compute the absolute coordinates of the top-left view corner
  724. sz = AGENT_VIEW_SIZE
  725. hs = AGENT_VIEW_SIZE // 2
  726. tx = ax + (dx * (sz-1)) - (rx * hs)
  727. ty = ay + (dy * (sz-1)) - (ry * hs)
  728. lx = i - tx
  729. ly = j - ty
  730. # Project the coordinates of the object relative to the top-left
  731. # corner onto the agent's own coordinate system
  732. vx = (rx*lx + ry*ly)
  733. vy = -(dx*lx + dy*ly)
  734. return vx, vy
  735. def get_view_exts(self):
  736. """
  737. Get the extents of the square set of tiles visible to the agent
  738. Note: the bottom extent indices are not included in the set
  739. """
  740. # Facing right
  741. if self.agent_dir == 0:
  742. topX = self.agent_pos[0]
  743. topY = self.agent_pos[1] - AGENT_VIEW_SIZE // 2
  744. # Facing down
  745. elif self.agent_dir == 1:
  746. topX = self.agent_pos[0] - AGENT_VIEW_SIZE // 2
  747. topY = self.agent_pos[1]
  748. # Facing left
  749. elif self.agent_dir == 2:
  750. topX = self.agent_pos[0] - AGENT_VIEW_SIZE + 1
  751. topY = self.agent_pos[1] - AGENT_VIEW_SIZE // 2
  752. # Facing up
  753. elif self.agent_dir == 3:
  754. topX = self.agent_pos[0] - AGENT_VIEW_SIZE // 2
  755. topY = self.agent_pos[1] - AGENT_VIEW_SIZE + 1
  756. else:
  757. assert False, "invalid agent direction"
  758. botX = topX + AGENT_VIEW_SIZE
  759. botY = topY + AGENT_VIEW_SIZE
  760. return (topX, topY, botX, botY)
  761. def agent_sees(self, x, y):
  762. """
  763. Check if a grid position is visible to the agent
  764. """
  765. vx, vy = self.get_view_coords(x, y)
  766. if vx < 0 or vy < 0 or vx >= AGENT_VIEW_SIZE or vy >= AGENT_VIEW_SIZE:
  767. return False
  768. obs = self._gen_obs()
  769. obs_grid = Grid.decode(obs['image'])
  770. obs_cell = obs_grid.get(vx, vy)
  771. world_cell = self.grid.get(x, y)
  772. return obs_cell is not None and obs_cell.type == world_cell.type
  773. def step(self, action):
  774. self.step_count += 1
  775. reward = 0
  776. done = False
  777. # Get the position in front of the agent
  778. u, v = self.getDirVec()
  779. fwdPos = (self.agent_pos[0] + u, self.agent_pos[1] + v)
  780. # Get the contents of the cell in front of the agent
  781. fwdCell = self.grid.get(*fwdPos)
  782. # Rotate left
  783. if action == self.actions.left:
  784. self.agent_dir -= 1
  785. if self.agent_dir < 0:
  786. self.agent_dir += 4
  787. # Rotate right
  788. elif action == self.actions.right:
  789. self.agent_dir = (self.agent_dir + 1) % 4
  790. # Move forward
  791. elif action == self.actions.forward:
  792. if fwdCell == None or fwdCell.can_overlap():
  793. self.agent_pos = fwdPos
  794. if fwdCell != None and fwdCell.type == 'goal':
  795. done = True
  796. reward = 1000 - self.step_count
  797. # Pick up an object
  798. elif action == self.actions.pickup:
  799. if fwdCell and fwdCell.canPickup():
  800. if self.carrying is None:
  801. self.carrying = fwdCell
  802. self.grid.set(*fwdPos, None)
  803. # Drop an object
  804. elif action == self.actions.drop:
  805. if not fwdCell and self.carrying:
  806. self.grid.set(*fwdPos, self.carrying)
  807. self.carrying = None
  808. # Toggle/activate an object
  809. elif action == self.actions.toggle:
  810. if fwdCell:
  811. fwdCell.toggle(self, fwdPos)
  812. # Wait/do nothing
  813. elif action == self.actions.wait:
  814. pass
  815. else:
  816. assert False, "unknown action"
  817. if self.step_count >= self.max_steps:
  818. done = True
  819. obs = self._gen_obs()
  820. return obs, reward, done, {}
  821. def _gen_obs(self):
  822. """
  823. Generate the agent's view (partially observable, low-resolution encoding)
  824. """
  825. topX, topY, botX, botY = self.get_view_exts()
  826. grid = self.grid.slice(topX, topY, AGENT_VIEW_SIZE, AGENT_VIEW_SIZE)
  827. for i in range(self.agent_dir + 1):
  828. grid = grid.rotateLeft()
  829. # Make it so the agent sees what it's carrying
  830. # We do this by placing the carried object at the agent's position
  831. # in the agent's partially observable view
  832. agent_pos = grid.width // 2, grid.height - 1
  833. if self.carrying:
  834. grid.set(*agent_pos, self.carrying)
  835. else:
  836. grid.set(*agent_pos, None)
  837. # Process occluders and visibility
  838. grid.process_vis(agent_pos=(3, 6))
  839. # Encode the partially observable view into a numpy array
  840. image = grid.encode()
  841. assert hasattr(self, 'mission'), "environments must define a textual mission string"
  842. # Observations are dictionaries containing:
  843. # - an image (partially observable view of the environment)
  844. # - the agent's direction/orientation (acting as a compass)
  845. # - a textual mission string (instructions for the agent)
  846. obs = {
  847. 'image': image,
  848. 'direction': self.agent_dir,
  849. 'mission': self.mission
  850. }
  851. return obs
  852. def getObsRender(self, obs):
  853. """
  854. Render an agent observation for visualization
  855. """
  856. if self.obs_render == None:
  857. self.obs_render = Renderer(
  858. AGENT_VIEW_SIZE * CELL_PIXELS // 2,
  859. AGENT_VIEW_SIZE * CELL_PIXELS // 2
  860. )
  861. r = self.obs_render
  862. r.beginFrame()
  863. grid = Grid.decode(obs)
  864. # Render the whole grid
  865. grid.render(r, CELL_PIXELS // 2)
  866. # Draw the agent
  867. r.push()
  868. r.scale(0.5, 0.5)
  869. r.translate(
  870. CELL_PIXELS * (0.5 + AGENT_VIEW_SIZE // 2),
  871. CELL_PIXELS * (AGENT_VIEW_SIZE - 0.5)
  872. )
  873. r.rotate(3 * 90)
  874. r.setLineColor(255, 0, 0)
  875. r.setColor(255, 0, 0)
  876. r.drawPolygon([
  877. (-12, 10),
  878. ( 12, 0),
  879. (-12, -10)
  880. ])
  881. r.pop()
  882. r.endFrame()
  883. return r.getPixmap()
  884. def render(self, mode='human', close=False):
  885. """
  886. Render the whole-grid human view
  887. """
  888. if close:
  889. if self.grid_render:
  890. self.grid_render.close()
  891. return
  892. if self.grid_render is None:
  893. self.grid_render = Renderer(
  894. self.grid_size * CELL_PIXELS,
  895. self.grid_size * CELL_PIXELS,
  896. True if mode == 'human' else False
  897. )
  898. r = self.grid_render
  899. r.beginFrame()
  900. # Render the whole grid
  901. self.grid.render(r, CELL_PIXELS)
  902. # Draw the agent
  903. r.push()
  904. r.translate(
  905. CELL_PIXELS * (self.agent_pos[0] + 0.5),
  906. CELL_PIXELS * (self.agent_pos[1] + 0.5)
  907. )
  908. r.rotate(self.agent_dir * 90)
  909. r.setLineColor(255, 0, 0)
  910. r.setColor(255, 0, 0)
  911. r.drawPolygon([
  912. (-12, 10),
  913. ( 12, 0),
  914. (-12, -10)
  915. ])
  916. r.pop()
  917. # Highlight what the agent can see
  918. topX, topY, botX, botY = self.get_view_exts()
  919. r.fillRect(
  920. topX * CELL_PIXELS,
  921. topY * CELL_PIXELS,
  922. AGENT_VIEW_SIZE * CELL_PIXELS,
  923. AGENT_VIEW_SIZE * CELL_PIXELS,
  924. 200, 200, 200, 75
  925. )
  926. r.endFrame()
  927. if mode == 'rgb_array':
  928. return r.getArray()
  929. elif mode == 'pixmap':
  930. return r.getPixmap()
  931. return r