minigrid.py 36 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301
  1. import math
  2. import hashlib
  3. import gym
  4. from enum import IntEnum
  5. import numpy as np
  6. from gym import error, spaces, utils
  7. from gym.utils import seeding
  8. from .rendering import *
  9. # Size in pixels of a tile in the full-scale human view
  10. TILE_PIXELS = 32
  11. # Map of color names to RGB values
  12. COLORS = {
  13. 'red' : np.array([255, 0, 0]),
  14. 'green' : np.array([0, 255, 0]),
  15. 'blue' : np.array([0, 0, 255]),
  16. 'purple': np.array([112, 39, 195]),
  17. 'yellow': np.array([255, 255, 0]),
  18. 'grey' : np.array([100, 100, 100])
  19. }
  20. COLOR_NAMES = sorted(list(COLORS.keys()))
  21. # Used to map colors to integers
  22. COLOR_TO_IDX = {
  23. 'red' : 0,
  24. 'green' : 1,
  25. 'blue' : 2,
  26. 'purple': 3,
  27. 'yellow': 4,
  28. 'grey' : 5
  29. }
  30. IDX_TO_COLOR = dict(zip(COLOR_TO_IDX.values(), COLOR_TO_IDX.keys()))
  31. # Map of object type to integers
  32. OBJECT_TO_IDX = {
  33. 'unseen' : 0,
  34. 'empty' : 1,
  35. 'wall' : 2,
  36. 'floor' : 3,
  37. 'door' : 4,
  38. 'key' : 5,
  39. 'ball' : 6,
  40. 'box' : 7,
  41. 'goal' : 8,
  42. 'lava' : 9,
  43. 'agent' : 10,
  44. }
  45. IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys()))
  46. # Map of state names to integers
  47. STATE_TO_IDX = {
  48. 'open' : 0,
  49. 'closed': 1,
  50. 'locked': 2,
  51. }
  52. # Map of agent direction indices to vectors
  53. DIR_TO_VEC = [
  54. # Pointing right (positive X)
  55. np.array((1, 0)),
  56. # Down (positive Y)
  57. np.array((0, 1)),
  58. # Pointing left (negative X)
  59. np.array((-1, 0)),
  60. # Up (negative Y)
  61. np.array((0, -1)),
  62. ]
  63. class WorldObj:
  64. """
  65. Base class for grid world objects
  66. """
  67. def __init__(self, type, color):
  68. assert type in OBJECT_TO_IDX, type
  69. assert color in COLOR_TO_IDX, color
  70. self.type = type
  71. self.color = color
  72. self.contains = None
  73. # Initial position of the object
  74. self.init_pos = None
  75. # Current position of the object
  76. self.cur_pos = None
  77. def can_overlap(self):
  78. """Can the agent overlap with this?"""
  79. return False
  80. def can_pickup(self):
  81. """Can the agent pick this up?"""
  82. return False
  83. def can_contain(self):
  84. """Can this contain another object?"""
  85. return False
  86. def see_behind(self):
  87. """Can the agent see behind this object?"""
  88. return True
  89. def toggle(self, env, pos):
  90. """Method to trigger/toggle an action this object performs"""
  91. return False
  92. def encode(self):
  93. """Encode the a description of this object as a 3-tuple of integers"""
  94. return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], 0)
  95. @staticmethod
  96. def decode(type_idx, color_idx, state):
  97. """Create an object from a 3-tuple state description"""
  98. obj_type = IDX_TO_OBJECT[type_idx]
  99. color = IDX_TO_COLOR[color_idx]
  100. if obj_type == 'empty' or obj_type == 'unseen':
  101. return None
  102. # State, 0: open, 1: closed, 2: locked
  103. is_open = state == 0
  104. is_locked = state == 2
  105. if obj_type == 'wall':
  106. v = Wall(color)
  107. elif obj_type == 'floor':
  108. v = Floor(color)
  109. elif obj_type == 'ball':
  110. v = Ball(color)
  111. elif obj_type == 'key':
  112. v = Key(color)
  113. elif obj_type == 'box':
  114. v = Box(color)
  115. elif obj_type == 'door':
  116. v = Door(color, is_open, is_locked)
  117. elif obj_type == 'goal':
  118. v = Goal()
  119. elif obj_type == 'lava':
  120. v = Lava()
  121. else:
  122. assert False, "unknown object type in decode '%s'" % obj_type
  123. return v
  124. def render(self, r):
  125. """Draw this object with the given renderer"""
  126. raise NotImplementedError
  127. class Goal(WorldObj):
  128. def __init__(self):
  129. super().__init__('goal', 'green')
  130. def can_overlap(self):
  131. return True
  132. def render(self, img):
  133. fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color])
  134. class Floor(WorldObj):
  135. """
  136. Colored floor tile the agent can walk over
  137. """
  138. def __init__(self, color='blue'):
  139. super().__init__('floor', color)
  140. def can_overlap(self):
  141. return True
  142. def render(self, img):
  143. # Give the floor a pale color
  144. color = COLORS[self.color] / 2
  145. fill_coords(img, point_in_rect(0.031, 1, 0.031, 1), color)
  146. class Lava(WorldObj):
  147. def __init__(self):
  148. super().__init__('lava', 'red')
  149. def can_overlap(self):
  150. return True
  151. def render(self, img):
  152. c = (255, 128, 0)
  153. # Background color
  154. fill_coords(img, point_in_rect(0, 1, 0, 1), c)
  155. # Little waves
  156. for i in range(3):
  157. ylo = 0.3 + 0.2 * i
  158. yhi = 0.4 + 0.2 * i
  159. fill_coords(img, point_in_line(0.1, ylo, 0.3, yhi, r=0.03), (0,0,0))
  160. fill_coords(img, point_in_line(0.3, yhi, 0.5, ylo, r=0.03), (0,0,0))
  161. fill_coords(img, point_in_line(0.5, ylo, 0.7, yhi, r=0.03), (0,0,0))
  162. fill_coords(img, point_in_line(0.7, yhi, 0.9, ylo, r=0.03), (0,0,0))
  163. class Wall(WorldObj):
  164. def __init__(self, color='grey'):
  165. super().__init__('wall', color)
  166. def see_behind(self):
  167. return False
  168. def render(self, img):
  169. fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color])
  170. class Door(WorldObj):
  171. def __init__(self, color, is_open=False, is_locked=False):
  172. super().__init__('door', color)
  173. self.is_open = is_open
  174. self.is_locked = is_locked
  175. def can_overlap(self):
  176. """The agent can only walk over this cell when the door is open"""
  177. return self.is_open
  178. def see_behind(self):
  179. return self.is_open
  180. def toggle(self, env, pos):
  181. # If the player has the right key to open the door
  182. if self.is_locked:
  183. if isinstance(env.carrying, Key) and env.carrying.color == self.color:
  184. self.is_locked = False
  185. self.is_open = True
  186. return True
  187. return False
  188. self.is_open = not self.is_open
  189. return True
  190. def encode(self):
  191. """Encode the a description of this object as a 3-tuple of integers"""
  192. # State, 0: open, 1: closed, 2: locked
  193. if self.is_open:
  194. state = 0
  195. elif self.is_locked:
  196. state = 2
  197. elif not self.is_open:
  198. state = 1
  199. return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], state)
  200. def render(self, img):
  201. c = COLORS[self.color]
  202. if self.is_open:
  203. fill_coords(img, point_in_rect(0.88, 1.00, 0.00, 1.00), c)
  204. fill_coords(img, point_in_rect(0.92, 0.96, 0.04, 0.96), (0,0,0))
  205. return
  206. # Door frame and door
  207. if self.is_locked:
  208. fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)
  209. fill_coords(img, point_in_rect(0.06, 0.94, 0.06, 0.94), 0.45 * np.array(c))
  210. # Draw key slot
  211. fill_coords(img, point_in_rect(0.52, 0.75, 0.50, 0.56), c)
  212. else:
  213. fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)
  214. fill_coords(img, point_in_rect(0.04, 0.96, 0.04, 0.96), (0,0,0))
  215. fill_coords(img, point_in_rect(0.08, 0.92, 0.08, 0.92), c)
  216. fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), (0,0,0))
  217. # Draw door handle
  218. fill_coords(img, point_in_circle(cx=0.75, cy=0.50, r=0.08), c)
  219. class Key(WorldObj):
  220. def __init__(self, color='blue'):
  221. super(Key, self).__init__('key', color)
  222. def can_pickup(self):
  223. return True
  224. def render(self, img):
  225. c = COLORS[self.color]
  226. # Vertical quad
  227. fill_coords(img, point_in_rect(0.50, 0.63, 0.31, 0.88), c)
  228. # Teeth
  229. fill_coords(img, point_in_rect(0.38, 0.50, 0.59, 0.66), c)
  230. fill_coords(img, point_in_rect(0.38, 0.50, 0.81, 0.88), c)
  231. # Ring
  232. fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.190), c)
  233. fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.064), (0,0,0))
  234. class Ball(WorldObj):
  235. def __init__(self, color='blue'):
  236. super(Ball, self).__init__('ball', color)
  237. def can_pickup(self):
  238. return True
  239. def render(self, img):
  240. fill_coords(img, point_in_circle(0.5, 0.5, 0.31), COLORS[self.color])
  241. class Box(WorldObj):
  242. def __init__(self, color, contains=None):
  243. super(Box, self).__init__('box', color)
  244. self.contains = contains
  245. def can_pickup(self):
  246. return True
  247. def render(self, img):
  248. c = COLORS[self.color]
  249. # Outline
  250. fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), c)
  251. fill_coords(img, point_in_rect(0.18, 0.82, 0.18, 0.82), (0,0,0))
  252. # Horizontal slit
  253. fill_coords(img, point_in_rect(0.16, 0.84, 0.47, 0.53), c)
  254. def toggle(self, env, pos):
  255. # Replace the box by its contents
  256. env.grid.set(*pos, self.contains)
  257. return True
  258. class Grid:
  259. """
  260. Represent a grid and operations on it
  261. """
  262. # Static cache of pre-renderer tiles
  263. tile_cache = {}
  264. def __init__(self, width, height):
  265. assert width >= 3
  266. assert height >= 3
  267. self.width = width
  268. self.height = height
  269. self.grid = [None] * width * height
  270. def __contains__(self, key):
  271. if isinstance(key, WorldObj):
  272. for e in self.grid:
  273. if e is key:
  274. return True
  275. elif isinstance(key, tuple):
  276. for e in self.grid:
  277. if e is None:
  278. continue
  279. if (e.color, e.type) == key:
  280. return True
  281. if key[0] is None and key[1] == e.type:
  282. return True
  283. return False
  284. def __eq__(self, other):
  285. grid1 = self.encode()
  286. grid2 = other.encode()
  287. return np.array_equal(grid2, grid1)
  288. def __ne__(self, other):
  289. return not self == other
  290. def copy(self):
  291. from copy import deepcopy
  292. return deepcopy(self)
  293. def set(self, i, j, v):
  294. assert i >= 0 and i < self.width
  295. assert j >= 0 and j < self.height
  296. self.grid[j * self.width + i] = v
  297. def get(self, i, j):
  298. assert i >= 0 and i < self.width
  299. assert j >= 0 and j < self.height
  300. return self.grid[j * self.width + i]
  301. def horz_wall(self, x, y, length=None, obj_type=Wall):
  302. if length is None:
  303. length = self.width - x
  304. for i in range(0, length):
  305. self.set(x + i, y, obj_type())
  306. def vert_wall(self, x, y, length=None, obj_type=Wall):
  307. if length is None:
  308. length = self.height - y
  309. for j in range(0, length):
  310. self.set(x, y + j, obj_type())
  311. def wall_rect(self, x, y, w, h):
  312. self.horz_wall(x, y, w)
  313. self.horz_wall(x, y+h-1, w)
  314. self.vert_wall(x, y, h)
  315. self.vert_wall(x+w-1, y, h)
  316. def rotate_left(self):
  317. """
  318. Rotate the grid to the left (counter-clockwise)
  319. """
  320. grid = Grid(self.height, self.width)
  321. for i in range(self.width):
  322. for j in range(self.height):
  323. v = self.get(i, j)
  324. grid.set(j, grid.height - 1 - i, v)
  325. return grid
  326. def slice(self, topX, topY, width, height):
  327. """
  328. Get a subset of the grid
  329. """
  330. grid = Grid(width, height)
  331. for j in range(0, height):
  332. for i in range(0, width):
  333. x = topX + i
  334. y = topY + j
  335. if x >= 0 and x < self.width and \
  336. y >= 0 and y < self.height:
  337. v = self.get(x, y)
  338. else:
  339. v = Wall()
  340. grid.set(i, j, v)
  341. return grid
  342. @classmethod
  343. def render_tile(
  344. cls,
  345. obj,
  346. agent_dir=None,
  347. highlight=False,
  348. tile_size=TILE_PIXELS,
  349. subdivs=3
  350. ):
  351. """
  352. Render a tile and cache the result
  353. """
  354. # Hash map lookup key for the cache
  355. key = (agent_dir, highlight, tile_size)
  356. key = obj.encode() + key if obj else key
  357. if key in cls.tile_cache:
  358. return cls.tile_cache[key]
  359. img = np.zeros(shape=(tile_size * subdivs, tile_size * subdivs, 3), dtype=np.uint8)
  360. # Draw the grid lines (top and left edges)
  361. fill_coords(img, point_in_rect(0, 0.031, 0, 1), (100, 100, 100))
  362. fill_coords(img, point_in_rect(0, 1, 0, 0.031), (100, 100, 100))
  363. if obj != None:
  364. obj.render(img)
  365. # Overlay the agent on top
  366. if agent_dir is not None:
  367. tri_fn = point_in_triangle(
  368. (0.12, 0.19),
  369. (0.87, 0.50),
  370. (0.12, 0.81),
  371. )
  372. # Rotate the agent based on its direction
  373. tri_fn = rotate_fn(tri_fn, cx=0.5, cy=0.5, theta=0.5*math.pi*agent_dir)
  374. fill_coords(img, tri_fn, (255, 0, 0))
  375. # Highlight the cell if needed
  376. if highlight:
  377. highlight_img(img)
  378. # Downsample the image to perform supersampling/anti-aliasing
  379. img = downsample(img, subdivs)
  380. # Cache the rendered tile
  381. cls.tile_cache[key] = img
  382. return img
  383. def render(
  384. self,
  385. tile_size,
  386. agent_pos=None,
  387. agent_dir=None,
  388. highlight_mask=None
  389. ):
  390. """
  391. Render this grid at a given scale
  392. :param r: target renderer object
  393. :param tile_size: tile size in pixels
  394. """
  395. if highlight_mask is None:
  396. highlight_mask = np.zeros(shape=(self.width, self.height), dtype=bool)
  397. # Compute the total grid size
  398. width_px = self.width * tile_size
  399. height_px = self.height * tile_size
  400. img = np.zeros(shape=(height_px, width_px, 3), dtype=np.uint8)
  401. # Render the grid
  402. for j in range(0, self.height):
  403. for i in range(0, self.width):
  404. cell = self.get(i, j)
  405. agent_here = np.array_equal(agent_pos, (i, j))
  406. tile_img = Grid.render_tile(
  407. cell,
  408. agent_dir=agent_dir if agent_here else None,
  409. highlight=highlight_mask[i, j],
  410. tile_size=tile_size
  411. )
  412. ymin = j * tile_size
  413. ymax = (j+1) * tile_size
  414. xmin = i * tile_size
  415. xmax = (i+1) * tile_size
  416. img[ymin:ymax, xmin:xmax, :] = tile_img
  417. return img
  418. def encode(self, vis_mask=None):
  419. """
  420. Produce a compact numpy encoding of the grid
  421. """
  422. if vis_mask is None:
  423. vis_mask = np.ones((self.width, self.height), dtype=bool)
  424. array = np.zeros((self.width, self.height, 3), dtype='uint8')
  425. for i in range(self.width):
  426. for j in range(self.height):
  427. if vis_mask[i, j]:
  428. v = self.get(i, j)
  429. if v is None:
  430. array[i, j, 0] = OBJECT_TO_IDX['empty']
  431. array[i, j, 1] = 0
  432. array[i, j, 2] = 0
  433. else:
  434. array[i, j, :] = v.encode()
  435. return array
  436. @staticmethod
  437. def decode(array):
  438. """
  439. Decode an array grid encoding back into a grid
  440. """
  441. width, height, channels = array.shape
  442. assert channels == 3
  443. vis_mask = np.ones(shape=(width, height), dtype=bool)
  444. grid = Grid(width, height)
  445. for i in range(width):
  446. for j in range(height):
  447. type_idx, color_idx, state = array[i, j]
  448. v = WorldObj.decode(type_idx, color_idx, state)
  449. grid.set(i, j, v)
  450. vis_mask[i, j] = (type_idx != OBJECT_TO_IDX['unseen'])
  451. return grid, vis_mask
  452. def process_vis(grid, agent_pos):
  453. mask = np.zeros(shape=(grid.width, grid.height), dtype=bool)
  454. mask[agent_pos[0], agent_pos[1]] = True
  455. for j in reversed(range(0, grid.height)):
  456. for i in range(0, grid.width-1):
  457. if not mask[i, j]:
  458. continue
  459. cell = grid.get(i, j)
  460. if cell and not cell.see_behind():
  461. continue
  462. mask[i+1, j] = True
  463. if j > 0:
  464. mask[i+1, j-1] = True
  465. mask[i, j-1] = True
  466. for i in reversed(range(1, grid.width)):
  467. if not mask[i, j]:
  468. continue
  469. cell = grid.get(i, j)
  470. if cell and not cell.see_behind():
  471. continue
  472. mask[i-1, j] = True
  473. if j > 0:
  474. mask[i-1, j-1] = True
  475. mask[i, j-1] = True
  476. for j in range(0, grid.height):
  477. for i in range(0, grid.width):
  478. if not mask[i, j]:
  479. grid.set(i, j, None)
  480. return mask
  481. class MiniGridEnv(gym.Env):
  482. """
  483. 2D grid world game environment
  484. """
  485. metadata = {
  486. 'render.modes': ['human', 'rgb_array'],
  487. 'video.frames_per_second' : 10
  488. }
  489. # Enumeration of possible actions
  490. class Actions(IntEnum):
  491. # Turn left, turn right, move forward
  492. left = 0
  493. right = 1
  494. forward = 2
  495. # Pick up an object
  496. pickup = 3
  497. # Drop an object
  498. drop = 4
  499. # Toggle/activate an object
  500. toggle = 5
  501. # Done completing task
  502. done = 6
  503. def __init__(
  504. self,
  505. grid_size=None,
  506. width=None,
  507. height=None,
  508. max_steps=100,
  509. see_through_walls=False,
  510. seed=1337,
  511. agent_view_size=7
  512. ):
  513. # Can't set both grid_size and width/height
  514. if grid_size:
  515. assert width == None and height == None
  516. width = grid_size
  517. height = grid_size
  518. # Action enumeration for this environment
  519. self.actions = MiniGridEnv.Actions
  520. # Actions are discrete integer values
  521. self.action_space = spaces.Discrete(len(self.actions))
  522. # Number of cells (width and height) in the agent view
  523. assert agent_view_size % 2 == 1
  524. assert agent_view_size >= 3
  525. self.agent_view_size = agent_view_size
  526. # Observations are dictionaries containing an
  527. # encoding of the grid and a textual 'mission' string
  528. self.observation_space = spaces.Box(
  529. low=0,
  530. high=255,
  531. shape=(self.agent_view_size, self.agent_view_size, 3),
  532. dtype='uint8'
  533. )
  534. self.observation_space = spaces.Dict({
  535. 'image': self.observation_space
  536. })
  537. # Range of possible rewards
  538. self.reward_range = (0, 1)
  539. # Window to use for human rendering mode
  540. self.window = None
  541. # Environment configuration
  542. self.width = width
  543. self.height = height
  544. self.max_steps = max_steps
  545. self.see_through_walls = see_through_walls
  546. # Current position and direction of the agent
  547. self.agent_pos = None
  548. self.agent_dir = None
  549. # Initialize the RNG
  550. self.seed(seed=seed)
  551. # Initialize the state
  552. self.reset()
  553. def reset(self):
  554. # Current position and direction of the agent
  555. self.agent_pos = None
  556. self.agent_dir = None
  557. # Generate a new random grid at the start of each episode
  558. # To keep the same grid for each episode, call env.seed() with
  559. # the same seed before calling env.reset()
  560. self._gen_grid(self.width, self.height)
  561. # These fields should be defined by _gen_grid
  562. assert self.agent_pos is not None
  563. assert self.agent_dir is not None
  564. # Check that the agent doesn't overlap with an object
  565. start_cell = self.grid.get(*self.agent_pos)
  566. assert start_cell is None or start_cell.can_overlap()
  567. # Item picked up, being carried, initially nothing
  568. self.carrying = None
  569. # Step count since episode start
  570. self.step_count = 0
  571. # Return first observation
  572. obs = self.gen_obs()
  573. return obs
  574. def seed(self, seed=1337):
  575. # Seed the random number generator
  576. self.np_random, _ = seeding.np_random(seed)
  577. return [seed]
  578. def hash(self, size=16):
  579. """Compute a hash that uniquely identifies the current state of the environment.
  580. :param size: Size of the hashing
  581. """
  582. sample_hash = hashlib.sha256()
  583. to_encode = [self.grid.encode().tolist(), self.agent_pos, self.agent_dir]
  584. for item in to_encode:
  585. sample_hash.update(str(item).encode('utf8'))
  586. return sample_hash.hexdigest()[:size]
  587. @property
  588. def steps_remaining(self):
  589. return self.max_steps - self.step_count
  590. def __str__(self):
  591. """
  592. Produce a pretty string of the environment's grid along with the agent.
  593. A grid cell is represented by 2-character string, the first one for
  594. the object and the second one for the color.
  595. """
  596. # Map of object types to short string
  597. OBJECT_TO_STR = {
  598. 'wall' : 'W',
  599. 'floor' : 'F',
  600. 'door' : 'D',
  601. 'key' : 'K',
  602. 'ball' : 'A',
  603. 'box' : 'B',
  604. 'goal' : 'G',
  605. 'lava' : 'V',
  606. }
  607. # Short string for opened door
  608. OPENDED_DOOR_IDS = '_'
  609. # Map agent's direction to short string
  610. AGENT_DIR_TO_STR = {
  611. 0: '>',
  612. 1: 'V',
  613. 2: '<',
  614. 3: '^'
  615. }
  616. str = ''
  617. for j in range(self.grid.height):
  618. for i in range(self.grid.width):
  619. if i == self.agent_pos[0] and j == self.agent_pos[1]:
  620. str += 2 * AGENT_DIR_TO_STR[self.agent_dir]
  621. continue
  622. c = self.grid.get(i, j)
  623. if c == None:
  624. str += ' '
  625. continue
  626. if c.type == 'door':
  627. if c.is_open:
  628. str += '__'
  629. elif c.is_locked:
  630. str += 'L' + c.color[0].upper()
  631. else:
  632. str += 'D' + c.color[0].upper()
  633. continue
  634. str += OBJECT_TO_STR[c.type] + c.color[0].upper()
  635. if j < self.grid.height - 1:
  636. str += '\n'
  637. return str
  638. def _gen_grid(self, width, height):
  639. assert False, "_gen_grid needs to be implemented by each environment"
  640. def _reward(self):
  641. """
  642. Compute the reward to be given upon success
  643. """
  644. return 1 - 0.9 * (self.step_count / self.max_steps)
  645. def _rand_int(self, low, high):
  646. """
  647. Generate random integer in [low,high[
  648. """
  649. return self.np_random.randint(low, high)
  650. def _rand_float(self, low, high):
  651. """
  652. Generate random float in [low,high[
  653. """
  654. return self.np_random.uniform(low, high)
  655. def _rand_bool(self):
  656. """
  657. Generate random boolean value
  658. """
  659. return (self.np_random.randint(0, 2) == 0)
  660. def _rand_elem(self, iterable):
  661. """
  662. Pick a random element in a list
  663. """
  664. lst = list(iterable)
  665. idx = self._rand_int(0, len(lst))
  666. return lst[idx]
  667. def _rand_subset(self, iterable, num_elems):
  668. """
  669. Sample a random subset of distinct elements of a list
  670. """
  671. lst = list(iterable)
  672. assert num_elems <= len(lst)
  673. out = []
  674. while len(out) < num_elems:
  675. elem = self._rand_elem(lst)
  676. lst.remove(elem)
  677. out.append(elem)
  678. return out
  679. def _rand_color(self):
  680. """
  681. Generate a random color name (string)
  682. """
  683. return self._rand_elem(COLOR_NAMES)
  684. def _rand_pos(self, xLow, xHigh, yLow, yHigh):
  685. """
  686. Generate a random (x,y) position tuple
  687. """
  688. return (
  689. self.np_random.randint(xLow, xHigh),
  690. self.np_random.randint(yLow, yHigh)
  691. )
  692. def place_obj(self,
  693. obj,
  694. top=None,
  695. size=None,
  696. reject_fn=None,
  697. max_tries=math.inf
  698. ):
  699. """
  700. Place an object at an empty position in the grid
  701. :param top: top-left position of the rectangle where to place
  702. :param size: size of the rectangle where to place
  703. :param reject_fn: function to filter out potential positions
  704. """
  705. if top is None:
  706. top = (0, 0)
  707. else:
  708. top = (max(top[0], 0), max(top[1], 0))
  709. if size is None:
  710. size = (self.grid.width, self.grid.height)
  711. num_tries = 0
  712. while True:
  713. # This is to handle with rare cases where rejection sampling
  714. # gets stuck in an infinite loop
  715. if num_tries > max_tries:
  716. raise RecursionError('rejection sampling failed in place_obj')
  717. num_tries += 1
  718. pos = np.array((
  719. self._rand_int(top[0], min(top[0] + size[0], self.grid.width)),
  720. self._rand_int(top[1], min(top[1] + size[1], self.grid.height))
  721. ))
  722. # Don't place the object on top of another object
  723. if self.grid.get(*pos) != None:
  724. continue
  725. # Don't place the object where the agent is
  726. if np.array_equal(pos, self.agent_pos):
  727. continue
  728. # Check if there is a filtering criterion
  729. if reject_fn and reject_fn(self, pos):
  730. continue
  731. break
  732. self.grid.set(*pos, obj)
  733. if obj is not None:
  734. obj.init_pos = pos
  735. obj.cur_pos = pos
  736. return pos
  737. def put_obj(self, obj, i, j):
  738. """
  739. Put an object at a specific position in the grid
  740. """
  741. self.grid.set(i, j, obj)
  742. obj.init_pos = (i, j)
  743. obj.cur_pos = (i, j)
  744. def place_agent(
  745. self,
  746. top=None,
  747. size=None,
  748. rand_dir=True,
  749. max_tries=math.inf
  750. ):
  751. """
  752. Set the agent's starting point at an empty position in the grid
  753. """
  754. self.agent_pos = None
  755. pos = self.place_obj(None, top, size, max_tries=max_tries)
  756. self.agent_pos = pos
  757. if rand_dir:
  758. self.agent_dir = self._rand_int(0, 4)
  759. return pos
  760. @property
  761. def dir_vec(self):
  762. """
  763. Get the direction vector for the agent, pointing in the direction
  764. of forward movement.
  765. """
  766. assert self.agent_dir >= 0 and self.agent_dir < 4
  767. return DIR_TO_VEC[self.agent_dir]
  768. @property
  769. def right_vec(self):
  770. """
  771. Get the vector pointing to the right of the agent.
  772. """
  773. dx, dy = self.dir_vec
  774. return np.array((-dy, dx))
  775. @property
  776. def front_pos(self):
  777. """
  778. Get the position of the cell that is right in front of the agent
  779. """
  780. return self.agent_pos + self.dir_vec
  781. def get_view_coords(self, i, j):
  782. """
  783. Translate and rotate absolute grid coordinates (i, j) into the
  784. agent's partially observable view (sub-grid). Note that the resulting
  785. coordinates may be negative or outside of the agent's view size.
  786. """
  787. ax, ay = self.agent_pos
  788. dx, dy = self.dir_vec
  789. rx, ry = self.right_vec
  790. # Compute the absolute coordinates of the top-left view corner
  791. sz = self.agent_view_size
  792. hs = self.agent_view_size // 2
  793. tx = ax + (dx * (sz-1)) - (rx * hs)
  794. ty = ay + (dy * (sz-1)) - (ry * hs)
  795. lx = i - tx
  796. ly = j - ty
  797. # Project the coordinates of the object relative to the top-left
  798. # corner onto the agent's own coordinate system
  799. vx = (rx*lx + ry*ly)
  800. vy = -(dx*lx + dy*ly)
  801. return vx, vy
  802. def get_view_exts(self):
  803. """
  804. Get the extents of the square set of tiles visible to the agent
  805. Note: the bottom extent indices are not included in the set
  806. """
  807. # Facing right
  808. if self.agent_dir == 0:
  809. topX = self.agent_pos[0]
  810. topY = self.agent_pos[1] - self.agent_view_size // 2
  811. # Facing down
  812. elif self.agent_dir == 1:
  813. topX = self.agent_pos[0] - self.agent_view_size // 2
  814. topY = self.agent_pos[1]
  815. # Facing left
  816. elif self.agent_dir == 2:
  817. topX = self.agent_pos[0] - self.agent_view_size + 1
  818. topY = self.agent_pos[1] - self.agent_view_size // 2
  819. # Facing up
  820. elif self.agent_dir == 3:
  821. topX = self.agent_pos[0] - self.agent_view_size // 2
  822. topY = self.agent_pos[1] - self.agent_view_size + 1
  823. else:
  824. assert False, "invalid agent direction"
  825. botX = topX + self.agent_view_size
  826. botY = topY + self.agent_view_size
  827. return (topX, topY, botX, botY)
  828. def relative_coords(self, x, y):
  829. """
  830. Check if a grid position belongs to the agent's field of view, and returns the corresponding coordinates
  831. """
  832. vx, vy = self.get_view_coords(x, y)
  833. if vx < 0 or vy < 0 or vx >= self.agent_view_size or vy >= self.agent_view_size:
  834. return None
  835. return vx, vy
  836. def in_view(self, x, y):
  837. """
  838. check if a grid position is visible to the agent
  839. """
  840. return self.relative_coords(x, y) is not None
  841. def agent_sees(self, x, y):
  842. """
  843. Check if a non-empty grid position is visible to the agent
  844. """
  845. coordinates = self.relative_coords(x, y)
  846. if coordinates is None:
  847. return False
  848. vx, vy = coordinates
  849. obs = self.gen_obs()
  850. obs_grid, _ = Grid.decode(obs['image'])
  851. obs_cell = obs_grid.get(vx, vy)
  852. world_cell = self.grid.get(x, y)
  853. return obs_cell is not None and obs_cell.type == world_cell.type
  854. def step(self, action):
  855. self.step_count += 1
  856. reward = 0
  857. done = False
  858. # Get the position in front of the agent
  859. fwd_pos = self.front_pos
  860. # Get the contents of the cell in front of the agent
  861. fwd_cell = self.grid.get(*fwd_pos)
  862. # Rotate left
  863. if action == self.actions.left:
  864. self.agent_dir -= 1
  865. if self.agent_dir < 0:
  866. self.agent_dir += 4
  867. # Rotate right
  868. elif action == self.actions.right:
  869. self.agent_dir = (self.agent_dir + 1) % 4
  870. # Move forward
  871. elif action == self.actions.forward:
  872. if fwd_cell == None or fwd_cell.can_overlap():
  873. self.agent_pos = fwd_pos
  874. if fwd_cell != None and fwd_cell.type == 'goal':
  875. done = True
  876. reward = self._reward()
  877. if fwd_cell != None and fwd_cell.type == 'lava':
  878. done = True
  879. # Pick up an object
  880. elif action == self.actions.pickup:
  881. if fwd_cell and fwd_cell.can_pickup():
  882. if self.carrying is None:
  883. self.carrying = fwd_cell
  884. self.carrying.cur_pos = np.array([-1, -1])
  885. self.grid.set(*fwd_pos, None)
  886. # Drop an object
  887. elif action == self.actions.drop:
  888. if not fwd_cell and self.carrying:
  889. self.grid.set(*fwd_pos, self.carrying)
  890. self.carrying.cur_pos = fwd_pos
  891. self.carrying = None
  892. # Toggle/activate an object
  893. elif action == self.actions.toggle:
  894. if fwd_cell:
  895. fwd_cell.toggle(self, fwd_pos)
  896. # Done action (not used by default)
  897. elif action == self.actions.done:
  898. pass
  899. else:
  900. assert False, "unknown action"
  901. if self.step_count >= self.max_steps:
  902. done = True
  903. obs = self.gen_obs()
  904. return obs, reward, done, {}
  905. def gen_obs_grid(self):
  906. """
  907. Generate the sub-grid observed by the agent.
  908. This method also outputs a visibility mask telling us which grid
  909. cells the agent can actually see.
  910. """
  911. topX, topY, botX, botY = self.get_view_exts()
  912. grid = self.grid.slice(topX, topY, self.agent_view_size, self.agent_view_size)
  913. for i in range(self.agent_dir + 1):
  914. grid = grid.rotate_left()
  915. # Process occluders and visibility
  916. # Note that this incurs some performance cost
  917. if not self.see_through_walls:
  918. vis_mask = grid.process_vis(agent_pos=(self.agent_view_size // 2 , self.agent_view_size - 1))
  919. else:
  920. vis_mask = np.ones(shape=(grid.width, grid.height), dtype=bool)
  921. # Make it so the agent sees what it's carrying
  922. # We do this by placing the carried object at the agent's position
  923. # in the agent's partially observable view
  924. agent_pos = grid.width // 2, grid.height - 1
  925. if self.carrying:
  926. grid.set(*agent_pos, self.carrying)
  927. else:
  928. grid.set(*agent_pos, None)
  929. return grid, vis_mask
  930. def gen_obs(self):
  931. """
  932. Generate the agent's view (partially observable, low-resolution encoding)
  933. """
  934. grid, vis_mask = self.gen_obs_grid()
  935. # Encode the partially observable view into a numpy array
  936. image = grid.encode(vis_mask)
  937. assert hasattr(self, 'mission'), "environments must define a textual mission string"
  938. # Observations are dictionaries containing:
  939. # - an image (partially observable view of the environment)
  940. # - the agent's direction/orientation (acting as a compass)
  941. # - a textual mission string (instructions for the agent)
  942. obs = {
  943. 'image': image,
  944. 'direction': self.agent_dir,
  945. 'mission': self.mission
  946. }
  947. return obs
  948. def get_obs_render(self, obs, tile_size=TILE_PIXELS//2):
  949. """
  950. Render an agent observation for visualization
  951. """
  952. grid, vis_mask = Grid.decode(obs)
  953. # Render the whole grid
  954. img = grid.render(
  955. tile_size,
  956. agent_pos=(self.agent_view_size // 2, self.agent_view_size - 1),
  957. agent_dir=3,
  958. highlight_mask=vis_mask
  959. )
  960. return img
  961. def render(self, mode='human', close=False, highlight=True, tile_size=TILE_PIXELS):
  962. """
  963. Render the whole-grid human view
  964. """
  965. if close:
  966. if self.window:
  967. self.window.close()
  968. return
  969. if mode == 'human' and not self.window:
  970. import gym_minigrid.window
  971. self.window = gym_minigrid.window.Window('gym_minigrid')
  972. self.window.show(block=False)
  973. # Compute which cells are visible to the agent
  974. _, vis_mask = self.gen_obs_grid()
  975. # Compute the world coordinates of the bottom-left corner
  976. # of the agent's view area
  977. f_vec = self.dir_vec
  978. r_vec = self.right_vec
  979. top_left = self.agent_pos + f_vec * (self.agent_view_size-1) - r_vec * (self.agent_view_size // 2)
  980. # Mask of which cells to highlight
  981. highlight_mask = np.zeros(shape=(self.width, self.height), dtype=bool)
  982. # For each cell in the visibility mask
  983. for vis_j in range(0, self.agent_view_size):
  984. for vis_i in range(0, self.agent_view_size):
  985. # If this cell is not visible, don't highlight it
  986. if not vis_mask[vis_i, vis_j]:
  987. continue
  988. # Compute the world coordinates of this cell
  989. abs_i, abs_j = top_left - (f_vec * vis_j) + (r_vec * vis_i)
  990. if abs_i < 0 or abs_i >= self.width:
  991. continue
  992. if abs_j < 0 or abs_j >= self.height:
  993. continue
  994. # Mark this cell to be highlighted
  995. highlight_mask[abs_i, abs_j] = True
  996. # Render the whole grid
  997. img = self.grid.render(
  998. tile_size,
  999. self.agent_pos,
  1000. self.agent_dir,
  1001. highlight_mask=highlight_mask if highlight else None
  1002. )
  1003. if mode == 'human':
  1004. self.window.set_caption(self.mission)
  1005. self.window.show_img(img)
  1006. return img
  1007. def close(self):
  1008. if self.window:
  1009. self.window.close()
  1010. return