minigrid.py 37 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334
  1. import math
  2. import hashlib
  3. import string
  4. import gym
  5. from enum import IntEnum
  6. import numpy as np
  7. from gym import error, spaces, utils
  8. from .rendering import *
  9. # Size in pixels of a tile in the full-scale human view
  10. TILE_PIXELS = 32
  11. # Map of color names to RGB values
  12. COLORS = {
  13. 'red' : np.array([255, 0, 0]),
  14. 'green' : np.array([0, 255, 0]),
  15. 'blue' : np.array([0, 0, 255]),
  16. 'purple': np.array([112, 39, 195]),
  17. 'yellow': np.array([255, 255, 0]),
  18. 'grey' : np.array([100, 100, 100])
  19. }
  20. COLOR_NAMES = sorted(list(COLORS.keys()))
  21. # Used to map colors to integers
  22. COLOR_TO_IDX = {
  23. 'red' : 0,
  24. 'green' : 1,
  25. 'blue' : 2,
  26. 'purple': 3,
  27. 'yellow': 4,
  28. 'grey' : 5
  29. }
  30. IDX_TO_COLOR = dict(zip(COLOR_TO_IDX.values(), COLOR_TO_IDX.keys()))
  31. # Map of object type to integers
  32. OBJECT_TO_IDX = {
  33. 'unseen' : 0,
  34. 'empty' : 1,
  35. 'wall' : 2,
  36. 'floor' : 3,
  37. 'door' : 4,
  38. 'key' : 5,
  39. 'ball' : 6,
  40. 'box' : 7,
  41. 'goal' : 8,
  42. 'lava' : 9,
  43. 'agent' : 10,
  44. }
  45. IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys()))
  46. # Map of state names to integers
  47. STATE_TO_IDX = {
  48. 'open' : 0,
  49. 'closed': 1,
  50. 'locked': 2,
  51. }
  52. # Map of agent direction indices to vectors
  53. DIR_TO_VEC = [
  54. # Pointing right (positive X)
  55. np.array((1, 0)),
  56. # Down (positive Y)
  57. np.array((0, 1)),
  58. # Pointing left (negative X)
  59. np.array((-1, 0)),
  60. # Up (negative Y)
  61. np.array((0, -1)),
  62. ]
  63. class WorldObj:
  64. """
  65. Base class for grid world objects
  66. """
  67. def __init__(self, type, color):
  68. assert type in OBJECT_TO_IDX, type
  69. assert color in COLOR_TO_IDX, color
  70. self.type = type
  71. self.color = color
  72. self.contains = None
  73. # Initial position of the object
  74. self.init_pos = None
  75. # Current position of the object
  76. self.cur_pos = None
  77. def can_overlap(self):
  78. """Can the agent overlap with this?"""
  79. return False
  80. def can_pickup(self):
  81. """Can the agent pick this up?"""
  82. return False
  83. def can_contain(self):
  84. """Can this contain another object?"""
  85. return False
  86. def see_behind(self):
  87. """Can the agent see behind this object?"""
  88. return True
  89. def toggle(self, env, pos):
  90. """Method to trigger/toggle an action this object performs"""
  91. return False
  92. def encode(self):
  93. """Encode the a description of this object as a 3-tuple of integers"""
  94. return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], 0)
  95. @staticmethod
  96. def decode(type_idx, color_idx, state):
  97. """Create an object from a 3-tuple state description"""
  98. obj_type = IDX_TO_OBJECT[type_idx]
  99. color = IDX_TO_COLOR[color_idx]
  100. if obj_type == 'empty' or obj_type == 'unseen':
  101. return None
  102. # State, 0: open, 1: closed, 2: locked
  103. is_open = state == 0
  104. is_locked = state == 2
  105. if obj_type == 'wall':
  106. v = Wall(color)
  107. elif obj_type == 'floor':
  108. v = Floor(color)
  109. elif obj_type == 'ball':
  110. v = Ball(color)
  111. elif obj_type == 'key':
  112. v = Key(color)
  113. elif obj_type == 'box':
  114. v = Box(color)
  115. elif obj_type == 'door':
  116. v = Door(color, is_open, is_locked)
  117. elif obj_type == 'goal':
  118. v = Goal()
  119. elif obj_type == 'lava':
  120. v = Lava()
  121. else:
  122. assert False, "unknown object type in decode '%s'" % obj_type
  123. return v
  124. def render(self, r):
  125. """Draw this object with the given renderer"""
  126. raise NotImplementedError
  127. class Goal(WorldObj):
  128. def __init__(self):
  129. super().__init__('goal', 'green')
  130. def can_overlap(self):
  131. return True
  132. def render(self, img):
  133. fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color])
  134. class Floor(WorldObj):
  135. """
  136. Colored floor tile the agent can walk over
  137. """
  138. def __init__(self, color='blue'):
  139. super().__init__('floor', color)
  140. def can_overlap(self):
  141. return True
  142. def render(self, img):
  143. # Give the floor a pale color
  144. color = COLORS[self.color] / 2
  145. fill_coords(img, point_in_rect(0.031, 1, 0.031, 1), color)
  146. class Lava(WorldObj):
  147. def __init__(self):
  148. super().__init__('lava', 'red')
  149. def can_overlap(self):
  150. return True
  151. def render(self, img):
  152. c = (255, 128, 0)
  153. # Background color
  154. fill_coords(img, point_in_rect(0, 1, 0, 1), c)
  155. # Little waves
  156. for i in range(3):
  157. ylo = 0.3 + 0.2 * i
  158. yhi = 0.4 + 0.2 * i
  159. fill_coords(img, point_in_line(0.1, ylo, 0.3, yhi, r=0.03), (0,0,0))
  160. fill_coords(img, point_in_line(0.3, yhi, 0.5, ylo, r=0.03), (0,0,0))
  161. fill_coords(img, point_in_line(0.5, ylo, 0.7, yhi, r=0.03), (0,0,0))
  162. fill_coords(img, point_in_line(0.7, yhi, 0.9, ylo, r=0.03), (0,0,0))
  163. class Wall(WorldObj):
  164. def __init__(self, color='grey'):
  165. super().__init__('wall', color)
  166. def see_behind(self):
  167. return False
  168. def render(self, img):
  169. fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color])
  170. class Door(WorldObj):
  171. def __init__(self, color, is_open=False, is_locked=False):
  172. super().__init__('door', color)
  173. self.is_open = is_open
  174. self.is_locked = is_locked
  175. def can_overlap(self):
  176. """The agent can only walk over this cell when the door is open"""
  177. return self.is_open
  178. def see_behind(self):
  179. return self.is_open
  180. def toggle(self, env, pos):
  181. # If the player has the right key to open the door
  182. if self.is_locked:
  183. if isinstance(env.carrying, Key) and env.carrying.color == self.color:
  184. self.is_locked = False
  185. self.is_open = True
  186. return True
  187. return False
  188. self.is_open = not self.is_open
  189. return True
  190. def encode(self):
  191. """Encode the a description of this object as a 3-tuple of integers"""
  192. # State, 0: open, 1: closed, 2: locked
  193. if self.is_open:
  194. state = 0
  195. elif self.is_locked:
  196. state = 2
  197. elif not self.is_open:
  198. state = 1
  199. return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], state)
  200. def render(self, img):
  201. c = COLORS[self.color]
  202. if self.is_open:
  203. fill_coords(img, point_in_rect(0.88, 1.00, 0.00, 1.00), c)
  204. fill_coords(img, point_in_rect(0.92, 0.96, 0.04, 0.96), (0,0,0))
  205. return
  206. # Door frame and door
  207. if self.is_locked:
  208. fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)
  209. fill_coords(img, point_in_rect(0.06, 0.94, 0.06, 0.94), 0.45 * np.array(c))
  210. # Draw key slot
  211. fill_coords(img, point_in_rect(0.52, 0.75, 0.50, 0.56), c)
  212. else:
  213. fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)
  214. fill_coords(img, point_in_rect(0.04, 0.96, 0.04, 0.96), (0,0,0))
  215. fill_coords(img, point_in_rect(0.08, 0.92, 0.08, 0.92), c)
  216. fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), (0,0,0))
  217. # Draw door handle
  218. fill_coords(img, point_in_circle(cx=0.75, cy=0.50, r=0.08), c)
  219. class Key(WorldObj):
  220. def __init__(self, color='blue'):
  221. super(Key, self).__init__('key', color)
  222. def can_pickup(self):
  223. return True
  224. def render(self, img):
  225. c = COLORS[self.color]
  226. # Vertical quad
  227. fill_coords(img, point_in_rect(0.50, 0.63, 0.31, 0.88), c)
  228. # Teeth
  229. fill_coords(img, point_in_rect(0.38, 0.50, 0.59, 0.66), c)
  230. fill_coords(img, point_in_rect(0.38, 0.50, 0.81, 0.88), c)
  231. # Ring
  232. fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.190), c)
  233. fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.064), (0,0,0))
  234. class Ball(WorldObj):
  235. def __init__(self, color='blue'):
  236. super(Ball, self).__init__('ball', color)
  237. def can_pickup(self):
  238. return True
  239. def render(self, img):
  240. fill_coords(img, point_in_circle(0.5, 0.5, 0.31), COLORS[self.color])
  241. class Box(WorldObj):
  242. def __init__(self, color, contains=None):
  243. super(Box, self).__init__('box', color)
  244. self.contains = contains
  245. def can_pickup(self):
  246. return True
  247. def render(self, img):
  248. c = COLORS[self.color]
  249. # Outline
  250. fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), c)
  251. fill_coords(img, point_in_rect(0.18, 0.82, 0.18, 0.82), (0,0,0))
  252. # Horizontal slit
  253. fill_coords(img, point_in_rect(0.16, 0.84, 0.47, 0.53), c)
  254. def toggle(self, env, pos):
  255. # Replace the box by its contents
  256. env.grid.set(*pos, self.contains)
  257. return True
  258. class Grid:
  259. """
  260. Represent a grid and operations on it
  261. """
  262. # Static cache of pre-renderer tiles
  263. tile_cache = {}
  264. def __init__(self, width, height):
  265. assert width >= 3
  266. assert height >= 3
  267. self.width = width
  268. self.height = height
  269. self.grid = [None] * width * height
  270. def __contains__(self, key):
  271. if isinstance(key, WorldObj):
  272. for e in self.grid:
  273. if e is key:
  274. return True
  275. elif isinstance(key, tuple):
  276. for e in self.grid:
  277. if e is None:
  278. continue
  279. if (e.color, e.type) == key:
  280. return True
  281. if key[0] is None and key[1] == e.type:
  282. return True
  283. return False
  284. def __eq__(self, other):
  285. grid1 = self.encode()
  286. grid2 = other.encode()
  287. return np.array_equal(grid2, grid1)
  288. def __ne__(self, other):
  289. return not self == other
  290. def copy(self):
  291. from copy import deepcopy
  292. return deepcopy(self)
  293. def set(self, i, j, v):
  294. assert i >= 0 and i < self.width
  295. assert j >= 0 and j < self.height
  296. self.grid[j * self.width + i] = v
  297. def get(self, i, j):
  298. assert i >= 0 and i < self.width
  299. assert j >= 0 and j < self.height
  300. return self.grid[j * self.width + i]
  301. def horz_wall(self, x, y, length=None, obj_type=Wall):
  302. if length is None:
  303. length = self.width - x
  304. for i in range(0, length):
  305. self.set(x + i, y, obj_type())
  306. def vert_wall(self, x, y, length=None, obj_type=Wall):
  307. if length is None:
  308. length = self.height - y
  309. for j in range(0, length):
  310. self.set(x, y + j, obj_type())
  311. def wall_rect(self, x, y, w, h):
  312. self.horz_wall(x, y, w)
  313. self.horz_wall(x, y+h-1, w)
  314. self.vert_wall(x, y, h)
  315. self.vert_wall(x+w-1, y, h)
  316. def rotate_left(self):
  317. """
  318. Rotate the grid to the left (counter-clockwise)
  319. """
  320. grid = Grid(self.height, self.width)
  321. for i in range(self.width):
  322. for j in range(self.height):
  323. v = self.get(i, j)
  324. grid.set(j, grid.height - 1 - i, v)
  325. return grid
  326. def slice(self, topX, topY, width, height):
  327. """
  328. Get a subset of the grid
  329. """
  330. grid = Grid(width, height)
  331. for j in range(0, height):
  332. for i in range(0, width):
  333. x = topX + i
  334. y = topY + j
  335. if x >= 0 and x < self.width and \
  336. y >= 0 and y < self.height:
  337. v = self.get(x, y)
  338. else:
  339. v = Wall()
  340. grid.set(i, j, v)
  341. return grid
  342. @classmethod
  343. def render_tile(
  344. cls,
  345. obj,
  346. agent_dir=None,
  347. highlight=False,
  348. tile_size=TILE_PIXELS,
  349. subdivs=3
  350. ):
  351. """
  352. Render a tile and cache the result
  353. """
  354. # Hash map lookup key for the cache
  355. key = (agent_dir, highlight, tile_size)
  356. key = obj.encode() + key if obj else key
  357. if key in cls.tile_cache:
  358. return cls.tile_cache[key]
  359. img = np.zeros(shape=(tile_size * subdivs, tile_size * subdivs, 3), dtype=np.uint8)
  360. # Draw the grid lines (top and left edges)
  361. fill_coords(img, point_in_rect(0, 0.031, 0, 1), (100, 100, 100))
  362. fill_coords(img, point_in_rect(0, 1, 0, 0.031), (100, 100, 100))
  363. if obj != None:
  364. obj.render(img)
  365. # Overlay the agent on top
  366. if agent_dir is not None:
  367. tri_fn = point_in_triangle(
  368. (0.12, 0.19),
  369. (0.87, 0.50),
  370. (0.12, 0.81),
  371. )
  372. # Rotate the agent based on its direction
  373. tri_fn = rotate_fn(tri_fn, cx=0.5, cy=0.5, theta=0.5*math.pi*agent_dir)
  374. fill_coords(img, tri_fn, (255, 0, 0))
  375. # Highlight the cell if needed
  376. if highlight:
  377. highlight_img(img)
  378. # Downsample the image to perform supersampling/anti-aliasing
  379. img = downsample(img, subdivs)
  380. # Cache the rendered tile
  381. cls.tile_cache[key] = img
  382. return img
  383. def render(
  384. self,
  385. tile_size,
  386. agent_pos=None,
  387. agent_dir=None,
  388. highlight_mask=None
  389. ):
  390. """
  391. Render this grid at a given scale
  392. :param r: target renderer object
  393. :param tile_size: tile size in pixels
  394. """
  395. if highlight_mask is None:
  396. highlight_mask = np.zeros(shape=(self.width, self.height), dtype=bool)
  397. # Compute the total grid size
  398. width_px = self.width * tile_size
  399. height_px = self.height * tile_size
  400. img = np.zeros(shape=(height_px, width_px, 3), dtype=np.uint8)
  401. # Render the grid
  402. for j in range(0, self.height):
  403. for i in range(0, self.width):
  404. cell = self.get(i, j)
  405. agent_here = np.array_equal(agent_pos, (i, j))
  406. tile_img = Grid.render_tile(
  407. cell,
  408. agent_dir=agent_dir if agent_here else None,
  409. highlight=highlight_mask[i, j],
  410. tile_size=tile_size
  411. )
  412. ymin = j * tile_size
  413. ymax = (j+1) * tile_size
  414. xmin = i * tile_size
  415. xmax = (i+1) * tile_size
  416. img[ymin:ymax, xmin:xmax, :] = tile_img
  417. return img
  418. def encode(self, vis_mask=None):
  419. """
  420. Produce a compact numpy encoding of the grid
  421. """
  422. if vis_mask is None:
  423. vis_mask = np.ones((self.width, self.height), dtype=bool)
  424. array = np.zeros((self.width, self.height, 3), dtype='uint8')
  425. for i in range(self.width):
  426. for j in range(self.height):
  427. if vis_mask[i, j]:
  428. v = self.get(i, j)
  429. if v is None:
  430. array[i, j, 0] = OBJECT_TO_IDX['empty']
  431. array[i, j, 1] = 0
  432. array[i, j, 2] = 0
  433. else:
  434. array[i, j, :] = v.encode()
  435. return array
  436. @staticmethod
  437. def decode(array):
  438. """
  439. Decode an array grid encoding back into a grid
  440. """
  441. width, height, channels = array.shape
  442. assert channels == 3
  443. vis_mask = np.ones(shape=(width, height), dtype=bool)
  444. grid = Grid(width, height)
  445. for i in range(width):
  446. for j in range(height):
  447. type_idx, color_idx, state = array[i, j]
  448. v = WorldObj.decode(type_idx, color_idx, state)
  449. grid.set(i, j, v)
  450. vis_mask[i, j] = (type_idx != OBJECT_TO_IDX['unseen'])
  451. return grid, vis_mask
  452. def process_vis(grid, agent_pos):
  453. mask = np.zeros(shape=(grid.width, grid.height), dtype=bool)
  454. mask[agent_pos[0], agent_pos[1]] = True
  455. for j in reversed(range(0, grid.height)):
  456. for i in range(0, grid.width-1):
  457. if not mask[i, j]:
  458. continue
  459. cell = grid.get(i, j)
  460. if cell and not cell.see_behind():
  461. continue
  462. mask[i+1, j] = True
  463. if j > 0:
  464. mask[i+1, j-1] = True
  465. mask[i, j-1] = True
  466. for i in reversed(range(1, grid.width)):
  467. if not mask[i, j]:
  468. continue
  469. cell = grid.get(i, j)
  470. if cell and not cell.see_behind():
  471. continue
  472. mask[i-1, j] = True
  473. if j > 0:
  474. mask[i-1, j-1] = True
  475. mask[i, j-1] = True
  476. for j in range(0, grid.height):
  477. for i in range(0, grid.width):
  478. if not mask[i, j]:
  479. grid.set(i, j, None)
  480. return mask
  481. class StringGymSpace(gym.spaces.space.Space):
  482. """
  483. A gym space that represents a string of characters of bounded length
  484. """
  485. def __init__(self, min_length=0, max_length=1000):
  486. super().__init__(shape=(), dtype='U')
  487. self.min_length = min_length
  488. self.max_length = max_length
  489. self.letters = string.ascii_letters + string.digits + ' .,!- '
  490. def sample(self):
  491. length = np.random.randint(self.min_length, self.max_length)
  492. string = ''.join(np.random.choice(list(self.letters), size=length))
  493. return string
  494. def contains(self, x):
  495. return isinstance(x, str) and len(x) >= self.min_length and len(x) <= self.max_length
  496. def __repr__(self):
  497. return "StringGymSpace(min_length={}, max_length={})".format(self.min_length, self.max_length)
  498. def __eq__(self, other):
  499. return (isinstance(other, StringGymSpace)
  500. and self.min_length == other.min_length
  501. and self.max_length == other.max_length
  502. and self.letters == other.letters
  503. )
  504. class MiniGridEnv(gym.Env):
  505. """
  506. 2D grid world game environment
  507. """
  508. metadata = {
  509. 'render_modes': ['human', 'rgb_array'],
  510. 'render_fps' : 10
  511. }
  512. # Enumeration of possible actions
  513. class Actions(IntEnum):
  514. # Turn left, turn right, move forward
  515. left = 0
  516. right = 1
  517. forward = 2
  518. # Pick up an object
  519. pickup = 3
  520. # Drop an object
  521. drop = 4
  522. # Toggle/activate an object
  523. toggle = 5
  524. # Done completing task
  525. done = 6
  526. def __init__(
  527. self,
  528. grid_size=None,
  529. width=None,
  530. height=None,
  531. max_steps=100,
  532. see_through_walls=False,
  533. agent_view_size=7,
  534. render_mode=None
  535. ):
  536. # Can't set both grid_size and width/height
  537. if grid_size:
  538. assert width == None and height == None
  539. width = grid_size
  540. height = grid_size
  541. # Action enumeration for this environment
  542. self.actions = MiniGridEnv.Actions
  543. # Actions are discrete integer values
  544. self.action_space = spaces.Discrete(len(self.actions))
  545. # Number of cells (width and height) in the agent view
  546. assert agent_view_size % 2 == 1
  547. assert agent_view_size >= 3
  548. self.agent_view_size = agent_view_size
  549. # Observations are dictionaries containing an
  550. # encoding of the grid and a textual 'mission' string
  551. self.observation_space = spaces.Box(
  552. low=0,
  553. high=255,
  554. shape=(self.agent_view_size, self.agent_view_size, 3),
  555. dtype='uint8'
  556. )
  557. self.observation_space = spaces.Dict({
  558. 'image': self.observation_space,
  559. 'direction': spaces.Discrete(4),
  560. 'mission': StringGymSpace(min_length=0, max_length=200),
  561. })
  562. # render mode
  563. self.render_mode = render_mode
  564. # Range of possible rewards
  565. self.reward_range = (0, 1)
  566. # Window to use for human rendering mode
  567. self.window = None
  568. # Environment configuration
  569. self.width = width
  570. self.height = height
  571. self.max_steps = max_steps
  572. self.see_through_walls = see_through_walls
  573. # Current position and direction of the agent
  574. self.agent_pos = None
  575. self.agent_dir = None
  576. # Initialize the state
  577. self.reset()
  578. def reset(self, *, seed=None, return_info=False, options=None):
  579. super().reset(seed=seed)
  580. # Current position and direction of the agent
  581. self.agent_pos = None
  582. self.agent_dir = None
  583. # Generate a new random grid at the start of each episode
  584. self._gen_grid(self.width, self.height)
  585. # These fields should be defined by _gen_grid
  586. assert self.agent_pos is not None
  587. assert self.agent_dir is not None
  588. # Check that the agent doesn't overlap with an object
  589. start_cell = self.grid.get(*self.agent_pos)
  590. assert start_cell is None or start_cell.can_overlap()
  591. # Item picked up, being carried, initially nothing
  592. self.carrying = None
  593. # Step count since episode start
  594. self.step_count = 0
  595. # Return first observation
  596. obs = self.gen_obs()
  597. return obs
  598. def hash(self, size=16):
  599. """Compute a hash that uniquely identifies the current state of the environment.
  600. :param size: Size of the hashing
  601. """
  602. sample_hash = hashlib.sha256()
  603. to_encode = [self.grid.encode().tolist(), self.agent_pos, self.agent_dir]
  604. for item in to_encode:
  605. sample_hash.update(str(item).encode('utf8'))
  606. return sample_hash.hexdigest()[:size]
  607. @property
  608. def steps_remaining(self):
  609. return self.max_steps - self.step_count
  610. def __str__(self):
  611. """
  612. Produce a pretty string of the environment's grid along with the agent.
  613. A grid cell is represented by 2-character string, the first one for
  614. the object and the second one for the color.
  615. """
  616. # Map of object types to short string
  617. OBJECT_TO_STR = {
  618. 'wall' : 'W',
  619. 'floor' : 'F',
  620. 'door' : 'D',
  621. 'key' : 'K',
  622. 'ball' : 'A',
  623. 'box' : 'B',
  624. 'goal' : 'G',
  625. 'lava' : 'V',
  626. }
  627. # Short string for opened door
  628. OPENDED_DOOR_IDS = '_'
  629. # Map agent's direction to short string
  630. AGENT_DIR_TO_STR = {
  631. 0: '>',
  632. 1: 'V',
  633. 2: '<',
  634. 3: '^'
  635. }
  636. str = ''
  637. for j in range(self.grid.height):
  638. for i in range(self.grid.width):
  639. if i == self.agent_pos[0] and j == self.agent_pos[1]:
  640. str += 2 * AGENT_DIR_TO_STR[self.agent_dir]
  641. continue
  642. c = self.grid.get(i, j)
  643. if c == None:
  644. str += ' '
  645. continue
  646. if c.type == 'door':
  647. if c.is_open:
  648. str += '__'
  649. elif c.is_locked:
  650. str += 'L' + c.color[0].upper()
  651. else:
  652. str += 'D' + c.color[0].upper()
  653. continue
  654. str += OBJECT_TO_STR[c.type] + c.color[0].upper()
  655. if j < self.grid.height - 1:
  656. str += '\n'
  657. return str
  658. def _gen_grid(self, width, height):
  659. assert False, "_gen_grid needs to be implemented by each environment"
  660. def _reward(self):
  661. """
  662. Compute the reward to be given upon success
  663. """
  664. return 1 - 0.9 * (self.step_count / self.max_steps)
  665. def _rand_int(self, low, high):
  666. """
  667. Generate random integer in [low,high[
  668. """
  669. return self.np_random.integers(low, high)
  670. def _rand_float(self, low, high):
  671. """
  672. Generate random float in [low,high[
  673. """
  674. return self.np_random.uniform(low, high)
  675. def _rand_bool(self):
  676. """
  677. Generate random boolean value
  678. """
  679. return (self.np_random.integers(0, 2) == 0)
  680. def _rand_elem(self, iterable):
  681. """
  682. Pick a random element in a list
  683. """
  684. lst = list(iterable)
  685. idx = self._rand_int(0, len(lst))
  686. return lst[idx]
  687. def _rand_subset(self, iterable, num_elems):
  688. """
  689. Sample a random subset of distinct elements of a list
  690. """
  691. lst = list(iterable)
  692. assert num_elems <= len(lst)
  693. out = []
  694. while len(out) < num_elems:
  695. elem = self._rand_elem(lst)
  696. lst.remove(elem)
  697. out.append(elem)
  698. return out
  699. def _rand_color(self):
  700. """
  701. Generate a random color name (string)
  702. """
  703. return self._rand_elem(COLOR_NAMES)
  704. def _rand_pos(self, xLow, xHigh, yLow, yHigh):
  705. """
  706. Generate a random (x,y) position tuple
  707. """
  708. return (
  709. self.np_random.integers(xLow, xHigh),
  710. self.np_random.integers(yLow, yHigh)
  711. )
  712. def place_obj(self,
  713. obj,
  714. top=None,
  715. size=None,
  716. reject_fn=None,
  717. max_tries=math.inf
  718. ):
  719. """
  720. Place an object at an empty position in the grid
  721. :param top: top-left position of the rectangle where to place
  722. :param size: size of the rectangle where to place
  723. :param reject_fn: function to filter out potential positions
  724. """
  725. if top is None:
  726. top = (0, 0)
  727. else:
  728. top = (max(top[0], 0), max(top[1], 0))
  729. if size is None:
  730. size = (self.grid.width, self.grid.height)
  731. num_tries = 0
  732. while True:
  733. # This is to handle with rare cases where rejection sampling
  734. # gets stuck in an infinite loop
  735. if num_tries > max_tries:
  736. raise RecursionError('rejection sampling failed in place_obj')
  737. num_tries += 1
  738. pos = np.array((
  739. self._rand_int(top[0], min(top[0] + size[0], self.grid.width)),
  740. self._rand_int(top[1], min(top[1] + size[1], self.grid.height))
  741. ))
  742. # Don't place the object on top of another object
  743. if self.grid.get(*pos) != None:
  744. continue
  745. # Don't place the object where the agent is
  746. if np.array_equal(pos, self.agent_pos):
  747. continue
  748. # Check if there is a filtering criterion
  749. if reject_fn and reject_fn(self, pos):
  750. continue
  751. break
  752. self.grid.set(*pos, obj)
  753. if obj is not None:
  754. obj.init_pos = pos
  755. obj.cur_pos = pos
  756. return pos
  757. def put_obj(self, obj, i, j):
  758. """
  759. Put an object at a specific position in the grid
  760. """
  761. self.grid.set(i, j, obj)
  762. obj.init_pos = (i, j)
  763. obj.cur_pos = (i, j)
  764. def place_agent(
  765. self,
  766. top=None,
  767. size=None,
  768. rand_dir=True,
  769. max_tries=math.inf
  770. ):
  771. """
  772. Set the agent's starting point at an empty position in the grid
  773. """
  774. self.agent_pos = None
  775. pos = self.place_obj(None, top, size, max_tries=max_tries)
  776. self.agent_pos = pos
  777. if rand_dir:
  778. self.agent_dir = self._rand_int(0, 4)
  779. return pos
  780. @property
  781. def dir_vec(self):
  782. """
  783. Get the direction vector for the agent, pointing in the direction
  784. of forward movement.
  785. """
  786. assert self.agent_dir >= 0 and self.agent_dir < 4
  787. return DIR_TO_VEC[self.agent_dir]
  788. @property
  789. def right_vec(self):
  790. """
  791. Get the vector pointing to the right of the agent.
  792. """
  793. dx, dy = self.dir_vec
  794. return np.array((-dy, dx))
  795. @property
  796. def front_pos(self):
  797. """
  798. Get the position of the cell that is right in front of the agent
  799. """
  800. return self.agent_pos + self.dir_vec
  801. def get_view_coords(self, i, j):
  802. """
  803. Translate and rotate absolute grid coordinates (i, j) into the
  804. agent's partially observable view (sub-grid). Note that the resulting
  805. coordinates may be negative or outside of the agent's view size.
  806. """
  807. ax, ay = self.agent_pos
  808. dx, dy = self.dir_vec
  809. rx, ry = self.right_vec
  810. # Compute the absolute coordinates of the top-left view corner
  811. sz = self.agent_view_size
  812. hs = self.agent_view_size // 2
  813. tx = ax + (dx * (sz-1)) - (rx * hs)
  814. ty = ay + (dy * (sz-1)) - (ry * hs)
  815. lx = i - tx
  816. ly = j - ty
  817. # Project the coordinates of the object relative to the top-left
  818. # corner onto the agent's own coordinate system
  819. vx = (rx*lx + ry*ly)
  820. vy = -(dx*lx + dy*ly)
  821. return vx, vy
  822. def get_view_exts(self, agent_view_size=None):
  823. """
  824. Get the extents of the square set of tiles visible to the agent
  825. Note: the bottom extent indices are not included in the set
  826. if agent_view_size is None, use self.agent_view_size
  827. """
  828. agent_view_size = agent_view_size or self.agent_view_size
  829. # Facing right
  830. if self.agent_dir == 0:
  831. topX = self.agent_pos[0]
  832. topY = self.agent_pos[1] - agent_view_size // 2
  833. # Facing down
  834. elif self.agent_dir == 1:
  835. topX = self.agent_pos[0] - agent_view_size // 2
  836. topY = self.agent_pos[1]
  837. # Facing left
  838. elif self.agent_dir == 2:
  839. topX = self.agent_pos[0] - agent_view_size + 1
  840. topY = self.agent_pos[1] - agent_view_size // 2
  841. # Facing up
  842. elif self.agent_dir == 3:
  843. topX = self.agent_pos[0] - agent_view_size // 2
  844. topY = self.agent_pos[1] - agent_view_size + 1
  845. else:
  846. assert False, "invalid agent direction"
  847. botX = topX + agent_view_size
  848. botY = topY + agent_view_size
  849. return (topX, topY, botX, botY)
  850. def relative_coords(self, x, y):
  851. """
  852. Check if a grid position belongs to the agent's field of view, and returns the corresponding coordinates
  853. """
  854. vx, vy = self.get_view_coords(x, y)
  855. if vx < 0 or vy < 0 or vx >= self.agent_view_size or vy >= self.agent_view_size:
  856. return None
  857. return vx, vy
  858. def in_view(self, x, y):
  859. """
  860. check if a grid position is visible to the agent
  861. """
  862. return self.relative_coords(x, y) is not None
  863. def agent_sees(self, x, y):
  864. """
  865. Check if a non-empty grid position is visible to the agent
  866. """
  867. coordinates = self.relative_coords(x, y)
  868. if coordinates is None:
  869. return False
  870. vx, vy = coordinates
  871. obs = self.gen_obs()
  872. obs_grid, _ = Grid.decode(obs['image'])
  873. obs_cell = obs_grid.get(vx, vy)
  874. world_cell = self.grid.get(x, y)
  875. return obs_cell is not None and obs_cell.type == world_cell.type
  876. def step(self, action):
  877. self.step_count += 1
  878. reward = 0
  879. done = False
  880. # Get the position in front of the agent
  881. fwd_pos = self.front_pos
  882. # Get the contents of the cell in front of the agent
  883. fwd_cell = self.grid.get(*fwd_pos)
  884. # Rotate left
  885. if action == self.actions.left:
  886. self.agent_dir -= 1
  887. if self.agent_dir < 0:
  888. self.agent_dir += 4
  889. # Rotate right
  890. elif action == self.actions.right:
  891. self.agent_dir = (self.agent_dir + 1) % 4
  892. # Move forward
  893. elif action == self.actions.forward:
  894. if fwd_cell == None or fwd_cell.can_overlap():
  895. self.agent_pos = fwd_pos
  896. if fwd_cell != None and fwd_cell.type == 'goal':
  897. done = True
  898. reward = self._reward()
  899. if fwd_cell != None and fwd_cell.type == 'lava':
  900. done = True
  901. # Pick up an object
  902. elif action == self.actions.pickup:
  903. if fwd_cell and fwd_cell.can_pickup():
  904. if self.carrying is None:
  905. self.carrying = fwd_cell
  906. self.carrying.cur_pos = np.array([-1, -1])
  907. self.grid.set(*fwd_pos, None)
  908. # Drop an object
  909. elif action == self.actions.drop:
  910. if not fwd_cell and self.carrying:
  911. self.grid.set(*fwd_pos, self.carrying)
  912. self.carrying.cur_pos = fwd_pos
  913. self.carrying = None
  914. # Toggle/activate an object
  915. elif action == self.actions.toggle:
  916. if fwd_cell:
  917. fwd_cell.toggle(self, fwd_pos)
  918. # Done action (not used by default)
  919. elif action == self.actions.done:
  920. pass
  921. else:
  922. assert False, "unknown action"
  923. if self.step_count >= self.max_steps:
  924. done = True
  925. obs = self.gen_obs()
  926. return obs, reward, done, {}
  927. def gen_obs_grid(self, agent_view_size=None):
  928. """
  929. Generate the sub-grid observed by the agent.
  930. This method also outputs a visibility mask telling us which grid
  931. cells the agent can actually see.
  932. if agent_view_size is None, self.agent_view_size is used
  933. """
  934. topX, topY, botX, botY = self.get_view_exts(agent_view_size)
  935. agent_view_size = agent_view_size or self.agent_view_size
  936. grid = self.grid.slice(topX, topY, agent_view_size, agent_view_size)
  937. for i in range(self.agent_dir + 1):
  938. grid = grid.rotate_left()
  939. # Process occluders and visibility
  940. # Note that this incurs some performance cost
  941. if not self.see_through_walls:
  942. vis_mask = grid.process_vis(agent_pos=(agent_view_size // 2 , agent_view_size - 1))
  943. else:
  944. vis_mask = np.ones(shape=(grid.width, grid.height), dtype=bool)
  945. # Make it so the agent sees what it's carrying
  946. # We do this by placing the carried object at the agent's position
  947. # in the agent's partially observable view
  948. agent_pos = grid.width // 2, grid.height - 1
  949. if self.carrying:
  950. grid.set(*agent_pos, self.carrying)
  951. else:
  952. grid.set(*agent_pos, None)
  953. return grid, vis_mask
  954. def gen_obs(self):
  955. """
  956. Generate the agent's view (partially observable, low-resolution encoding)
  957. """
  958. grid, vis_mask = self.gen_obs_grid()
  959. # Encode the partially observable view into a numpy array
  960. image = grid.encode(vis_mask)
  961. assert hasattr(self, 'mission'), "environments must define a textual mission string"
  962. # Observations are dictionaries containing:
  963. # - an image (partially observable view of the environment)
  964. # - the agent's direction/orientation (acting as a compass)
  965. # - a textual mission string (instructions for the agent)
  966. obs = {
  967. 'image': image,
  968. 'direction': self.agent_dir,
  969. 'mission': self.mission
  970. }
  971. return obs
  972. def get_obs_render(self, obs, tile_size=TILE_PIXELS//2):
  973. """
  974. Render an agent observation for visualization
  975. """
  976. grid, vis_mask = Grid.decode(obs)
  977. # Render the whole grid
  978. img = grid.render(
  979. tile_size,
  980. agent_pos=(self.agent_view_size // 2, self.agent_view_size - 1),
  981. agent_dir=3,
  982. highlight_mask=vis_mask
  983. )
  984. return img
  985. def render(self, mode='human', close=False, highlight=True, tile_size=TILE_PIXELS):
  986. """
  987. Render the whole-grid human view
  988. """
  989. if self.render_mode is not None:
  990. mode = self.render_mode
  991. if close:
  992. if self.window:
  993. self.window.close()
  994. return
  995. if mode == 'human' and not self.window:
  996. import gym_minigrid.window
  997. self.window = gym_minigrid.window.Window('gym_minigrid')
  998. self.window.show(block=False)
  999. # Compute which cells are visible to the agent
  1000. _, vis_mask = self.gen_obs_grid()
  1001. # Compute the world coordinates of the bottom-left corner
  1002. # of the agent's view area
  1003. f_vec = self.dir_vec
  1004. r_vec = self.right_vec
  1005. top_left = self.agent_pos + f_vec * (self.agent_view_size-1) - r_vec * (self.agent_view_size // 2)
  1006. # Mask of which cells to highlight
  1007. highlight_mask = np.zeros(shape=(self.width, self.height), dtype=bool)
  1008. # For each cell in the visibility mask
  1009. for vis_j in range(0, self.agent_view_size):
  1010. for vis_i in range(0, self.agent_view_size):
  1011. # If this cell is not visible, don't highlight it
  1012. if not vis_mask[vis_i, vis_j]:
  1013. continue
  1014. # Compute the world coordinates of this cell
  1015. abs_i, abs_j = top_left - (f_vec * vis_j) + (r_vec * vis_i)
  1016. if abs_i < 0 or abs_i >= self.width:
  1017. continue
  1018. if abs_j < 0 or abs_j >= self.height:
  1019. continue
  1020. # Mark this cell to be highlighted
  1021. highlight_mask[abs_i, abs_j] = True
  1022. # Render the whole grid
  1023. img = self.grid.render(
  1024. tile_size,
  1025. self.agent_pos,
  1026. self.agent_dir,
  1027. highlight_mask=highlight_mask if highlight else None
  1028. )
  1029. if mode == 'human':
  1030. self.window.set_caption(self.mission)
  1031. self.window.show_img(img)
  1032. return img
  1033. def close(self):
  1034. if self.window:
  1035. self.window.close()
  1036. return