minigrid.py 36 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316
  1. import math
  2. import gym
  3. from enum import IntEnum
  4. import numpy as np
  5. from gym import error, spaces, utils
  6. from gym.utils import seeding
  7. from .rendering import *
  8. # Size in pixels of a tile in the full-scale human view
  9. TILE_PIXELS = 32
  10. # Map of color names to RGB values
  11. COLORS = {
  12. 'red' : np.array([255, 0, 0]),
  13. 'green' : np.array([0, 255, 0]),
  14. 'blue' : np.array([0, 0, 255]),
  15. 'purple': np.array([112, 39, 195]),
  16. 'yellow': np.array([255, 255, 0]),
  17. 'grey' : np.array([100, 100, 100])
  18. }
  19. COLOR_NAMES = sorted(list(COLORS.keys()))
  20. # Used to map colors to integers
  21. COLOR_TO_IDX = {
  22. 'red' : 0,
  23. 'green' : 1,
  24. 'blue' : 2,
  25. 'purple': 3,
  26. 'yellow': 4,
  27. 'grey' : 5
  28. }
  29. IDX_TO_COLOR = dict(zip(COLOR_TO_IDX.values(), COLOR_TO_IDX.keys()))
  30. # Map of object type to integers
  31. OBJECT_TO_IDX = {
  32. 'unseen' : 0,
  33. 'empty' : 1,
  34. 'wall' : 2,
  35. 'floor' : 3,
  36. 'door' : 4,
  37. 'key' : 5,
  38. 'ball' : 6,
  39. 'box' : 7,
  40. 'goal' : 8,
  41. 'lava' : 9,
  42. 'agent' : 10,
  43. }
  44. IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys()))
  45. # Map of state names to integers
  46. STATE_TO_IDX = {
  47. 'open' : 0,
  48. 'closed': 1,
  49. 'locked': 2,
  50. }
  51. # Map of agent direction indices to vectors
  52. DIR_TO_VEC = [
  53. # Pointing right (positive X)
  54. np.array((1, 0)),
  55. # Down (positive Y)
  56. np.array((0, 1)),
  57. # Pointing left (negative X)
  58. np.array((-1, 0)),
  59. # Up (negative Y)
  60. np.array((0, -1)),
  61. ]
  62. class WorldObj:
  63. """
  64. Base class for grid world objects
  65. """
  66. def __init__(self, type, color):
  67. assert type in OBJECT_TO_IDX, type
  68. assert color in COLOR_TO_IDX, color
  69. self.type = type
  70. self.color = color
  71. self.contains = None
  72. # Initial position of the object
  73. self.init_pos = None
  74. # Current position of the object
  75. self.cur_pos = None
  76. def can_overlap(self):
  77. """Can the agent overlap with this?"""
  78. return False
  79. def can_pickup(self):
  80. """Can the agent pick this up?"""
  81. return False
  82. def can_contain(self):
  83. """Can this contain another object?"""
  84. return False
  85. def see_behind(self):
  86. """Can the agent see behind this object?"""
  87. return True
  88. def toggle(self, env, pos):
  89. """Method to trigger/toggle an action this object performs"""
  90. return False
  91. def encode(self):
  92. """Encode the a description of this object as a 3-tuple of integers"""
  93. return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], 0)
  94. @staticmethod
  95. def decode(type_idx, color_idx, state):
  96. """Create an object from a 3-tuple state description"""
  97. obj_type = IDX_TO_OBJECT[type_idx]
  98. color = IDX_TO_COLOR[color_idx]
  99. if obj_type == 'empty' or obj_type == 'unseen':
  100. return None
  101. # State, 0: open, 1: closed, 2: locked
  102. is_open = state == 0
  103. is_locked = state == 2
  104. if obj_type == 'wall':
  105. v = Wall(color)
  106. elif obj_type == 'floor':
  107. v = Floor(color)
  108. elif obj_type == 'ball':
  109. v = Ball(color)
  110. elif obj_type == 'key':
  111. v = Key(color)
  112. elif obj_type == 'box':
  113. v = Box(color)
  114. elif obj_type == 'door':
  115. v = Door(color, is_open, is_locked)
  116. elif obj_type == 'goal':
  117. v = Goal()
  118. elif obj_type == 'lava':
  119. v = Lava()
  120. else:
  121. assert False, "unknown object type in decode '%s'" % objType
  122. return v
  123. def render(self, r):
  124. """Draw this object with the given renderer"""
  125. raise NotImplementedError
  126. class Goal(WorldObj):
  127. def __init__(self):
  128. super().__init__('goal', 'green')
  129. def can_overlap(self):
  130. return True
  131. def render(self, img):
  132. fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color])
  133. class Floor(WorldObj):
  134. """
  135. Colored floor tile the agent can walk over
  136. """
  137. def __init__(self, color='blue'):
  138. super().__init__('floor', color)
  139. def can_overlap(self):
  140. return True
  141. def render(self, r):
  142. # Give the floor a pale color
  143. c = COLORS[self.color]
  144. r.setLineColor(100, 100, 100, 0)
  145. r.setColor(*c/2)
  146. r.drawPolygon([
  147. (1 , TILE_PIXELS),
  148. (TILE_PIXELS, TILE_PIXELS),
  149. (TILE_PIXELS, 1),
  150. (1 , 1)
  151. ])
  152. class Lava(WorldObj):
  153. def __init__(self):
  154. super().__init__('lava', 'red')
  155. def can_overlap(self):
  156. return True
  157. def render(self, r):
  158. orange = 255, 128, 0
  159. r.setLineColor(*orange)
  160. r.setColor(*orange)
  161. r.drawPolygon([
  162. (0 , TILE_PIXELS),
  163. (TILE_PIXELS, TILE_PIXELS),
  164. (TILE_PIXELS, 0),
  165. (0 , 0)
  166. ])
  167. # drawing the waves
  168. r.setLineColor(0, 0, 0)
  169. r.drawPolyline([
  170. (.1 * TILE_PIXELS, .3 * TILE_PIXELS),
  171. (.3 * TILE_PIXELS, .4 * TILE_PIXELS),
  172. (.5 * TILE_PIXELS, .3 * TILE_PIXELS),
  173. (.7 * TILE_PIXELS, .4 * TILE_PIXELS),
  174. (.9 * TILE_PIXELS, .3 * TILE_PIXELS),
  175. ])
  176. r.drawPolyline([
  177. (.1 * TILE_PIXELS, .5 * TILE_PIXELS),
  178. (.3 * TILE_PIXELS, .6 * TILE_PIXELS),
  179. (.5 * TILE_PIXELS, .5 * TILE_PIXELS),
  180. (.7 * TILE_PIXELS, .6 * TILE_PIXELS),
  181. (.9 * TILE_PIXELS, .5 * TILE_PIXELS),
  182. ])
  183. r.drawPolyline([
  184. (.1 * TILE_PIXELS, .7 * TILE_PIXELS),
  185. (.3 * TILE_PIXELS, .8 * TILE_PIXELS),
  186. (.5 * TILE_PIXELS, .7 * TILE_PIXELS),
  187. (.7 * TILE_PIXELS, .8 * TILE_PIXELS),
  188. (.9 * TILE_PIXELS, .7 * TILE_PIXELS),
  189. ])
  190. class Wall(WorldObj):
  191. def __init__(self, color='grey'):
  192. super().__init__('wall', color)
  193. def see_behind(self):
  194. return False
  195. def render(self, img):
  196. fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color])
  197. class Door(WorldObj):
  198. def __init__(self, color, is_open=False, is_locked=False):
  199. super().__init__('door', color)
  200. self.is_open = is_open
  201. self.is_locked = is_locked
  202. def can_overlap(self):
  203. """The agent can only walk over this cell when the door is open"""
  204. return self.is_open
  205. def see_behind(self):
  206. return self.is_open
  207. def toggle(self, env, pos):
  208. # If the player has the right key to open the door
  209. if self.is_locked:
  210. if isinstance(env.carrying, Key) and env.carrying.color == self.color:
  211. self.is_locked = False
  212. self.is_open = True
  213. return True
  214. return False
  215. self.is_open = not self.is_open
  216. return True
  217. def encode(self):
  218. """Encode the a description of this object as a 3-tuple of integers"""
  219. # State, 0: open, 1: closed, 2: locked
  220. if self.is_open:
  221. state = 0
  222. elif self.is_locked:
  223. state = 2
  224. elif not self.is_open:
  225. state = 1
  226. return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], state)
  227. def render(self, img):
  228. c = COLORS[self.color]
  229. if self.is_open:
  230. fill_coords(img, point_in_rect(0.88, 1.00, 0.00, 1.00), c)
  231. fill_coords(img, point_in_rect(0.92, 0.96, 0.04, 0.96), (0,0,0))
  232. return
  233. # Door frame and door
  234. if self.is_locked:
  235. fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)
  236. fill_coords(img, point_in_rect(0.06, 0.94, 0.06, 0.94), 0.45 * np.array(c))
  237. # Draw key slot
  238. fill_coords(img, point_in_rect(0.52, 0.75, 0.50, 0.56), c)
  239. else:
  240. fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)
  241. fill_coords(img, point_in_rect(0.04, 0.96, 0.04, 0.96), (0,0,0))
  242. fill_coords(img, point_in_rect(0.08, 0.92, 0.08, 0.92), c)
  243. fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), (0,0,0))
  244. # Draw door handle
  245. fill_coords(img, point_in_circle(cx=0.75, cy=0.50, r=0.08), c)
  246. class Key(WorldObj):
  247. def __init__(self, color='blue'):
  248. super(Key, self).__init__('key', color)
  249. def can_pickup(self):
  250. return True
  251. def render(self, img):
  252. c = COLORS[self.color]
  253. # Vertical quad
  254. fill_coords(img, point_in_rect(0.50, 0.63, 0.31, 0.89), c)
  255. # Teeth
  256. fill_coords(img, point_in_rect(0.38, 0.50, 0.59, 0.66), c)
  257. fill_coords(img, point_in_rect(0.38, 0.50, 0.81, 0.88), c)
  258. # Ring
  259. fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.190), c)
  260. fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.064), (0,0,0))
  261. class Ball(WorldObj):
  262. def __init__(self, color='blue'):
  263. super(Ball, self).__init__('ball', color)
  264. def can_pickup(self):
  265. return True
  266. def render(self, img):
  267. fill_coords(img, point_in_circle(0.5, 0.5, 0.31), COLORS[self.color])
  268. class Box(WorldObj):
  269. def __init__(self, color, contains=None):
  270. super(Box, self).__init__('box', color)
  271. self.contains = contains
  272. def can_pickup(self):
  273. return True
  274. def render(self, img):
  275. c = COLORS[self.color]
  276. # Outline
  277. fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), c)
  278. fill_coords(img, point_in_rect(0.18, 0.82, 0.18, 0.82), (0,0,0))
  279. # Vertical slit
  280. fill_coords(img, point_in_rect(0.48, 0.52, 0.16, 0.84), c)
  281. def toggle(self, env, pos):
  282. # Replace the box by its contents
  283. env.grid.set(*pos, self.contains)
  284. return True
  285. class Grid:
  286. """
  287. Represent a grid and operations on it
  288. """
  289. # Static cache of pre-renderer tiles
  290. tile_cache = {}
  291. def __init__(self, width, height):
  292. assert width >= 3
  293. assert height >= 3
  294. self.width = width
  295. self.height = height
  296. self.grid = [None] * width * height
  297. def __contains__(self, key):
  298. if isinstance(key, WorldObj):
  299. for e in self.grid:
  300. if e is key:
  301. return True
  302. elif isinstance(key, tuple):
  303. for e in self.grid:
  304. if e is None:
  305. continue
  306. if (e.color, e.type) == key:
  307. return True
  308. if key[0] is None and key[1] == e.type:
  309. return True
  310. return False
  311. def __eq__(self, other):
  312. grid1 = self.encode()
  313. grid2 = other.encode()
  314. return np.array_equal(grid2, grid1)
  315. def __ne__(self, other):
  316. return not self == other
  317. def copy(self):
  318. from copy import deepcopy
  319. return deepcopy(self)
  320. def set(self, i, j, v):
  321. assert i >= 0 and i < self.width
  322. assert j >= 0 and j < self.height
  323. self.grid[j * self.width + i] = v
  324. def get(self, i, j):
  325. assert i >= 0 and i < self.width
  326. assert j >= 0 and j < self.height
  327. return self.grid[j * self.width + i]
  328. def horz_wall(self, x, y, length=None, obj_type=Wall):
  329. if length is None:
  330. length = self.width - x
  331. for i in range(0, length):
  332. self.set(x + i, y, obj_type())
  333. def vert_wall(self, x, y, length=None, obj_type=Wall):
  334. if length is None:
  335. length = self.height - y
  336. for j in range(0, length):
  337. self.set(x, y + j, obj_type())
  338. def wall_rect(self, x, y, w, h):
  339. self.horz_wall(x, y, w)
  340. self.horz_wall(x, y+h-1, w)
  341. self.vert_wall(x, y, h)
  342. self.vert_wall(x+w-1, y, h)
  343. def rotate_left(self):
  344. """
  345. Rotate the grid to the left (counter-clockwise)
  346. """
  347. grid = Grid(self.height, self.width)
  348. for i in range(self.width):
  349. for j in range(self.height):
  350. v = self.get(i, j)
  351. grid.set(j, grid.height - 1 - i, v)
  352. return grid
  353. def slice(self, topX, topY, width, height):
  354. """
  355. Get a subset of the grid
  356. """
  357. grid = Grid(width, height)
  358. for j in range(0, height):
  359. for i in range(0, width):
  360. x = topX + i
  361. y = topY + j
  362. if x >= 0 and x < self.width and \
  363. y >= 0 and y < self.height:
  364. v = self.get(x, y)
  365. else:
  366. v = Wall()
  367. grid.set(i, j, v)
  368. return grid
  369. @classmethod
  370. def render_tile(
  371. cls,
  372. obj,
  373. agent_dir=None,
  374. highlight=False,
  375. tile_size=TILE_PIXELS
  376. ):
  377. """
  378. Render a tile and cache the result
  379. """
  380. # Hash map lookup key for the cache
  381. key = (agent_dir, highlight, tile_size)
  382. key = obj.encode() + key if obj else key
  383. if key in cls.tile_cache:
  384. return cls.tile_cache[key]
  385. img = np.zeros(shape=(tile_size, tile_size, 3), dtype=np.uint8)
  386. # Draw the grid lines (top and left edges)
  387. fill_coords(img, point_in_rect(0, 0.031, 0, 1), (100, 100, 100))
  388. fill_coords(img, point_in_rect(0, 1, 0, 0.031), (100, 100, 100))
  389. if obj != None:
  390. obj.render(img)
  391. # Overlay the agent on top
  392. if agent_dir is not None:
  393. tri_fn = point_in_triangle(
  394. (0.12, 0.19),
  395. (0.87, 0.50),
  396. (0.12, 0.81),
  397. )
  398. # Rotate the agent based on its direction
  399. tri_fn = rotate_fn(tri_fn, cx=0.5, cy=0.5, theta=0.5*math.pi*agent_dir)
  400. fill_coords(img, tri_fn, (255, 0, 0))
  401. # Highlight the cell if needed
  402. if highlight:
  403. highlight_img(img)
  404. # Cache the rendered tile
  405. cls.tile_cache[key] = img
  406. return img
  407. def render(
  408. self,
  409. tile_size,
  410. agent_pos=None,
  411. agent_dir=None,
  412. highlight_mask=None
  413. ):
  414. """
  415. Render this grid at a given scale
  416. :param r: target renderer object
  417. :param tile_size: tile size in pixels
  418. """
  419. if highlight_mask is None:
  420. highlight_mask = np.zeros(shape=(self.width, self.height), dtype=np.bool)
  421. # Compute the total grid size
  422. width_px = self.width * tile_size
  423. height_px = self.height * tile_size
  424. img = np.zeros(shape=(height_px, width_px, 3), dtype=np.uint8)
  425. # Render the grid
  426. for j in range(0, self.height):
  427. for i in range(0, self.width):
  428. cell = self.get(i, j)
  429. agent_here = np.array_equal(agent_pos, (i, j))
  430. tile_img = Grid.render_tile(
  431. cell,
  432. agent_dir=agent_dir if agent_here else None,
  433. highlight=highlight_mask[i, j],
  434. tile_size=tile_size
  435. )
  436. ymin = j * tile_size
  437. ymax = (j+1) * tile_size
  438. xmin = i * tile_size
  439. xmax = (i+1) * tile_size
  440. img[ymin:ymax, xmin:xmax, :] = tile_img
  441. return img
  442. def encode(self, vis_mask=None):
  443. """
  444. Produce a compact numpy encoding of the grid
  445. """
  446. if vis_mask is None:
  447. vis_mask = np.ones((self.width, self.height), dtype=bool)
  448. array = np.zeros((self.width, self.height, 3), dtype='uint8')
  449. for i in range(self.width):
  450. for j in range(self.height):
  451. if vis_mask[i, j]:
  452. v = self.get(i, j)
  453. if v is None:
  454. array[i, j, 0] = OBJECT_TO_IDX['empty']
  455. array[i, j, 1] = 0
  456. array[i, j, 2] = 0
  457. else:
  458. array[i, j, :] = v.encode()
  459. return array
  460. @staticmethod
  461. def decode(array):
  462. """
  463. Decode an array grid encoding back into a grid
  464. """
  465. width, height, channels = array.shape
  466. assert channels == 3
  467. grid = Grid(width, height)
  468. for i in range(width):
  469. for j in range(height):
  470. type_idx, color_idx, state = array[i, j]
  471. v = WorldObj.decode(type_idx, color_idx, state)
  472. grid.set(i, j, v)
  473. return grid
  474. def process_vis(grid, agent_pos):
  475. mask = np.zeros(shape=(grid.width, grid.height), dtype=np.bool)
  476. mask[agent_pos[0], agent_pos[1]] = True
  477. for j in reversed(range(0, grid.height)):
  478. for i in range(0, grid.width-1):
  479. if not mask[i, j]:
  480. continue
  481. cell = grid.get(i, j)
  482. if cell and not cell.see_behind():
  483. continue
  484. mask[i+1, j] = True
  485. if j > 0:
  486. mask[i+1, j-1] = True
  487. mask[i, j-1] = True
  488. for i in reversed(range(1, grid.width)):
  489. if not mask[i, j]:
  490. continue
  491. cell = grid.get(i, j)
  492. if cell and not cell.see_behind():
  493. continue
  494. mask[i-1, j] = True
  495. if j > 0:
  496. mask[i-1, j-1] = True
  497. mask[i, j-1] = True
  498. for j in range(0, grid.height):
  499. for i in range(0, grid.width):
  500. if not mask[i, j]:
  501. grid.set(i, j, None)
  502. return mask
  503. class MiniGridEnv(gym.Env):
  504. """
  505. 2D grid world game environment
  506. """
  507. metadata = {
  508. 'render.modes': ['human', 'rgb_array', 'pixmap'],
  509. 'video.frames_per_second' : 10
  510. }
  511. # Enumeration of possible actions
  512. class Actions(IntEnum):
  513. # Turn left, turn right, move forward
  514. left = 0
  515. right = 1
  516. forward = 2
  517. # Pick up an object
  518. pickup = 3
  519. # Drop an object
  520. drop = 4
  521. # Toggle/activate an object
  522. toggle = 5
  523. # Done completing task
  524. done = 6
  525. def __init__(
  526. self,
  527. grid_size=None,
  528. width=None,
  529. height=None,
  530. max_steps=100,
  531. see_through_walls=False,
  532. seed=1337,
  533. agent_view_size=7
  534. ):
  535. # Can't set both grid_size and width/height
  536. if grid_size:
  537. assert width == None and height == None
  538. width = grid_size
  539. height = grid_size
  540. # Action enumeration for this environment
  541. self.actions = MiniGridEnv.Actions
  542. # Actions are discrete integer values
  543. self.action_space = spaces.Discrete(len(self.actions))
  544. # Number of cells (width and height) in the agent view
  545. self.agent_view_size = agent_view_size
  546. # Observations are dictionaries containing an
  547. # encoding of the grid and a textual 'mission' string
  548. self.observation_space = spaces.Box(
  549. low=0,
  550. high=255,
  551. shape=(self.agent_view_size, self.agent_view_size, 3),
  552. dtype='uint8'
  553. )
  554. self.observation_space = spaces.Dict({
  555. 'image': self.observation_space
  556. })
  557. # Range of possible rewards
  558. self.reward_range = (0, 1)
  559. # Renderer object used to render the whole grid (full-scale)
  560. self.grid_render = None
  561. # Renderer used to render observations (small-scale agent view)
  562. self.obs_render = None
  563. # Environment configuration
  564. self.width = width
  565. self.height = height
  566. self.max_steps = max_steps
  567. self.see_through_walls = see_through_walls
  568. # Current position and direction of the agent
  569. self.agent_pos = None
  570. self.agent_dir = None
  571. # Initialize the RNG
  572. self.seed(seed=seed)
  573. # Initialize the state
  574. self.reset()
  575. def reset(self):
  576. # Current position and direction of the agent
  577. self.agent_pos = None
  578. self.agent_dir = None
  579. # Generate a new random grid at the start of each episode
  580. # To keep the same grid for each episode, call env.seed() with
  581. # the same seed before calling env.reset()
  582. self._gen_grid(self.width, self.height)
  583. # These fields should be defined by _gen_grid
  584. assert self.agent_pos is not None
  585. assert self.agent_dir is not None
  586. # Check that the agent doesn't overlap with an object
  587. start_cell = self.grid.get(*self.agent_pos)
  588. assert start_cell is None or start_cell.can_overlap()
  589. # Item picked up, being carried, initially nothing
  590. self.carrying = None
  591. # Step count since episode start
  592. self.step_count = 0
  593. # Return first observation
  594. obs = self.gen_obs()
  595. return obs
  596. def seed(self, seed=1337):
  597. # Seed the random number generator
  598. self.np_random, _ = seeding.np_random(seed)
  599. return [seed]
  600. @property
  601. def steps_remaining(self):
  602. return self.max_steps - self.step_count
  603. def __str__(self):
  604. """
  605. Produce a pretty string of the environment's grid along with the agent.
  606. A grid cell is represented by 2-character string, the first one for
  607. the object and the second one for the color.
  608. """
  609. # Map of object types to short string
  610. OBJECT_TO_STR = {
  611. 'wall' : 'W',
  612. 'floor' : 'F',
  613. 'door' : 'D',
  614. 'key' : 'K',
  615. 'ball' : 'A',
  616. 'box' : 'B',
  617. 'goal' : 'G',
  618. 'lava' : 'V',
  619. }
  620. # Short string for opened door
  621. OPENDED_DOOR_IDS = '_'
  622. # Map agent's direction to short string
  623. AGENT_DIR_TO_STR = {
  624. 0: '>',
  625. 1: 'V',
  626. 2: '<',
  627. 3: '^'
  628. }
  629. str = ''
  630. for j in range(self.grid.height):
  631. for i in range(self.grid.width):
  632. if i == self.agent_pos[0] and j == self.agent_pos[1]:
  633. str += 2 * AGENT_DIR_TO_STR[self.agent_dir]
  634. continue
  635. c = self.grid.get(i, j)
  636. if c == None:
  637. str += ' '
  638. continue
  639. if c.type == 'door':
  640. if c.is_open:
  641. str += '__'
  642. elif c.is_locked:
  643. str += 'L' + c.color[0].upper()
  644. else:
  645. str += 'D' + c.color[0].upper()
  646. continue
  647. str += OBJECT_TO_STR[c.type] + c.color[0].upper()
  648. if j < self.grid.height - 1:
  649. str += '\n'
  650. return str
  651. def _gen_grid(self, width, height):
  652. assert False, "_gen_grid needs to be implemented by each environment"
  653. def _reward(self):
  654. """
  655. Compute the reward to be given upon success
  656. """
  657. return 1 - 0.9 * (self.step_count / self.max_steps)
  658. def _rand_int(self, low, high):
  659. """
  660. Generate random integer in [low,high[
  661. """
  662. return self.np_random.randint(low, high)
  663. def _rand_float(self, low, high):
  664. """
  665. Generate random float in [low,high[
  666. """
  667. return self.np_random.uniform(low, high)
  668. def _rand_bool(self):
  669. """
  670. Generate random boolean value
  671. """
  672. return (self.np_random.randint(0, 2) == 0)
  673. def _rand_elem(self, iterable):
  674. """
  675. Pick a random element in a list
  676. """
  677. lst = list(iterable)
  678. idx = self._rand_int(0, len(lst))
  679. return lst[idx]
  680. def _rand_subset(self, iterable, num_elems):
  681. """
  682. Sample a random subset of distinct elements of a list
  683. """
  684. lst = list(iterable)
  685. assert num_elems <= len(lst)
  686. out = []
  687. while len(out) < num_elems:
  688. elem = self._rand_elem(lst)
  689. lst.remove(elem)
  690. out.append(elem)
  691. return out
  692. def _rand_color(self):
  693. """
  694. Generate a random color name (string)
  695. """
  696. return self._rand_elem(COLOR_NAMES)
  697. def _rand_pos(self, xLow, xHigh, yLow, yHigh):
  698. """
  699. Generate a random (x,y) position tuple
  700. """
  701. return (
  702. self.np_random.randint(xLow, xHigh),
  703. self.np_random.randint(yLow, yHigh)
  704. )
  705. def place_obj(self,
  706. obj,
  707. top=None,
  708. size=None,
  709. reject_fn=None,
  710. max_tries=math.inf
  711. ):
  712. """
  713. Place an object at an empty position in the grid
  714. :param top: top-left position of the rectangle where to place
  715. :param size: size of the rectangle where to place
  716. :param reject_fn: function to filter out potential positions
  717. """
  718. if top is None:
  719. top = (0, 0)
  720. else:
  721. top = (max(top[0], 0), max(top[1], 0))
  722. if size is None:
  723. size = (self.grid.width, self.grid.height)
  724. num_tries = 0
  725. while True:
  726. # This is to handle with rare cases where rejection sampling
  727. # gets stuck in an infinite loop
  728. if num_tries > max_tries:
  729. raise RecursionError('rejection sampling failed in place_obj')
  730. num_tries += 1
  731. pos = np.array((
  732. self._rand_int(top[0], min(top[0] + size[0], self.grid.width)),
  733. self._rand_int(top[1], min(top[1] + size[1], self.grid.height))
  734. ))
  735. # Don't place the object on top of another object
  736. if self.grid.get(*pos) != None:
  737. continue
  738. # Don't place the object where the agent is
  739. if np.array_equal(pos, self.agent_pos):
  740. continue
  741. # Check if there is a filtering criterion
  742. if reject_fn and reject_fn(self, pos):
  743. continue
  744. break
  745. self.grid.set(*pos, obj)
  746. if obj is not None:
  747. obj.init_pos = pos
  748. obj.cur_pos = pos
  749. return pos
  750. def put_obj(self, obj, i, j):
  751. """
  752. Put an object at a specific position in the grid
  753. """
  754. self.grid.set(i, j, obj)
  755. obj.init_pos = (i, j)
  756. obj.cur_pos = (i, j)
  757. def place_agent(
  758. self,
  759. top=None,
  760. size=None,
  761. rand_dir=True,
  762. max_tries=math.inf
  763. ):
  764. """
  765. Set the agent's starting point at an empty position in the grid
  766. """
  767. self.agent_pos = None
  768. pos = self.place_obj(None, top, size, max_tries=max_tries)
  769. self.agent_pos = pos
  770. if rand_dir:
  771. self.agent_dir = self._rand_int(0, 4)
  772. return pos
  773. @property
  774. def dir_vec(self):
  775. """
  776. Get the direction vector for the agent, pointing in the direction
  777. of forward movement.
  778. """
  779. assert self.agent_dir >= 0 and self.agent_dir < 4
  780. return DIR_TO_VEC[self.agent_dir]
  781. @property
  782. def right_vec(self):
  783. """
  784. Get the vector pointing to the right of the agent.
  785. """
  786. dx, dy = self.dir_vec
  787. return np.array((-dy, dx))
  788. @property
  789. def front_pos(self):
  790. """
  791. Get the position of the cell that is right in front of the agent
  792. """
  793. return self.agent_pos + self.dir_vec
  794. def get_view_coords(self, i, j):
  795. """
  796. Translate and rotate absolute grid coordinates (i, j) into the
  797. agent's partially observable view (sub-grid). Note that the resulting
  798. coordinates may be negative or outside of the agent's view size.
  799. """
  800. ax, ay = self.agent_pos
  801. dx, dy = self.dir_vec
  802. rx, ry = self.right_vec
  803. # Compute the absolute coordinates of the top-left view corner
  804. sz = self.agent_view_size
  805. hs = self.agent_view_size // 2
  806. tx = ax + (dx * (sz-1)) - (rx * hs)
  807. ty = ay + (dy * (sz-1)) - (ry * hs)
  808. lx = i - tx
  809. ly = j - ty
  810. # Project the coordinates of the object relative to the top-left
  811. # corner onto the agent's own coordinate system
  812. vx = (rx*lx + ry*ly)
  813. vy = -(dx*lx + dy*ly)
  814. return vx, vy
  815. def get_view_exts(self):
  816. """
  817. Get the extents of the square set of tiles visible to the agent
  818. Note: the bottom extent indices are not included in the set
  819. """
  820. # Facing right
  821. if self.agent_dir == 0:
  822. topX = self.agent_pos[0]
  823. topY = self.agent_pos[1] - self.agent_view_size // 2
  824. # Facing down
  825. elif self.agent_dir == 1:
  826. topX = self.agent_pos[0] - self.agent_view_size // 2
  827. topY = self.agent_pos[1]
  828. # Facing left
  829. elif self.agent_dir == 2:
  830. topX = self.agent_pos[0] - self.agent_view_size + 1
  831. topY = self.agent_pos[1] - self.agent_view_size // 2
  832. # Facing up
  833. elif self.agent_dir == 3:
  834. topX = self.agent_pos[0] - self.agent_view_size // 2
  835. topY = self.agent_pos[1] - self.agent_view_size + 1
  836. else:
  837. assert False, "invalid agent direction"
  838. botX = topX + self.agent_view_size
  839. botY = topY + self.agent_view_size
  840. return (topX, topY, botX, botY)
  841. def relative_coords(self, x, y):
  842. """
  843. Check if a grid position belongs to the agent's field of view, and returns the corresponding coordinates
  844. """
  845. vx, vy = self.get_view_coords(x, y)
  846. if vx < 0 or vy < 0 or vx >= self.agent_view_size or vy >= self.agent_view_size:
  847. return None
  848. return vx, vy
  849. def in_view(self, x, y):
  850. """
  851. check if a grid position is visible to the agent
  852. """
  853. return self.relative_coords(x, y) is not None
  854. def agent_sees(self, x, y):
  855. """
  856. Check if a non-empty grid position is visible to the agent
  857. """
  858. coordinates = self.relative_coords(x, y)
  859. if coordinates is None:
  860. return False
  861. vx, vy = coordinates
  862. obs = self.gen_obs()
  863. obs_grid = Grid.decode(obs['image'])
  864. obs_cell = obs_grid.get(vx, vy)
  865. world_cell = self.grid.get(x, y)
  866. return obs_cell is not None and obs_cell.type == world_cell.type
  867. def step(self, action):
  868. self.step_count += 1
  869. reward = 0
  870. done = False
  871. # Get the position in front of the agent
  872. fwd_pos = self.front_pos
  873. # Get the contents of the cell in front of the agent
  874. fwd_cell = self.grid.get(*fwd_pos)
  875. # Rotate left
  876. if action == self.actions.left:
  877. self.agent_dir -= 1
  878. if self.agent_dir < 0:
  879. self.agent_dir += 4
  880. # Rotate right
  881. elif action == self.actions.right:
  882. self.agent_dir = (self.agent_dir + 1) % 4
  883. # Move forward
  884. elif action == self.actions.forward:
  885. if fwd_cell == None or fwd_cell.can_overlap():
  886. self.agent_pos = fwd_pos
  887. if fwd_cell != None and fwd_cell.type == 'goal':
  888. done = True
  889. reward = self._reward()
  890. if fwd_cell != None and fwd_cell.type == 'lava':
  891. done = True
  892. # Pick up an object
  893. elif action == self.actions.pickup:
  894. if fwd_cell and fwd_cell.can_pickup():
  895. if self.carrying is None:
  896. self.carrying = fwd_cell
  897. self.carrying.cur_pos = np.array([-1, -1])
  898. self.grid.set(*fwd_pos, None)
  899. # Drop an object
  900. elif action == self.actions.drop:
  901. if not fwd_cell and self.carrying:
  902. self.grid.set(*fwd_pos, self.carrying)
  903. self.carrying.cur_pos = fwd_pos
  904. self.carrying = None
  905. # Toggle/activate an object
  906. elif action == self.actions.toggle:
  907. if fwd_cell:
  908. fwd_cell.toggle(self, fwd_pos)
  909. # Done action (not used by default)
  910. elif action == self.actions.done:
  911. pass
  912. else:
  913. assert False, "unknown action"
  914. if self.step_count >= self.max_steps:
  915. done = True
  916. obs = self.gen_obs()
  917. return obs, reward, done, {}
  918. def gen_obs_grid(self):
  919. """
  920. Generate the sub-grid observed by the agent.
  921. This method also outputs a visibility mask telling us which grid
  922. cells the agent can actually see.
  923. """
  924. topX, topY, botX, botY = self.get_view_exts()
  925. grid = self.grid.slice(topX, topY, self.agent_view_size, self.agent_view_size)
  926. for i in range(self.agent_dir + 1):
  927. grid = grid.rotate_left()
  928. # Process occluders and visibility
  929. # Note that this incurs some performance cost
  930. if not self.see_through_walls:
  931. vis_mask = grid.process_vis(agent_pos=(self.agent_view_size // 2 , self.agent_view_size - 1))
  932. else:
  933. vis_mask = np.ones(shape=(grid.width, grid.height), dtype=np.bool)
  934. # Make it so the agent sees what it's carrying
  935. # We do this by placing the carried object at the agent's position
  936. # in the agent's partially observable view
  937. agent_pos = grid.width // 2, grid.height - 1
  938. if self.carrying:
  939. grid.set(*agent_pos, self.carrying)
  940. else:
  941. grid.set(*agent_pos, None)
  942. return grid, vis_mask
  943. def gen_obs(self):
  944. """
  945. Generate the agent's view (partially observable, low-resolution encoding)
  946. """
  947. grid, vis_mask = self.gen_obs_grid()
  948. # Encode the partially observable view into a numpy array
  949. image = grid.encode(vis_mask)
  950. assert hasattr(self, 'mission'), "environments must define a textual mission string"
  951. # Observations are dictionaries containing:
  952. # - an image (partially observable view of the environment)
  953. # - the agent's direction/orientation (acting as a compass)
  954. # - a textual mission string (instructions for the agent)
  955. obs = {
  956. 'image': image,
  957. 'direction': self.agent_dir,
  958. 'mission': self.mission
  959. }
  960. return obs
  961. def get_obs_render(self, obs, tile_size=TILE_PIXELS//2):
  962. """
  963. Render an agent observation for visualization
  964. """
  965. grid = Grid.decode(obs)
  966. # Render the whole grid
  967. img = grid.render(r, tile_size)
  968. assert False
  969. """
  970. # Draw the agent
  971. ratio = tile_size / TILE_PIXELS
  972. r.push()
  973. r.scale(ratio, ratio)
  974. r.translate(
  975. TILE_PIXELS * (0.5 + self.agent_view_size // 2),
  976. TILE_PIXELS * (self.agent_view_size - 0.5)
  977. )
  978. r.rotate(3 * 90)
  979. r.setLineColor(255, 0, 0)
  980. r.setColor(255, 0, 0)
  981. r.drawPolygon([
  982. (-12, 10),
  983. ( 12, 0),
  984. (-12, -10)
  985. ])
  986. r.pop()
  987. """
  988. return img
  989. def render(self, mode='human', close=False, highlight=True, tile_size=TILE_PIXELS):
  990. """
  991. Render the whole-grid human view
  992. """
  993. """
  994. if close:
  995. if self.grid_render:
  996. self.grid_render.close()
  997. return
  998. """
  999. # Compute which cells are visible to the agent
  1000. _, vis_mask = self.gen_obs_grid()
  1001. # Compute the world coordinates of the bottom-left corner
  1002. # of the agent's view area
  1003. f_vec = self.dir_vec
  1004. r_vec = self.right_vec
  1005. top_left = self.agent_pos + f_vec * (self.agent_view_size-1) - r_vec * (self.agent_view_size // 2)
  1006. # Mask of which cells to highlight
  1007. highlight_mask = np.zeros(shape=(self.width, self.height), dtype=np.bool)
  1008. # For each cell in the visibility mask
  1009. for vis_j in range(0, self.agent_view_size):
  1010. for vis_i in range(0, self.agent_view_size):
  1011. # If this cell is not visible, don't highlight it
  1012. if not vis_mask[vis_i, vis_j]:
  1013. continue
  1014. # Compute the world coordinates of this cell
  1015. abs_i, abs_j = top_left - (f_vec * vis_j) + (r_vec * vis_i)
  1016. if abs_i < 0 or abs_i >= self.width:
  1017. continue
  1018. if abs_j < 0 or abs_j >= self.height:
  1019. continue
  1020. # Mark this cell to be highlighted
  1021. highlight_mask[abs_i, abs_j] = True
  1022. # Render the whole grid
  1023. img = self.grid.render(
  1024. tile_size,
  1025. self.agent_pos,
  1026. self.agent_dir,
  1027. highlight_mask=highlight_mask if highlight else None
  1028. )
  1029. return img