minigrid.py 37 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370
  1. import math
  2. import gym
  3. from enum import IntEnum
  4. import numpy as np
  5. from gym import error, spaces, utils
  6. from gym.utils import seeding
  7. from .rendering import *
  8. # Size in pixels of a tile in the full-scale human view
  9. TILE_PIXELS = 32
  10. # Map of color names to RGB values
  11. COLORS = {
  12. 'red' : np.array([255, 0, 0]),
  13. 'green' : np.array([0, 255, 0]),
  14. 'blue' : np.array([0, 0, 255]),
  15. 'purple': np.array([112, 39, 195]),
  16. 'yellow': np.array([255, 255, 0]),
  17. 'grey' : np.array([100, 100, 100])
  18. }
  19. COLOR_NAMES = sorted(list(COLORS.keys()))
  20. # Used to map colors to integers
  21. COLOR_TO_IDX = {
  22. 'red' : 0,
  23. 'green' : 1,
  24. 'blue' : 2,
  25. 'purple': 3,
  26. 'yellow': 4,
  27. 'grey' : 5
  28. }
  29. IDX_TO_COLOR = dict(zip(COLOR_TO_IDX.values(), COLOR_TO_IDX.keys()))
  30. # Map of object type to integers
  31. OBJECT_TO_IDX = {
  32. 'unseen' : 0,
  33. 'empty' : 1,
  34. 'wall' : 2,
  35. 'floor' : 3,
  36. 'door' : 4,
  37. 'key' : 5,
  38. 'ball' : 6,
  39. 'box' : 7,
  40. 'goal' : 8,
  41. 'lava' : 9,
  42. 'agent' : 10,
  43. }
  44. IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys()))
  45. # Map of state names to integers
  46. STATE_TO_IDX = {
  47. 'open' : 0,
  48. 'closed': 1,
  49. 'locked': 2,
  50. }
  51. # Map of agent direction indices to vectors
  52. DIR_TO_VEC = [
  53. # Pointing right (positive X)
  54. np.array((1, 0)),
  55. # Down (positive Y)
  56. np.array((0, 1)),
  57. # Pointing left (negative X)
  58. np.array((-1, 0)),
  59. # Up (negative Y)
  60. np.array((0, -1)),
  61. ]
  62. class WorldObj:
  63. """
  64. Base class for grid world objects
  65. """
  66. def __init__(self, type, color):
  67. assert type in OBJECT_TO_IDX, type
  68. assert color in COLOR_TO_IDX, color
  69. self.type = type
  70. self.color = color
  71. self.contains = None
  72. # Initial position of the object
  73. self.init_pos = None
  74. # Current position of the object
  75. self.cur_pos = None
  76. def can_overlap(self):
  77. """Can the agent overlap with this?"""
  78. return False
  79. def can_pickup(self):
  80. """Can the agent pick this up?"""
  81. return False
  82. def can_contain(self):
  83. """Can this contain another object?"""
  84. return False
  85. def see_behind(self):
  86. """Can the agent see behind this object?"""
  87. return True
  88. def toggle(self, env, pos):
  89. """Method to trigger/toggle an action this object performs"""
  90. return False
  91. def encode(self):
  92. """Encode the a description of this object as a 3-tuple of integers"""
  93. return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], 0)
  94. @staticmethod
  95. def decode(type_idx, color_idx, state):
  96. """Create an object from a 3-tuple state description"""
  97. obj_type = IDX_TO_OBJECT[type_idx]
  98. color = IDX_TO_COLOR[color_idx]
  99. if obj_type == 'empty' or obj_type == 'unseen':
  100. return None
  101. # State, 0: open, 1: closed, 2: locked
  102. is_open = state == 0
  103. is_locked = state == 2
  104. if obj_type == 'wall':
  105. v = Wall(color)
  106. elif obj_type == 'floor':
  107. v = Floor(color)
  108. elif obj_type == 'ball':
  109. v = Ball(color)
  110. elif obj_type == 'key':
  111. v = Key(color)
  112. elif obj_type == 'box':
  113. v = Box(color)
  114. elif obj_type == 'door':
  115. v = Door(color, is_open, is_locked)
  116. elif obj_type == 'goal':
  117. v = Goal()
  118. elif obj_type == 'lava':
  119. v = Lava()
  120. else:
  121. assert False, "unknown object type in decode '%s'" % objType
  122. return v
  123. def render(self, r):
  124. """Draw this object with the given renderer"""
  125. raise NotImplementedError
  126. class Goal(WorldObj):
  127. def __init__(self):
  128. super().__init__('goal', 'green')
  129. def can_overlap(self):
  130. return True
  131. def render(self, r):
  132. fill_coords(img, point_in_rect(0.5, 0.5, 0.5, 0.5), COLORS[self.color])
  133. class Floor(WorldObj):
  134. """
  135. Colored floor tile the agent can walk over
  136. """
  137. def __init__(self, color='blue'):
  138. super().__init__('floor', color)
  139. def can_overlap(self):
  140. return True
  141. def render(self, r):
  142. # Give the floor a pale color
  143. c = COLORS[self.color]
  144. r.setLineColor(100, 100, 100, 0)
  145. r.setColor(*c/2)
  146. r.drawPolygon([
  147. (1 , TILE_PIXELS),
  148. (TILE_PIXELS, TILE_PIXELS),
  149. (TILE_PIXELS, 1),
  150. (1 , 1)
  151. ])
  152. class Lava(WorldObj):
  153. def __init__(self):
  154. super().__init__('lava', 'red')
  155. def can_overlap(self):
  156. return True
  157. def render(self, r):
  158. orange = 255, 128, 0
  159. r.setLineColor(*orange)
  160. r.setColor(*orange)
  161. r.drawPolygon([
  162. (0 , TILE_PIXELS),
  163. (TILE_PIXELS, TILE_PIXELS),
  164. (TILE_PIXELS, 0),
  165. (0 , 0)
  166. ])
  167. # drawing the waves
  168. r.setLineColor(0, 0, 0)
  169. r.drawPolyline([
  170. (.1 * TILE_PIXELS, .3 * TILE_PIXELS),
  171. (.3 * TILE_PIXELS, .4 * TILE_PIXELS),
  172. (.5 * TILE_PIXELS, .3 * TILE_PIXELS),
  173. (.7 * TILE_PIXELS, .4 * TILE_PIXELS),
  174. (.9 * TILE_PIXELS, .3 * TILE_PIXELS),
  175. ])
  176. r.drawPolyline([
  177. (.1 * TILE_PIXELS, .5 * TILE_PIXELS),
  178. (.3 * TILE_PIXELS, .6 * TILE_PIXELS),
  179. (.5 * TILE_PIXELS, .5 * TILE_PIXELS),
  180. (.7 * TILE_PIXELS, .6 * TILE_PIXELS),
  181. (.9 * TILE_PIXELS, .5 * TILE_PIXELS),
  182. ])
  183. r.drawPolyline([
  184. (.1 * TILE_PIXELS, .7 * TILE_PIXELS),
  185. (.3 * TILE_PIXELS, .8 * TILE_PIXELS),
  186. (.5 * TILE_PIXELS, .7 * TILE_PIXELS),
  187. (.7 * TILE_PIXELS, .8 * TILE_PIXELS),
  188. (.9 * TILE_PIXELS, .7 * TILE_PIXELS),
  189. ])
  190. class Wall(WorldObj):
  191. def __init__(self, color='grey'):
  192. super().__init__('wall', color)
  193. def see_behind(self):
  194. return False
  195. def render(self, r):
  196. fill_coords(img, point_in_rect(0.5, 0.5, 0.5, 0.5), COLORS[self.color])
  197. class Door(WorldObj):
  198. def __init__(self, color, is_open=False, is_locked=False):
  199. super().__init__('door', color)
  200. self.is_open = is_open
  201. self.is_locked = is_locked
  202. def can_overlap(self):
  203. """The agent can only walk over this cell when the door is open"""
  204. return self.is_open
  205. def see_behind(self):
  206. return self.is_open
  207. def toggle(self, env, pos):
  208. # If the player has the right key to open the door
  209. if self.is_locked:
  210. if isinstance(env.carrying, Key) and env.carrying.color == self.color:
  211. self.is_locked = False
  212. self.is_open = True
  213. return True
  214. return False
  215. self.is_open = not self.is_open
  216. return True
  217. def encode(self):
  218. """Encode the a description of this object as a 3-tuple of integers"""
  219. # State, 0: open, 1: closed, 2: locked
  220. if self.is_open:
  221. state = 0
  222. elif self.is_locked:
  223. state = 2
  224. elif not self.is_open:
  225. state = 1
  226. return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], state)
  227. def render(self, r):
  228. c = COLORS[self.color]
  229. r.setLineColor(c[0], c[1], c[2])
  230. r.setColor(c[0], c[1], c[2], 50 if self.is_locked else 0)
  231. if self.is_open:
  232. r.drawPolygon([
  233. (TILE_PIXELS-2, TILE_PIXELS),
  234. (TILE_PIXELS , TILE_PIXELS),
  235. (TILE_PIXELS , 0),
  236. (TILE_PIXELS-2, 0)
  237. ])
  238. return
  239. r.drawPolygon([
  240. (0 , TILE_PIXELS),
  241. (TILE_PIXELS, TILE_PIXELS),
  242. (TILE_PIXELS, 0),
  243. (0 , 0)
  244. ])
  245. r.drawPolygon([
  246. (2 , TILE_PIXELS-2),
  247. (TILE_PIXELS-2, TILE_PIXELS-2),
  248. (TILE_PIXELS-2, 2),
  249. (2 , 2)
  250. ])
  251. if self.is_locked:
  252. # Draw key slot
  253. r.drawLine(
  254. TILE_PIXELS * 0.55,
  255. TILE_PIXELS * 0.5,
  256. TILE_PIXELS * 0.75,
  257. TILE_PIXELS * 0.5
  258. )
  259. else:
  260. # Draw door handle
  261. r.drawCircle(TILE_PIXELS * 0.75, TILE_PIXELS * 0.5, 2)
  262. class Key(WorldObj):
  263. def __init__(self, color='blue'):
  264. super(Key, self).__init__('key', color)
  265. def can_pickup(self):
  266. return True
  267. def render(self, r):
  268. self._set_color(r)
  269. # Vertical quad
  270. r.drawPolygon([
  271. (16, 10),
  272. (20, 10),
  273. (20, 28),
  274. (16, 28)
  275. ])
  276. # Teeth
  277. r.drawPolygon([
  278. (12, 19),
  279. (16, 19),
  280. (16, 21),
  281. (12, 21)
  282. ])
  283. r.drawPolygon([
  284. (12, 26),
  285. (16, 26),
  286. (16, 28),
  287. (12, 28)
  288. ])
  289. r.drawCircle(18, 9, 6)
  290. r.setLineColor(0, 0, 0)
  291. r.setColor(0, 0, 0)
  292. r.drawCircle(18, 9, 2)
  293. class Ball(WorldObj):
  294. def __init__(self, color='blue'):
  295. super(Ball, self).__init__('ball', color)
  296. def can_pickup(self):
  297. return True
  298. def render(self, img):
  299. fill_coords(img, point_in_circle(0.5, 0.5, 0.31), COLORS[self.color])
  300. class Box(WorldObj):
  301. def __init__(self, color, contains=None):
  302. super(Box, self).__init__('box', color)
  303. self.contains = contains
  304. def can_pickup(self):
  305. return True
  306. def render(self, r):
  307. c = COLORS[self.color]
  308. r.setLineColor(c[0], c[1], c[2])
  309. r.setColor(0, 0, 0)
  310. r.setLineWidth(2)
  311. r.drawPolygon([
  312. (4 , TILE_PIXELS-4),
  313. (TILE_PIXELS-4, TILE_PIXELS-4),
  314. (TILE_PIXELS-4, 4),
  315. (4 , 4)
  316. ])
  317. r.drawLine(
  318. 4,
  319. TILE_PIXELS / 2,
  320. TILE_PIXELS - 4,
  321. TILE_PIXELS / 2
  322. )
  323. r.setLineWidth(1)
  324. def toggle(self, env, pos):
  325. # Replace the box by its contents
  326. env.grid.set(*pos, self.contains)
  327. return True
  328. class Grid:
  329. """
  330. Represent a grid and operations on it
  331. """
  332. # Static cache of pre-renderer tiles
  333. tile_cache = {}
  334. def __init__(self, width, height):
  335. assert width >= 3
  336. assert height >= 3
  337. self.width = width
  338. self.height = height
  339. self.grid = [None] * width * height
  340. def __contains__(self, key):
  341. if isinstance(key, WorldObj):
  342. for e in self.grid:
  343. if e is key:
  344. return True
  345. elif isinstance(key, tuple):
  346. for e in self.grid:
  347. if e is None:
  348. continue
  349. if (e.color, e.type) == key:
  350. return True
  351. if key[0] is None and key[1] == e.type:
  352. return True
  353. return False
  354. def __eq__(self, other):
  355. grid1 = self.encode()
  356. grid2 = other.encode()
  357. return np.array_equal(grid2, grid1)
  358. def __ne__(self, other):
  359. return not self == other
  360. def copy(self):
  361. from copy import deepcopy
  362. return deepcopy(self)
  363. def set(self, i, j, v):
  364. assert i >= 0 and i < self.width
  365. assert j >= 0 and j < self.height
  366. self.grid[j * self.width + i] = v
  367. def get(self, i, j):
  368. assert i >= 0 and i < self.width
  369. assert j >= 0 and j < self.height
  370. return self.grid[j * self.width + i]
  371. def horz_wall(self, x, y, length=None, obj_type=Wall):
  372. if length is None:
  373. length = self.width - x
  374. for i in range(0, length):
  375. self.set(x + i, y, obj_type())
  376. def vert_wall(self, x, y, length=None, obj_type=Wall):
  377. if length is None:
  378. length = self.height - y
  379. for j in range(0, length):
  380. self.set(x, y + j, obj_type())
  381. def wall_rect(self, x, y, w, h):
  382. self.horz_wall(x, y, w)
  383. self.horz_wall(x, y+h-1, w)
  384. self.vert_wall(x, y, h)
  385. self.vert_wall(x+w-1, y, h)
  386. def rotate_left(self):
  387. """
  388. Rotate the grid to the left (counter-clockwise)
  389. """
  390. grid = Grid(self.height, self.width)
  391. for i in range(self.width):
  392. for j in range(self.height):
  393. v = self.get(i, j)
  394. grid.set(j, grid.height - 1 - i, v)
  395. return grid
  396. def slice(self, topX, topY, width, height):
  397. """
  398. Get a subset of the grid
  399. """
  400. grid = Grid(width, height)
  401. for j in range(0, height):
  402. for i in range(0, width):
  403. x = topX + i
  404. y = topY + j
  405. if x >= 0 and x < self.width and \
  406. y >= 0 and y < self.height:
  407. v = self.get(x, y)
  408. else:
  409. v = Wall()
  410. grid.set(i, j, v)
  411. return grid
  412. @classmethod
  413. def render_tile(
  414. cls,
  415. obj,
  416. agent_dir=None,
  417. highlight=False,
  418. tile_size=TILE_PIXELS
  419. ):
  420. """
  421. Render a tile and cache the result
  422. """
  423. # Hash map lookup key for the cache
  424. key = obj.encode() + (agent_dir, highlight, tile_size)
  425. if key in cls.tile_cache:
  426. return tile_cache[key]
  427. img = np.zeros(shape=(tile_size, tile_size, 3), dtype=np.uint8)
  428. obj.render_tile(img)
  429. # TODO: overlay agent on top
  430. if agent_dir is not None:
  431. pass
  432. # TODO: highlighting
  433. if highlight:
  434. pass
  435. # Cache the rendered tile
  436. tile_cache[key] = img
  437. return img
  438. def render(self, tile_size):
  439. """
  440. Render this grid at a given scale
  441. :param r: target renderer object
  442. :param tile_size: tile size in pixels
  443. """
  444. # Compute the total grid size
  445. width_px = self.width * TILE_PIXELS
  446. height_px = self.height * TILE_PIXELS
  447. img = np.zeros(shape=(height_px, width_px, 3), dtype=np.uint8)
  448. """
  449. # Draw grid lines
  450. r.setLineColor(100, 100, 100)
  451. for rowIdx in range(0, self.height):
  452. y = TILE_PIXELS * rowIdx
  453. r.drawLine(0, y, widthPx, y)
  454. for colIdx in range(0, self.width):
  455. x = TILE_PIXELS * colIdx
  456. r.drawLine(x, 0, x, heightPx)
  457. """
  458. # Render the grid
  459. for j in range(0, self.height):
  460. for i in range(0, self.width):
  461. cell = self.get(i, j)
  462. if cell == None:
  463. continue
  464. """
  465. tile_img = Grid.render_tile(
  466. cell,
  467. agent_dir=None,
  468. highlight=False,
  469. tile_size=tile_size
  470. )
  471. """
  472. return img
  473. def encode(self, vis_mask=None):
  474. """
  475. Produce a compact numpy encoding of the grid
  476. """
  477. if vis_mask is None:
  478. vis_mask = np.ones((self.width, self.height), dtype=bool)
  479. array = np.zeros((self.width, self.height, 3), dtype='uint8')
  480. for i in range(self.width):
  481. for j in range(self.height):
  482. if vis_mask[i, j]:
  483. v = self.get(i, j)
  484. if v is None:
  485. array[i, j, 0] = OBJECT_TO_IDX['empty']
  486. array[i, j, 1] = 0
  487. array[i, j, 2] = 0
  488. else:
  489. array[i, j, :] = v.encode()
  490. return array
  491. @staticmethod
  492. def decode(array):
  493. """
  494. Decode an array grid encoding back into a grid
  495. """
  496. width, height, channels = array.shape
  497. assert channels == 3
  498. grid = Grid(width, height)
  499. for i in range(width):
  500. for j in range(height):
  501. type_idx, color_idx, state = array[i, j]
  502. v = WorldObj.decode(type_idx, color_idx, state)
  503. grid.set(i, j, v)
  504. return grid
  505. def process_vis(grid, agent_pos):
  506. mask = np.zeros(shape=(grid.width, grid.height), dtype=np.bool)
  507. mask[agent_pos[0], agent_pos[1]] = True
  508. for j in reversed(range(0, grid.height)):
  509. for i in range(0, grid.width-1):
  510. if not mask[i, j]:
  511. continue
  512. cell = grid.get(i, j)
  513. if cell and not cell.see_behind():
  514. continue
  515. mask[i+1, j] = True
  516. if j > 0:
  517. mask[i+1, j-1] = True
  518. mask[i, j-1] = True
  519. for i in reversed(range(1, grid.width)):
  520. if not mask[i, j]:
  521. continue
  522. cell = grid.get(i, j)
  523. if cell and not cell.see_behind():
  524. continue
  525. mask[i-1, j] = True
  526. if j > 0:
  527. mask[i-1, j-1] = True
  528. mask[i, j-1] = True
  529. for j in range(0, grid.height):
  530. for i in range(0, grid.width):
  531. if not mask[i, j]:
  532. grid.set(i, j, None)
  533. return mask
  534. class MiniGridEnv(gym.Env):
  535. """
  536. 2D grid world game environment
  537. """
  538. metadata = {
  539. 'render.modes': ['human', 'rgb_array', 'pixmap'],
  540. 'video.frames_per_second' : 10
  541. }
  542. # Enumeration of possible actions
  543. class Actions(IntEnum):
  544. # Turn left, turn right, move forward
  545. left = 0
  546. right = 1
  547. forward = 2
  548. # Pick up an object
  549. pickup = 3
  550. # Drop an object
  551. drop = 4
  552. # Toggle/activate an object
  553. toggle = 5
  554. # Done completing task
  555. done = 6
  556. def __init__(
  557. self,
  558. grid_size=None,
  559. width=None,
  560. height=None,
  561. max_steps=100,
  562. see_through_walls=False,
  563. seed=1337,
  564. agent_view_size=7
  565. ):
  566. # Can't set both grid_size and width/height
  567. if grid_size:
  568. assert width == None and height == None
  569. width = grid_size
  570. height = grid_size
  571. # Action enumeration for this environment
  572. self.actions = MiniGridEnv.Actions
  573. # Actions are discrete integer values
  574. self.action_space = spaces.Discrete(len(self.actions))
  575. # Number of cells (width and height) in the agent view
  576. self.agent_view_size = agent_view_size
  577. # Observations are dictionaries containing an
  578. # encoding of the grid and a textual 'mission' string
  579. self.observation_space = spaces.Box(
  580. low=0,
  581. high=255,
  582. shape=(self.agent_view_size, self.agent_view_size, 3),
  583. dtype='uint8'
  584. )
  585. self.observation_space = spaces.Dict({
  586. 'image': self.observation_space
  587. })
  588. # Range of possible rewards
  589. self.reward_range = (0, 1)
  590. # Renderer object used to render the whole grid (full-scale)
  591. self.grid_render = None
  592. # Renderer used to render observations (small-scale agent view)
  593. self.obs_render = None
  594. # Environment configuration
  595. self.width = width
  596. self.height = height
  597. self.max_steps = max_steps
  598. self.see_through_walls = see_through_walls
  599. # Current position and direction of the agent
  600. self.agent_pos = None
  601. self.agent_dir = None
  602. # Initialize the RNG
  603. self.seed(seed=seed)
  604. # Initialize the state
  605. self.reset()
  606. def reset(self):
  607. # Current position and direction of the agent
  608. self.agent_pos = None
  609. self.agent_dir = None
  610. # Generate a new random grid at the start of each episode
  611. # To keep the same grid for each episode, call env.seed() with
  612. # the same seed before calling env.reset()
  613. self._gen_grid(self.width, self.height)
  614. # These fields should be defined by _gen_grid
  615. assert self.agent_pos is not None
  616. assert self.agent_dir is not None
  617. # Check that the agent doesn't overlap with an object
  618. start_cell = self.grid.get(*self.agent_pos)
  619. assert start_cell is None or start_cell.can_overlap()
  620. # Item picked up, being carried, initially nothing
  621. self.carrying = None
  622. # Step count since episode start
  623. self.step_count = 0
  624. # Return first observation
  625. obs = self.gen_obs()
  626. return obs
  627. def seed(self, seed=1337):
  628. # Seed the random number generator
  629. self.np_random, _ = seeding.np_random(seed)
  630. return [seed]
  631. @property
  632. def steps_remaining(self):
  633. return self.max_steps - self.step_count
  634. def __str__(self):
  635. """
  636. Produce a pretty string of the environment's grid along with the agent.
  637. A grid cell is represented by 2-character string, the first one for
  638. the object and the second one for the color.
  639. """
  640. # Map of object types to short string
  641. OBJECT_TO_STR = {
  642. 'wall' : 'W',
  643. 'floor' : 'F',
  644. 'door' : 'D',
  645. 'key' : 'K',
  646. 'ball' : 'A',
  647. 'box' : 'B',
  648. 'goal' : 'G',
  649. 'lava' : 'V',
  650. }
  651. # Short string for opened door
  652. OPENDED_DOOR_IDS = '_'
  653. # Map agent's direction to short string
  654. AGENT_DIR_TO_STR = {
  655. 0: '>',
  656. 1: 'V',
  657. 2: '<',
  658. 3: '^'
  659. }
  660. str = ''
  661. for j in range(self.grid.height):
  662. for i in range(self.grid.width):
  663. if i == self.agent_pos[0] and j == self.agent_pos[1]:
  664. str += 2 * AGENT_DIR_TO_STR[self.agent_dir]
  665. continue
  666. c = self.grid.get(i, j)
  667. if c == None:
  668. str += ' '
  669. continue
  670. if c.type == 'door':
  671. if c.is_open:
  672. str += '__'
  673. elif c.is_locked:
  674. str += 'L' + c.color[0].upper()
  675. else:
  676. str += 'D' + c.color[0].upper()
  677. continue
  678. str += OBJECT_TO_STR[c.type] + c.color[0].upper()
  679. if j < self.grid.height - 1:
  680. str += '\n'
  681. return str
  682. def _gen_grid(self, width, height):
  683. assert False, "_gen_grid needs to be implemented by each environment"
  684. def _reward(self):
  685. """
  686. Compute the reward to be given upon success
  687. """
  688. return 1 - 0.9 * (self.step_count / self.max_steps)
  689. def _rand_int(self, low, high):
  690. """
  691. Generate random integer in [low,high[
  692. """
  693. return self.np_random.randint(low, high)
  694. def _rand_float(self, low, high):
  695. """
  696. Generate random float in [low,high[
  697. """
  698. return self.np_random.uniform(low, high)
  699. def _rand_bool(self):
  700. """
  701. Generate random boolean value
  702. """
  703. return (self.np_random.randint(0, 2) == 0)
  704. def _rand_elem(self, iterable):
  705. """
  706. Pick a random element in a list
  707. """
  708. lst = list(iterable)
  709. idx = self._rand_int(0, len(lst))
  710. return lst[idx]
  711. def _rand_subset(self, iterable, num_elems):
  712. """
  713. Sample a random subset of distinct elements of a list
  714. """
  715. lst = list(iterable)
  716. assert num_elems <= len(lst)
  717. out = []
  718. while len(out) < num_elems:
  719. elem = self._rand_elem(lst)
  720. lst.remove(elem)
  721. out.append(elem)
  722. return out
  723. def _rand_color(self):
  724. """
  725. Generate a random color name (string)
  726. """
  727. return self._rand_elem(COLOR_NAMES)
  728. def _rand_pos(self, xLow, xHigh, yLow, yHigh):
  729. """
  730. Generate a random (x,y) position tuple
  731. """
  732. return (
  733. self.np_random.randint(xLow, xHigh),
  734. self.np_random.randint(yLow, yHigh)
  735. )
  736. def place_obj(self,
  737. obj,
  738. top=None,
  739. size=None,
  740. reject_fn=None,
  741. max_tries=math.inf
  742. ):
  743. """
  744. Place an object at an empty position in the grid
  745. :param top: top-left position of the rectangle where to place
  746. :param size: size of the rectangle where to place
  747. :param reject_fn: function to filter out potential positions
  748. """
  749. if top is None:
  750. top = (0, 0)
  751. else:
  752. top = (max(top[0], 0), max(top[1], 0))
  753. if size is None:
  754. size = (self.grid.width, self.grid.height)
  755. num_tries = 0
  756. while True:
  757. # This is to handle with rare cases where rejection sampling
  758. # gets stuck in an infinite loop
  759. if num_tries > max_tries:
  760. raise RecursionError('rejection sampling failed in place_obj')
  761. num_tries += 1
  762. pos = np.array((
  763. self._rand_int(top[0], min(top[0] + size[0], self.grid.width)),
  764. self._rand_int(top[1], min(top[1] + size[1], self.grid.height))
  765. ))
  766. # Don't place the object on top of another object
  767. if self.grid.get(*pos) != None:
  768. continue
  769. # Don't place the object where the agent is
  770. if np.array_equal(pos, self.agent_pos):
  771. continue
  772. # Check if there is a filtering criterion
  773. if reject_fn and reject_fn(self, pos):
  774. continue
  775. break
  776. self.grid.set(*pos, obj)
  777. if obj is not None:
  778. obj.init_pos = pos
  779. obj.cur_pos = pos
  780. return pos
  781. def put_obj(self, obj, i, j):
  782. """
  783. Put an object at a specific position in the grid
  784. """
  785. self.grid.set(i, j, obj)
  786. obj.init_pos = (i, j)
  787. obj.cur_pos = (i, j)
  788. def place_agent(
  789. self,
  790. top=None,
  791. size=None,
  792. rand_dir=True,
  793. max_tries=math.inf
  794. ):
  795. """
  796. Set the agent's starting point at an empty position in the grid
  797. """
  798. self.agent_pos = None
  799. pos = self.place_obj(None, top, size, max_tries=max_tries)
  800. self.agent_pos = pos
  801. if rand_dir:
  802. self.agent_dir = self._rand_int(0, 4)
  803. return pos
  804. @property
  805. def dir_vec(self):
  806. """
  807. Get the direction vector for the agent, pointing in the direction
  808. of forward movement.
  809. """
  810. assert self.agent_dir >= 0 and self.agent_dir < 4
  811. return DIR_TO_VEC[self.agent_dir]
  812. @property
  813. def right_vec(self):
  814. """
  815. Get the vector pointing to the right of the agent.
  816. """
  817. dx, dy = self.dir_vec
  818. return np.array((-dy, dx))
  819. @property
  820. def front_pos(self):
  821. """
  822. Get the position of the cell that is right in front of the agent
  823. """
  824. return self.agent_pos + self.dir_vec
  825. def get_view_coords(self, i, j):
  826. """
  827. Translate and rotate absolute grid coordinates (i, j) into the
  828. agent's partially observable view (sub-grid). Note that the resulting
  829. coordinates may be negative or outside of the agent's view size.
  830. """
  831. ax, ay = self.agent_pos
  832. dx, dy = self.dir_vec
  833. rx, ry = self.right_vec
  834. # Compute the absolute coordinates of the top-left view corner
  835. sz = self.agent_view_size
  836. hs = self.agent_view_size // 2
  837. tx = ax + (dx * (sz-1)) - (rx * hs)
  838. ty = ay + (dy * (sz-1)) - (ry * hs)
  839. lx = i - tx
  840. ly = j - ty
  841. # Project the coordinates of the object relative to the top-left
  842. # corner onto the agent's own coordinate system
  843. vx = (rx*lx + ry*ly)
  844. vy = -(dx*lx + dy*ly)
  845. return vx, vy
  846. def get_view_exts(self):
  847. """
  848. Get the extents of the square set of tiles visible to the agent
  849. Note: the bottom extent indices are not included in the set
  850. """
  851. # Facing right
  852. if self.agent_dir == 0:
  853. topX = self.agent_pos[0]
  854. topY = self.agent_pos[1] - self.agent_view_size // 2
  855. # Facing down
  856. elif self.agent_dir == 1:
  857. topX = self.agent_pos[0] - self.agent_view_size // 2
  858. topY = self.agent_pos[1]
  859. # Facing left
  860. elif self.agent_dir == 2:
  861. topX = self.agent_pos[0] - self.agent_view_size + 1
  862. topY = self.agent_pos[1] - self.agent_view_size // 2
  863. # Facing up
  864. elif self.agent_dir == 3:
  865. topX = self.agent_pos[0] - self.agent_view_size // 2
  866. topY = self.agent_pos[1] - self.agent_view_size + 1
  867. else:
  868. assert False, "invalid agent direction"
  869. botX = topX + self.agent_view_size
  870. botY = topY + self.agent_view_size
  871. return (topX, topY, botX, botY)
  872. def relative_coords(self, x, y):
  873. """
  874. Check if a grid position belongs to the agent's field of view, and returns the corresponding coordinates
  875. """
  876. vx, vy = self.get_view_coords(x, y)
  877. if vx < 0 or vy < 0 or vx >= self.agent_view_size or vy >= self.agent_view_size:
  878. return None
  879. return vx, vy
  880. def in_view(self, x, y):
  881. """
  882. check if a grid position is visible to the agent
  883. """
  884. return self.relative_coords(x, y) is not None
  885. def agent_sees(self, x, y):
  886. """
  887. Check if a non-empty grid position is visible to the agent
  888. """
  889. coordinates = self.relative_coords(x, y)
  890. if coordinates is None:
  891. return False
  892. vx, vy = coordinates
  893. obs = self.gen_obs()
  894. obs_grid = Grid.decode(obs['image'])
  895. obs_cell = obs_grid.get(vx, vy)
  896. world_cell = self.grid.get(x, y)
  897. return obs_cell is not None and obs_cell.type == world_cell.type
  898. def step(self, action):
  899. self.step_count += 1
  900. reward = 0
  901. done = False
  902. # Get the position in front of the agent
  903. fwd_pos = self.front_pos
  904. # Get the contents of the cell in front of the agent
  905. fwd_cell = self.grid.get(*fwd_pos)
  906. # Rotate left
  907. if action == self.actions.left:
  908. self.agent_dir -= 1
  909. if self.agent_dir < 0:
  910. self.agent_dir += 4
  911. # Rotate right
  912. elif action == self.actions.right:
  913. self.agent_dir = (self.agent_dir + 1) % 4
  914. # Move forward
  915. elif action == self.actions.forward:
  916. if fwd_cell == None or fwd_cell.can_overlap():
  917. self.agent_pos = fwd_pos
  918. if fwd_cell != None and fwd_cell.type == 'goal':
  919. done = True
  920. reward = self._reward()
  921. if fwd_cell != None and fwd_cell.type == 'lava':
  922. done = True
  923. # Pick up an object
  924. elif action == self.actions.pickup:
  925. if fwd_cell and fwd_cell.can_pickup():
  926. if self.carrying is None:
  927. self.carrying = fwd_cell
  928. self.carrying.cur_pos = np.array([-1, -1])
  929. self.grid.set(*fwd_pos, None)
  930. # Drop an object
  931. elif action == self.actions.drop:
  932. if not fwd_cell and self.carrying:
  933. self.grid.set(*fwd_pos, self.carrying)
  934. self.carrying.cur_pos = fwd_pos
  935. self.carrying = None
  936. # Toggle/activate an object
  937. elif action == self.actions.toggle:
  938. if fwd_cell:
  939. fwd_cell.toggle(self, fwd_pos)
  940. # Done action (not used by default)
  941. elif action == self.actions.done:
  942. pass
  943. else:
  944. assert False, "unknown action"
  945. if self.step_count >= self.max_steps:
  946. done = True
  947. obs = self.gen_obs()
  948. return obs, reward, done, {}
  949. def gen_obs_grid(self):
  950. """
  951. Generate the sub-grid observed by the agent.
  952. This method also outputs a visibility mask telling us which grid
  953. cells the agent can actually see.
  954. """
  955. topX, topY, botX, botY = self.get_view_exts()
  956. grid = self.grid.slice(topX, topY, self.agent_view_size, self.agent_view_size)
  957. for i in range(self.agent_dir + 1):
  958. grid = grid.rotate_left()
  959. # Process occluders and visibility
  960. # Note that this incurs some performance cost
  961. if not self.see_through_walls:
  962. vis_mask = grid.process_vis(agent_pos=(self.agent_view_size // 2 , self.agent_view_size - 1))
  963. else:
  964. vis_mask = np.ones(shape=(grid.width, grid.height), dtype=np.bool)
  965. # Make it so the agent sees what it's carrying
  966. # We do this by placing the carried object at the agent's position
  967. # in the agent's partially observable view
  968. agent_pos = grid.width // 2, grid.height - 1
  969. if self.carrying:
  970. grid.set(*agent_pos, self.carrying)
  971. else:
  972. grid.set(*agent_pos, None)
  973. return grid, vis_mask
  974. def gen_obs(self):
  975. """
  976. Generate the agent's view (partially observable, low-resolution encoding)
  977. """
  978. grid, vis_mask = self.gen_obs_grid()
  979. # Encode the partially observable view into a numpy array
  980. image = grid.encode(vis_mask)
  981. assert hasattr(self, 'mission'), "environments must define a textual mission string"
  982. # Observations are dictionaries containing:
  983. # - an image (partially observable view of the environment)
  984. # - the agent's direction/orientation (acting as a compass)
  985. # - a textual mission string (instructions for the agent)
  986. obs = {
  987. 'image': image,
  988. 'direction': self.agent_dir,
  989. 'mission': self.mission
  990. }
  991. return obs
  992. def get_obs_render(self, obs, tile_size=TILE_PIXELS//2, mode='pixmap'):
  993. """
  994. Render an agent observation for visualization
  995. """
  996. grid = Grid.decode(obs)
  997. # Render the whole grid
  998. img = grid.render(r, tile_size)
  999. """
  1000. # Draw the agent
  1001. ratio = tile_size / TILE_PIXELS
  1002. r.push()
  1003. r.scale(ratio, ratio)
  1004. r.translate(
  1005. TILE_PIXELS * (0.5 + self.agent_view_size // 2),
  1006. TILE_PIXELS * (self.agent_view_size - 0.5)
  1007. )
  1008. r.rotate(3 * 90)
  1009. r.setLineColor(255, 0, 0)
  1010. r.setColor(255, 0, 0)
  1011. r.drawPolygon([
  1012. (-12, 10),
  1013. ( 12, 0),
  1014. (-12, -10)
  1015. ])
  1016. r.pop()
  1017. if mode == 'rgb_array':
  1018. return r.getArray()
  1019. elif mode == 'pixmap':
  1020. return r.getPixmap()
  1021. return r
  1022. """
  1023. def render(self, mode='human', close=False, highlight=True, tile_size=TILE_PIXELS):
  1024. """
  1025. Render the whole-grid human view
  1026. """
  1027. """
  1028. if close:
  1029. if self.grid_render:
  1030. self.grid_render.close()
  1031. return
  1032. """
  1033. # Render the whole grid
  1034. img = self.grid.render(tile_size)
  1035. """
  1036. # Draw the agent
  1037. ratio = tile_size / TILE_PIXELS
  1038. r.push()
  1039. r.scale(ratio, ratio)
  1040. r.translate(
  1041. TILE_PIXELS * (self.agent_pos[0] + 0.5),
  1042. TILE_PIXELS * (self.agent_pos[1] + 0.5)
  1043. )
  1044. r.rotate(self.agent_dir * 90)
  1045. r.setLineColor(255, 0, 0)
  1046. r.setColor(255, 0, 0)
  1047. r.drawPolygon([
  1048. (-12, 10),
  1049. ( 12, 0),
  1050. (-12, -10)
  1051. ])
  1052. r.pop()
  1053. """
  1054. """
  1055. # Compute which cells are visible to the agent
  1056. _, vis_mask = self.gen_obs_grid()
  1057. # Compute the absolute coordinates of the bottom-left corner
  1058. # of the agent's view area
  1059. f_vec = self.dir_vec
  1060. r_vec = self.right_vec
  1061. top_left = self.agent_pos + f_vec * (self.agent_view_size-1) - r_vec * (self.agent_view_size // 2)
  1062. # For each cell in the visibility mask
  1063. if highlight:
  1064. for vis_j in range(0, self.agent_view_size):
  1065. for vis_i in range(0, self.agent_view_size):
  1066. # If this cell is not visible, don't highlight it
  1067. if not vis_mask[vis_i, vis_j]:
  1068. continue
  1069. # Compute the world coordinates of this cell
  1070. abs_i, abs_j = top_left - (f_vec * vis_j) + (r_vec * vis_i)
  1071. # Highlight the cell
  1072. r.fillRect(
  1073. abs_i * tile_size,
  1074. abs_j * tile_size,
  1075. tile_size,
  1076. tile_size,
  1077. 255, 255, 255, 75
  1078. )
  1079. """
  1080. """
  1081. if mode == 'rgb_array':
  1082. return r.getArray()
  1083. elif mode == 'pixmap':
  1084. return r.getPixmap()
  1085. """
  1086. return img