minigrid.py 37 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328
  1. import math
  2. import hashlib
  3. import string
  4. import gym
  5. from enum import IntEnum
  6. import numpy as np
  7. from gym import error, spaces, utils
  8. from .rendering import *
  9. # Size in pixels of a tile in the full-scale human view
  10. TILE_PIXELS = 32
  11. # Map of color names to RGB values
  12. COLORS = {
  13. 'red' : np.array([255, 0, 0]),
  14. 'green' : np.array([0, 255, 0]),
  15. 'blue' : np.array([0, 0, 255]),
  16. 'purple': np.array([112, 39, 195]),
  17. 'yellow': np.array([255, 255, 0]),
  18. 'grey' : np.array([100, 100, 100])
  19. }
  20. COLOR_NAMES = sorted(list(COLORS.keys()))
  21. # Used to map colors to integers
  22. COLOR_TO_IDX = {
  23. 'red' : 0,
  24. 'green' : 1,
  25. 'blue' : 2,
  26. 'purple': 3,
  27. 'yellow': 4,
  28. 'grey' : 5
  29. }
  30. IDX_TO_COLOR = dict(zip(COLOR_TO_IDX.values(), COLOR_TO_IDX.keys()))
  31. # Map of object type to integers
  32. OBJECT_TO_IDX = {
  33. 'unseen' : 0,
  34. 'empty' : 1,
  35. 'wall' : 2,
  36. 'floor' : 3,
  37. 'door' : 4,
  38. 'key' : 5,
  39. 'ball' : 6,
  40. 'box' : 7,
  41. 'goal' : 8,
  42. 'lava' : 9,
  43. 'agent' : 10,
  44. }
  45. IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys()))
  46. # Map of state names to integers
  47. STATE_TO_IDX = {
  48. 'open' : 0,
  49. 'closed': 1,
  50. 'locked': 2,
  51. }
  52. # Map of agent direction indices to vectors
  53. DIR_TO_VEC = [
  54. # Pointing right (positive X)
  55. np.array((1, 0)),
  56. # Down (positive Y)
  57. np.array((0, 1)),
  58. # Pointing left (negative X)
  59. np.array((-1, 0)),
  60. # Up (negative Y)
  61. np.array((0, -1)),
  62. ]
  63. class WorldObj:
  64. """
  65. Base class for grid world objects
  66. """
  67. def __init__(self, type, color):
  68. assert type in OBJECT_TO_IDX, type
  69. assert color in COLOR_TO_IDX, color
  70. self.type = type
  71. self.color = color
  72. self.contains = None
  73. # Initial position of the object
  74. self.init_pos = None
  75. # Current position of the object
  76. self.cur_pos = None
  77. def can_overlap(self):
  78. """Can the agent overlap with this?"""
  79. return False
  80. def can_pickup(self):
  81. """Can the agent pick this up?"""
  82. return False
  83. def can_contain(self):
  84. """Can this contain another object?"""
  85. return False
  86. def see_behind(self):
  87. """Can the agent see behind this object?"""
  88. return True
  89. def toggle(self, env, pos):
  90. """Method to trigger/toggle an action this object performs"""
  91. return False
  92. def encode(self):
  93. """Encode the a description of this object as a 3-tuple of integers"""
  94. return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], 0)
  95. @staticmethod
  96. def decode(type_idx, color_idx, state):
  97. """Create an object from a 3-tuple state description"""
  98. obj_type = IDX_TO_OBJECT[type_idx]
  99. color = IDX_TO_COLOR[color_idx]
  100. if obj_type == 'empty' or obj_type == 'unseen':
  101. return None
  102. # State, 0: open, 1: closed, 2: locked
  103. is_open = state == 0
  104. is_locked = state == 2
  105. if obj_type == 'wall':
  106. v = Wall(color)
  107. elif obj_type == 'floor':
  108. v = Floor(color)
  109. elif obj_type == 'ball':
  110. v = Ball(color)
  111. elif obj_type == 'key':
  112. v = Key(color)
  113. elif obj_type == 'box':
  114. v = Box(color)
  115. elif obj_type == 'door':
  116. v = Door(color, is_open, is_locked)
  117. elif obj_type == 'goal':
  118. v = Goal()
  119. elif obj_type == 'lava':
  120. v = Lava()
  121. else:
  122. assert False, "unknown object type in decode '%s'" % obj_type
  123. return v
  124. def render(self, r):
  125. """Draw this object with the given renderer"""
  126. raise NotImplementedError
  127. class Goal(WorldObj):
  128. def __init__(self):
  129. super().__init__('goal', 'green')
  130. def can_overlap(self):
  131. return True
  132. def render(self, img):
  133. fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color])
  134. class Floor(WorldObj):
  135. """
  136. Colored floor tile the agent can walk over
  137. """
  138. def __init__(self, color='blue'):
  139. super().__init__('floor', color)
  140. def can_overlap(self):
  141. return True
  142. def render(self, img):
  143. # Give the floor a pale color
  144. color = COLORS[self.color] / 2
  145. fill_coords(img, point_in_rect(0.031, 1, 0.031, 1), color)
  146. class Lava(WorldObj):
  147. def __init__(self):
  148. super().__init__('lava', 'red')
  149. def can_overlap(self):
  150. return True
  151. def render(self, img):
  152. c = (255, 128, 0)
  153. # Background color
  154. fill_coords(img, point_in_rect(0, 1, 0, 1), c)
  155. # Little waves
  156. for i in range(3):
  157. ylo = 0.3 + 0.2 * i
  158. yhi = 0.4 + 0.2 * i
  159. fill_coords(img, point_in_line(0.1, ylo, 0.3, yhi, r=0.03), (0,0,0))
  160. fill_coords(img, point_in_line(0.3, yhi, 0.5, ylo, r=0.03), (0,0,0))
  161. fill_coords(img, point_in_line(0.5, ylo, 0.7, yhi, r=0.03), (0,0,0))
  162. fill_coords(img, point_in_line(0.7, yhi, 0.9, ylo, r=0.03), (0,0,0))
  163. class Wall(WorldObj):
  164. def __init__(self, color='grey'):
  165. super().__init__('wall', color)
  166. def see_behind(self):
  167. return False
  168. def render(self, img):
  169. fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color])
  170. class Door(WorldObj):
  171. def __init__(self, color, is_open=False, is_locked=False):
  172. super().__init__('door', color)
  173. self.is_open = is_open
  174. self.is_locked = is_locked
  175. def can_overlap(self):
  176. """The agent can only walk over this cell when the door is open"""
  177. return self.is_open
  178. def see_behind(self):
  179. return self.is_open
  180. def toggle(self, env, pos):
  181. # If the player has the right key to open the door
  182. if self.is_locked:
  183. if isinstance(env.carrying, Key) and env.carrying.color == self.color:
  184. self.is_locked = False
  185. self.is_open = True
  186. return True
  187. return False
  188. self.is_open = not self.is_open
  189. return True
  190. def encode(self):
  191. """Encode the a description of this object as a 3-tuple of integers"""
  192. # State, 0: open, 1: closed, 2: locked
  193. if self.is_open:
  194. state = 0
  195. elif self.is_locked:
  196. state = 2
  197. elif not self.is_open:
  198. state = 1
  199. return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], state)
  200. def render(self, img):
  201. c = COLORS[self.color]
  202. if self.is_open:
  203. fill_coords(img, point_in_rect(0.88, 1.00, 0.00, 1.00), c)
  204. fill_coords(img, point_in_rect(0.92, 0.96, 0.04, 0.96), (0,0,0))
  205. return
  206. # Door frame and door
  207. if self.is_locked:
  208. fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)
  209. fill_coords(img, point_in_rect(0.06, 0.94, 0.06, 0.94), 0.45 * np.array(c))
  210. # Draw key slot
  211. fill_coords(img, point_in_rect(0.52, 0.75, 0.50, 0.56), c)
  212. else:
  213. fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)
  214. fill_coords(img, point_in_rect(0.04, 0.96, 0.04, 0.96), (0,0,0))
  215. fill_coords(img, point_in_rect(0.08, 0.92, 0.08, 0.92), c)
  216. fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), (0,0,0))
  217. # Draw door handle
  218. fill_coords(img, point_in_circle(cx=0.75, cy=0.50, r=0.08), c)
  219. class Key(WorldObj):
  220. def __init__(self, color='blue'):
  221. super(Key, self).__init__('key', color)
  222. def can_pickup(self):
  223. return True
  224. def render(self, img):
  225. c = COLORS[self.color]
  226. # Vertical quad
  227. fill_coords(img, point_in_rect(0.50, 0.63, 0.31, 0.88), c)
  228. # Teeth
  229. fill_coords(img, point_in_rect(0.38, 0.50, 0.59, 0.66), c)
  230. fill_coords(img, point_in_rect(0.38, 0.50, 0.81, 0.88), c)
  231. # Ring
  232. fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.190), c)
  233. fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.064), (0,0,0))
  234. class Ball(WorldObj):
  235. def __init__(self, color='blue'):
  236. super(Ball, self).__init__('ball', color)
  237. def can_pickup(self):
  238. return True
  239. def render(self, img):
  240. fill_coords(img, point_in_circle(0.5, 0.5, 0.31), COLORS[self.color])
  241. class Box(WorldObj):
  242. def __init__(self, color, contains=None):
  243. super(Box, self).__init__('box', color)
  244. self.contains = contains
  245. def can_pickup(self):
  246. return True
  247. def render(self, img):
  248. c = COLORS[self.color]
  249. # Outline
  250. fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), c)
  251. fill_coords(img, point_in_rect(0.18, 0.82, 0.18, 0.82), (0,0,0))
  252. # Horizontal slit
  253. fill_coords(img, point_in_rect(0.16, 0.84, 0.47, 0.53), c)
  254. def toggle(self, env, pos):
  255. # Replace the box by its contents
  256. env.grid.set(*pos, self.contains)
  257. return True
  258. class Grid:
  259. """
  260. Represent a grid and operations on it
  261. """
  262. # Static cache of pre-renderer tiles
  263. tile_cache = {}
  264. def __init__(self, width, height):
  265. assert width >= 3
  266. assert height >= 3
  267. self.width = width
  268. self.height = height
  269. self.grid = [None] * width * height
  270. def __contains__(self, key):
  271. if isinstance(key, WorldObj):
  272. for e in self.grid:
  273. if e is key:
  274. return True
  275. elif isinstance(key, tuple):
  276. for e in self.grid:
  277. if e is None:
  278. continue
  279. if (e.color, e.type) == key:
  280. return True
  281. if key[0] is None and key[1] == e.type:
  282. return True
  283. return False
  284. def __eq__(self, other):
  285. grid1 = self.encode()
  286. grid2 = other.encode()
  287. return np.array_equal(grid2, grid1)
  288. def __ne__(self, other):
  289. return not self == other
  290. def copy(self):
  291. from copy import deepcopy
  292. return deepcopy(self)
  293. def set(self, i, j, v):
  294. assert i >= 0 and i < self.width
  295. assert j >= 0 and j < self.height
  296. self.grid[j * self.width + i] = v
  297. def get(self, i, j):
  298. assert i >= 0 and i < self.width
  299. assert j >= 0 and j < self.height
  300. return self.grid[j * self.width + i]
  301. def horz_wall(self, x, y, length=None, obj_type=Wall):
  302. if length is None:
  303. length = self.width - x
  304. for i in range(0, length):
  305. self.set(x + i, y, obj_type())
  306. def vert_wall(self, x, y, length=None, obj_type=Wall):
  307. if length is None:
  308. length = self.height - y
  309. for j in range(0, length):
  310. self.set(x, y + j, obj_type())
  311. def wall_rect(self, x, y, w, h):
  312. self.horz_wall(x, y, w)
  313. self.horz_wall(x, y+h-1, w)
  314. self.vert_wall(x, y, h)
  315. self.vert_wall(x+w-1, y, h)
  316. def rotate_left(self):
  317. """
  318. Rotate the grid to the left (counter-clockwise)
  319. """
  320. grid = Grid(self.height, self.width)
  321. for i in range(self.width):
  322. for j in range(self.height):
  323. v = self.get(i, j)
  324. grid.set(j, grid.height - 1 - i, v)
  325. return grid
  326. def slice(self, topX, topY, width, height):
  327. """
  328. Get a subset of the grid
  329. """
  330. grid = Grid(width, height)
  331. for j in range(0, height):
  332. for i in range(0, width):
  333. x = topX + i
  334. y = topY + j
  335. if x >= 0 and x < self.width and \
  336. y >= 0 and y < self.height:
  337. v = self.get(x, y)
  338. else:
  339. v = Wall()
  340. grid.set(i, j, v)
  341. return grid
  342. @classmethod
  343. def render_tile(
  344. cls,
  345. obj,
  346. agent_dir=None,
  347. highlight=False,
  348. tile_size=TILE_PIXELS,
  349. subdivs=3
  350. ):
  351. """
  352. Render a tile and cache the result
  353. """
  354. # Hash map lookup key for the cache
  355. key = (agent_dir, highlight, tile_size)
  356. key = obj.encode() + key if obj else key
  357. if key in cls.tile_cache:
  358. return cls.tile_cache[key]
  359. img = np.zeros(shape=(tile_size * subdivs, tile_size * subdivs, 3), dtype=np.uint8)
  360. # Draw the grid lines (top and left edges)
  361. fill_coords(img, point_in_rect(0, 0.031, 0, 1), (100, 100, 100))
  362. fill_coords(img, point_in_rect(0, 1, 0, 0.031), (100, 100, 100))
  363. if obj != None:
  364. obj.render(img)
  365. # Overlay the agent on top
  366. if agent_dir is not None:
  367. tri_fn = point_in_triangle(
  368. (0.12, 0.19),
  369. (0.87, 0.50),
  370. (0.12, 0.81),
  371. )
  372. # Rotate the agent based on its direction
  373. tri_fn = rotate_fn(tri_fn, cx=0.5, cy=0.5, theta=0.5*math.pi*agent_dir)
  374. fill_coords(img, tri_fn, (255, 0, 0))
  375. # Highlight the cell if needed
  376. if highlight:
  377. highlight_img(img)
  378. # Downsample the image to perform supersampling/anti-aliasing
  379. img = downsample(img, subdivs)
  380. # Cache the rendered tile
  381. cls.tile_cache[key] = img
  382. return img
  383. def render(
  384. self,
  385. tile_size,
  386. agent_pos=None,
  387. agent_dir=None,
  388. highlight_mask=None
  389. ):
  390. """
  391. Render this grid at a given scale
  392. :param r: target renderer object
  393. :param tile_size: tile size in pixels
  394. """
  395. if highlight_mask is None:
  396. highlight_mask = np.zeros(shape=(self.width, self.height), dtype=bool)
  397. # Compute the total grid size
  398. width_px = self.width * tile_size
  399. height_px = self.height * tile_size
  400. img = np.zeros(shape=(height_px, width_px, 3), dtype=np.uint8)
  401. # Render the grid
  402. for j in range(0, self.height):
  403. for i in range(0, self.width):
  404. cell = self.get(i, j)
  405. agent_here = np.array_equal(agent_pos, (i, j))
  406. tile_img = Grid.render_tile(
  407. cell,
  408. agent_dir=agent_dir if agent_here else None,
  409. highlight=highlight_mask[i, j],
  410. tile_size=tile_size
  411. )
  412. ymin = j * tile_size
  413. ymax = (j+1) * tile_size
  414. xmin = i * tile_size
  415. xmax = (i+1) * tile_size
  416. img[ymin:ymax, xmin:xmax, :] = tile_img
  417. return img
  418. def encode(self, vis_mask=None):
  419. """
  420. Produce a compact numpy encoding of the grid
  421. """
  422. if vis_mask is None:
  423. vis_mask = np.ones((self.width, self.height), dtype=bool)
  424. array = np.zeros((self.width, self.height, 3), dtype='uint8')
  425. for i in range(self.width):
  426. for j in range(self.height):
  427. if vis_mask[i, j]:
  428. v = self.get(i, j)
  429. if v is None:
  430. array[i, j, 0] = OBJECT_TO_IDX['empty']
  431. array[i, j, 1] = 0
  432. array[i, j, 2] = 0
  433. else:
  434. array[i, j, :] = v.encode()
  435. return array
  436. @staticmethod
  437. def decode(array):
  438. """
  439. Decode an array grid encoding back into a grid
  440. """
  441. width, height, channels = array.shape
  442. assert channels == 3
  443. vis_mask = np.ones(shape=(width, height), dtype=bool)
  444. grid = Grid(width, height)
  445. for i in range(width):
  446. for j in range(height):
  447. type_idx, color_idx, state = array[i, j]
  448. v = WorldObj.decode(type_idx, color_idx, state)
  449. grid.set(i, j, v)
  450. vis_mask[i, j] = (type_idx != OBJECT_TO_IDX['unseen'])
  451. return grid, vis_mask
  452. def process_vis(grid, agent_pos):
  453. mask = np.zeros(shape=(grid.width, grid.height), dtype=bool)
  454. mask[agent_pos[0], agent_pos[1]] = True
  455. for j in reversed(range(0, grid.height)):
  456. for i in range(0, grid.width-1):
  457. if not mask[i, j]:
  458. continue
  459. cell = grid.get(i, j)
  460. if cell and not cell.see_behind():
  461. continue
  462. mask[i+1, j] = True
  463. if j > 0:
  464. mask[i+1, j-1] = True
  465. mask[i, j-1] = True
  466. for i in reversed(range(1, grid.width)):
  467. if not mask[i, j]:
  468. continue
  469. cell = grid.get(i, j)
  470. if cell and not cell.see_behind():
  471. continue
  472. mask[i-1, j] = True
  473. if j > 0:
  474. mask[i-1, j-1] = True
  475. mask[i, j-1] = True
  476. for j in range(0, grid.height):
  477. for i in range(0, grid.width):
  478. if not mask[i, j]:
  479. grid.set(i, j, None)
  480. return mask
  481. class StringGymSpace(gym.spaces.space.Space):
  482. """
  483. A gym space that represents a string of characters of bounded length
  484. """
  485. def __init__(self, min_length=0, max_length=1000):
  486. self.min_length = min_length
  487. self.max_length = max_length
  488. self.letters = string.ascii_letters + string.digits + ' .,!- '
  489. self._shape = ()
  490. self.dtype = np.dtype('U')
  491. def sample(self):
  492. length = np.random.randint(self.min_length, self.max_length)
  493. string = ''.join(np.random.choice(list(self.letters), size=length))
  494. return string
  495. def contains(self, x):
  496. return isinstance(x, str) and len(x) >= self.min_length and len(x) <= self.max_length
  497. def __repr__(self):
  498. return "StringGymSpace(min_length={}, max_length={})".format(self.min_length, self.max_length)
  499. class MiniGridEnv(gym.Env):
  500. """
  501. 2D grid world game environment
  502. """
  503. metadata = {
  504. 'render_modes': ['human', 'rgb_array'],
  505. 'render_fps' : 10
  506. }
  507. # Enumeration of possible actions
  508. class Actions(IntEnum):
  509. # Turn left, turn right, move forward
  510. left = 0
  511. right = 1
  512. forward = 2
  513. # Pick up an object
  514. pickup = 3
  515. # Drop an object
  516. drop = 4
  517. # Toggle/activate an object
  518. toggle = 5
  519. # Done completing task
  520. done = 6
  521. def __init__(
  522. self,
  523. grid_size=None,
  524. width=None,
  525. height=None,
  526. max_steps=100,
  527. see_through_walls=False,
  528. agent_view_size=7,
  529. render_mode=None
  530. ):
  531. # Can't set both grid_size and width/height
  532. if grid_size:
  533. assert width == None and height == None
  534. width = grid_size
  535. height = grid_size
  536. # Action enumeration for this environment
  537. self.actions = MiniGridEnv.Actions
  538. # Actions are discrete integer values
  539. self.action_space = spaces.Discrete(len(self.actions))
  540. # Number of cells (width and height) in the agent view
  541. assert agent_view_size % 2 == 1
  542. assert agent_view_size >= 3
  543. self.agent_view_size = agent_view_size
  544. # Observations are dictionaries containing an
  545. # encoding of the grid and a textual 'mission' string
  546. self.observation_space = spaces.Box(
  547. low=0,
  548. high=255,
  549. shape=(self.agent_view_size, self.agent_view_size, 3),
  550. dtype='uint8'
  551. )
  552. self.observation_space = spaces.Dict({
  553. 'image': self.observation_space,
  554. 'direction': spaces.Discrete(4),
  555. 'mission': StringGymSpace(min_length=0, max_length=200),
  556. })
  557. # render mode
  558. self.render_mode = render_mode
  559. # Range of possible rewards
  560. self.reward_range = (0, 1)
  561. # Window to use for human rendering mode
  562. self.window = None
  563. # Environment configuration
  564. self.width = width
  565. self.height = height
  566. self.max_steps = max_steps
  567. self.see_through_walls = see_through_walls
  568. # Current position and direction of the agent
  569. self.agent_pos = None
  570. self.agent_dir = None
  571. # Initialize the state
  572. self.reset()
  573. def reset(self, *, seed=None, return_info=False, options=None):
  574. super().reset(seed=seed)
  575. # Current position and direction of the agent
  576. self.agent_pos = None
  577. self.agent_dir = None
  578. # Generate a new random grid at the start of each episode
  579. self._gen_grid(self.width, self.height)
  580. # These fields should be defined by _gen_grid
  581. assert self.agent_pos is not None
  582. assert self.agent_dir is not None
  583. # Check that the agent doesn't overlap with an object
  584. start_cell = self.grid.get(*self.agent_pos)
  585. assert start_cell is None or start_cell.can_overlap()
  586. # Item picked up, being carried, initially nothing
  587. self.carrying = None
  588. # Step count since episode start
  589. self.step_count = 0
  590. # Return first observation
  591. obs = self.gen_obs()
  592. return obs
  593. def hash(self, size=16):
  594. """Compute a hash that uniquely identifies the current state of the environment.
  595. :param size: Size of the hashing
  596. """
  597. sample_hash = hashlib.sha256()
  598. to_encode = [self.grid.encode().tolist(), self.agent_pos, self.agent_dir]
  599. for item in to_encode:
  600. sample_hash.update(str(item).encode('utf8'))
  601. return sample_hash.hexdigest()[:size]
  602. @property
  603. def steps_remaining(self):
  604. return self.max_steps - self.step_count
  605. def __str__(self):
  606. """
  607. Produce a pretty string of the environment's grid along with the agent.
  608. A grid cell is represented by 2-character string, the first one for
  609. the object and the second one for the color.
  610. """
  611. # Map of object types to short string
  612. OBJECT_TO_STR = {
  613. 'wall' : 'W',
  614. 'floor' : 'F',
  615. 'door' : 'D',
  616. 'key' : 'K',
  617. 'ball' : 'A',
  618. 'box' : 'B',
  619. 'goal' : 'G',
  620. 'lava' : 'V',
  621. }
  622. # Short string for opened door
  623. OPENDED_DOOR_IDS = '_'
  624. # Map agent's direction to short string
  625. AGENT_DIR_TO_STR = {
  626. 0: '>',
  627. 1: 'V',
  628. 2: '<',
  629. 3: '^'
  630. }
  631. str = ''
  632. for j in range(self.grid.height):
  633. for i in range(self.grid.width):
  634. if i == self.agent_pos[0] and j == self.agent_pos[1]:
  635. str += 2 * AGENT_DIR_TO_STR[self.agent_dir]
  636. continue
  637. c = self.grid.get(i, j)
  638. if c == None:
  639. str += ' '
  640. continue
  641. if c.type == 'door':
  642. if c.is_open:
  643. str += '__'
  644. elif c.is_locked:
  645. str += 'L' + c.color[0].upper()
  646. else:
  647. str += 'D' + c.color[0].upper()
  648. continue
  649. str += OBJECT_TO_STR[c.type] + c.color[0].upper()
  650. if j < self.grid.height - 1:
  651. str += '\n'
  652. return str
  653. def _gen_grid(self, width, height):
  654. assert False, "_gen_grid needs to be implemented by each environment"
  655. def _reward(self):
  656. """
  657. Compute the reward to be given upon success
  658. """
  659. return 1 - 0.9 * (self.step_count / self.max_steps)
  660. def _rand_int(self, low, high):
  661. """
  662. Generate random integer in [low,high[
  663. """
  664. return self.np_random.integers(low, high)
  665. def _rand_float(self, low, high):
  666. """
  667. Generate random float in [low,high[
  668. """
  669. return self.np_random.uniform(low, high)
  670. def _rand_bool(self):
  671. """
  672. Generate random boolean value
  673. """
  674. return (self.np_random.integers(0, 2) == 0)
  675. def _rand_elem(self, iterable):
  676. """
  677. Pick a random element in a list
  678. """
  679. lst = list(iterable)
  680. idx = self._rand_int(0, len(lst))
  681. return lst[idx]
  682. def _rand_subset(self, iterable, num_elems):
  683. """
  684. Sample a random subset of distinct elements of a list
  685. """
  686. lst = list(iterable)
  687. assert num_elems <= len(lst)
  688. out = []
  689. while len(out) < num_elems:
  690. elem = self._rand_elem(lst)
  691. lst.remove(elem)
  692. out.append(elem)
  693. return out
  694. def _rand_color(self):
  695. """
  696. Generate a random color name (string)
  697. """
  698. return self._rand_elem(COLOR_NAMES)
  699. def _rand_pos(self, xLow, xHigh, yLow, yHigh):
  700. """
  701. Generate a random (x,y) position tuple
  702. """
  703. return (
  704. self.np_random.integers(xLow, xHigh),
  705. self.np_random.integers(yLow, yHigh)
  706. )
  707. def place_obj(self,
  708. obj,
  709. top=None,
  710. size=None,
  711. reject_fn=None,
  712. max_tries=math.inf
  713. ):
  714. """
  715. Place an object at an empty position in the grid
  716. :param top: top-left position of the rectangle where to place
  717. :param size: size of the rectangle where to place
  718. :param reject_fn: function to filter out potential positions
  719. """
  720. if top is None:
  721. top = (0, 0)
  722. else:
  723. top = (max(top[0], 0), max(top[1], 0))
  724. if size is None:
  725. size = (self.grid.width, self.grid.height)
  726. num_tries = 0
  727. while True:
  728. # This is to handle with rare cases where rejection sampling
  729. # gets stuck in an infinite loop
  730. if num_tries > max_tries:
  731. raise RecursionError('rejection sampling failed in place_obj')
  732. num_tries += 1
  733. pos = np.array((
  734. self._rand_int(top[0], min(top[0] + size[0], self.grid.width)),
  735. self._rand_int(top[1], min(top[1] + size[1], self.grid.height))
  736. ))
  737. # Don't place the object on top of another object
  738. if self.grid.get(*pos) != None:
  739. continue
  740. # Don't place the object where the agent is
  741. if np.array_equal(pos, self.agent_pos):
  742. continue
  743. # Check if there is a filtering criterion
  744. if reject_fn and reject_fn(self, pos):
  745. continue
  746. break
  747. self.grid.set(*pos, obj)
  748. if obj is not None:
  749. obj.init_pos = pos
  750. obj.cur_pos = pos
  751. return pos
  752. def put_obj(self, obj, i, j):
  753. """
  754. Put an object at a specific position in the grid
  755. """
  756. self.grid.set(i, j, obj)
  757. obj.init_pos = (i, j)
  758. obj.cur_pos = (i, j)
  759. def place_agent(
  760. self,
  761. top=None,
  762. size=None,
  763. rand_dir=True,
  764. max_tries=math.inf
  765. ):
  766. """
  767. Set the agent's starting point at an empty position in the grid
  768. """
  769. self.agent_pos = None
  770. pos = self.place_obj(None, top, size, max_tries=max_tries)
  771. self.agent_pos = pos
  772. if rand_dir:
  773. self.agent_dir = self._rand_int(0, 4)
  774. return pos
  775. @property
  776. def dir_vec(self):
  777. """
  778. Get the direction vector for the agent, pointing in the direction
  779. of forward movement.
  780. """
  781. assert self.agent_dir >= 0 and self.agent_dir < 4
  782. return DIR_TO_VEC[self.agent_dir]
  783. @property
  784. def right_vec(self):
  785. """
  786. Get the vector pointing to the right of the agent.
  787. """
  788. dx, dy = self.dir_vec
  789. return np.array((-dy, dx))
  790. @property
  791. def front_pos(self):
  792. """
  793. Get the position of the cell that is right in front of the agent
  794. """
  795. return self.agent_pos + self.dir_vec
  796. def get_view_coords(self, i, j):
  797. """
  798. Translate and rotate absolute grid coordinates (i, j) into the
  799. agent's partially observable view (sub-grid). Note that the resulting
  800. coordinates may be negative or outside of the agent's view size.
  801. """
  802. ax, ay = self.agent_pos
  803. dx, dy = self.dir_vec
  804. rx, ry = self.right_vec
  805. # Compute the absolute coordinates of the top-left view corner
  806. sz = self.agent_view_size
  807. hs = self.agent_view_size // 2
  808. tx = ax + (dx * (sz-1)) - (rx * hs)
  809. ty = ay + (dy * (sz-1)) - (ry * hs)
  810. lx = i - tx
  811. ly = j - ty
  812. # Project the coordinates of the object relative to the top-left
  813. # corner onto the agent's own coordinate system
  814. vx = (rx*lx + ry*ly)
  815. vy = -(dx*lx + dy*ly)
  816. return vx, vy
  817. def get_view_exts(self, agent_view_size=None):
  818. """
  819. Get the extents of the square set of tiles visible to the agent
  820. Note: the bottom extent indices are not included in the set
  821. if agent_view_size is None, use self.agent_view_size
  822. """
  823. agent_view_size = agent_view_size or self.agent_view_size
  824. # Facing right
  825. if self.agent_dir == 0:
  826. topX = self.agent_pos[0]
  827. topY = self.agent_pos[1] - agent_view_size // 2
  828. # Facing down
  829. elif self.agent_dir == 1:
  830. topX = self.agent_pos[0] - agent_view_size // 2
  831. topY = self.agent_pos[1]
  832. # Facing left
  833. elif self.agent_dir == 2:
  834. topX = self.agent_pos[0] - agent_view_size + 1
  835. topY = self.agent_pos[1] - agent_view_size // 2
  836. # Facing up
  837. elif self.agent_dir == 3:
  838. topX = self.agent_pos[0] - agent_view_size // 2
  839. topY = self.agent_pos[1] - agent_view_size + 1
  840. else:
  841. assert False, "invalid agent direction"
  842. botX = topX + agent_view_size
  843. botY = topY + agent_view_size
  844. return (topX, topY, botX, botY)
  845. def relative_coords(self, x, y):
  846. """
  847. Check if a grid position belongs to the agent's field of view, and returns the corresponding coordinates
  848. """
  849. vx, vy = self.get_view_coords(x, y)
  850. if vx < 0 or vy < 0 or vx >= self.agent_view_size or vy >= self.agent_view_size:
  851. return None
  852. return vx, vy
  853. def in_view(self, x, y):
  854. """
  855. check if a grid position is visible to the agent
  856. """
  857. return self.relative_coords(x, y) is not None
  858. def agent_sees(self, x, y):
  859. """
  860. Check if a non-empty grid position is visible to the agent
  861. """
  862. coordinates = self.relative_coords(x, y)
  863. if coordinates is None:
  864. return False
  865. vx, vy = coordinates
  866. obs = self.gen_obs()
  867. obs_grid, _ = Grid.decode(obs['image'])
  868. obs_cell = obs_grid.get(vx, vy)
  869. world_cell = self.grid.get(x, y)
  870. return obs_cell is not None and obs_cell.type == world_cell.type
  871. def step(self, action):
  872. self.step_count += 1
  873. reward = 0
  874. done = False
  875. # Get the position in front of the agent
  876. fwd_pos = self.front_pos
  877. # Get the contents of the cell in front of the agent
  878. fwd_cell = self.grid.get(*fwd_pos)
  879. # Rotate left
  880. if action == self.actions.left:
  881. self.agent_dir -= 1
  882. if self.agent_dir < 0:
  883. self.agent_dir += 4
  884. # Rotate right
  885. elif action == self.actions.right:
  886. self.agent_dir = (self.agent_dir + 1) % 4
  887. # Move forward
  888. elif action == self.actions.forward:
  889. if fwd_cell == None or fwd_cell.can_overlap():
  890. self.agent_pos = fwd_pos
  891. if fwd_cell != None and fwd_cell.type == 'goal':
  892. done = True
  893. reward = self._reward()
  894. if fwd_cell != None and fwd_cell.type == 'lava':
  895. done = True
  896. # Pick up an object
  897. elif action == self.actions.pickup:
  898. if fwd_cell and fwd_cell.can_pickup():
  899. if self.carrying is None:
  900. self.carrying = fwd_cell
  901. self.carrying.cur_pos = np.array([-1, -1])
  902. self.grid.set(*fwd_pos, None)
  903. # Drop an object
  904. elif action == self.actions.drop:
  905. if not fwd_cell and self.carrying:
  906. self.grid.set(*fwd_pos, self.carrying)
  907. self.carrying.cur_pos = fwd_pos
  908. self.carrying = None
  909. # Toggle/activate an object
  910. elif action == self.actions.toggle:
  911. if fwd_cell:
  912. fwd_cell.toggle(self, fwd_pos)
  913. # Done action (not used by default)
  914. elif action == self.actions.done:
  915. pass
  916. else:
  917. assert False, "unknown action"
  918. if self.step_count >= self.max_steps:
  919. done = True
  920. obs = self.gen_obs()
  921. return obs, reward, done, {}
  922. def gen_obs_grid(self, agent_view_size=None):
  923. """
  924. Generate the sub-grid observed by the agent.
  925. This method also outputs a visibility mask telling us which grid
  926. cells the agent can actually see.
  927. if agent_view_size is None, self.agent_view_size is used
  928. """
  929. topX, topY, botX, botY = self.get_view_exts(agent_view_size)
  930. agent_view_size = agent_view_size or self.agent_view_size
  931. grid = self.grid.slice(topX, topY, agent_view_size, agent_view_size)
  932. for i in range(self.agent_dir + 1):
  933. grid = grid.rotate_left()
  934. # Process occluders and visibility
  935. # Note that this incurs some performance cost
  936. if not self.see_through_walls:
  937. vis_mask = grid.process_vis(agent_pos=(agent_view_size // 2 , agent_view_size - 1))
  938. else:
  939. vis_mask = np.ones(shape=(grid.width, grid.height), dtype=bool)
  940. # Make it so the agent sees what it's carrying
  941. # We do this by placing the carried object at the agent's position
  942. # in the agent's partially observable view
  943. agent_pos = grid.width // 2, grid.height - 1
  944. if self.carrying:
  945. grid.set(*agent_pos, self.carrying)
  946. else:
  947. grid.set(*agent_pos, None)
  948. return grid, vis_mask
  949. def gen_obs(self):
  950. """
  951. Generate the agent's view (partially observable, low-resolution encoding)
  952. """
  953. grid, vis_mask = self.gen_obs_grid()
  954. # Encode the partially observable view into a numpy array
  955. image = grid.encode(vis_mask)
  956. assert hasattr(self, 'mission'), "environments must define a textual mission string"
  957. # Observations are dictionaries containing:
  958. # - an image (partially observable view of the environment)
  959. # - the agent's direction/orientation (acting as a compass)
  960. # - a textual mission string (instructions for the agent)
  961. obs = {
  962. 'image': image,
  963. 'direction': self.agent_dir,
  964. 'mission': self.mission
  965. }
  966. return obs
  967. def get_obs_render(self, obs, tile_size=TILE_PIXELS//2):
  968. """
  969. Render an agent observation for visualization
  970. """
  971. grid, vis_mask = Grid.decode(obs)
  972. # Render the whole grid
  973. img = grid.render(
  974. tile_size,
  975. agent_pos=(self.agent_view_size // 2, self.agent_view_size - 1),
  976. agent_dir=3,
  977. highlight_mask=vis_mask
  978. )
  979. return img
  980. def render(self, mode='human', close=False, highlight=True, tile_size=TILE_PIXELS):
  981. """
  982. Render the whole-grid human view
  983. """
  984. if mode is None:
  985. mode = self.render_mode
  986. if close:
  987. if self.window:
  988. self.window.close()
  989. return
  990. if mode == 'human' and not self.window:
  991. import gym_minigrid.window
  992. self.window = gym_minigrid.window.Window('gym_minigrid')
  993. self.window.show(block=False)
  994. # Compute which cells are visible to the agent
  995. _, vis_mask = self.gen_obs_grid()
  996. # Compute the world coordinates of the bottom-left corner
  997. # of the agent's view area
  998. f_vec = self.dir_vec
  999. r_vec = self.right_vec
  1000. top_left = self.agent_pos + f_vec * (self.agent_view_size-1) - r_vec * (self.agent_view_size // 2)
  1001. # Mask of which cells to highlight
  1002. highlight_mask = np.zeros(shape=(self.width, self.height), dtype=bool)
  1003. # For each cell in the visibility mask
  1004. for vis_j in range(0, self.agent_view_size):
  1005. for vis_i in range(0, self.agent_view_size):
  1006. # If this cell is not visible, don't highlight it
  1007. if not vis_mask[vis_i, vis_j]:
  1008. continue
  1009. # Compute the world coordinates of this cell
  1010. abs_i, abs_j = top_left - (f_vec * vis_j) + (r_vec * vis_i)
  1011. if abs_i < 0 or abs_i >= self.width:
  1012. continue
  1013. if abs_j < 0 or abs_j >= self.height:
  1014. continue
  1015. # Mark this cell to be highlighted
  1016. highlight_mask[abs_i, abs_j] = True
  1017. # Render the whole grid
  1018. img = self.grid.render(
  1019. tile_size,
  1020. self.agent_pos,
  1021. self.agent_dir,
  1022. highlight_mask=highlight_mask if highlight else None
  1023. )
  1024. if mode == 'human':
  1025. self.window.set_caption(self.mission)
  1026. self.window.show_img(img)
  1027. return img
  1028. def close(self):
  1029. if self.window:
  1030. self.window.close()
  1031. return