minigrid.py 38 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331
  1. import hashlib
  2. import math
  3. import string
  4. from abc import abstractmethod
  5. from enum import IntEnum
  6. from functools import partial
  7. import gym
  8. import numpy as np
  9. from gym import spaces
  10. from gym.utils.renderer import Renderer
  11. # Size in pixels of a tile in the full-scale human view
  12. from gym_minigrid.rendering import (
  13. downsample,
  14. fill_coords,
  15. highlight_img,
  16. point_in_circle,
  17. point_in_line,
  18. point_in_rect,
  19. point_in_triangle,
  20. rotate_fn,
  21. )
  22. from gym_minigrid.window import Window
  23. TILE_PIXELS = 32
  24. # Map of color names to RGB values
  25. COLORS = {
  26. "red": np.array([255, 0, 0]),
  27. "green": np.array([0, 255, 0]),
  28. "blue": np.array([0, 0, 255]),
  29. "purple": np.array([112, 39, 195]),
  30. "yellow": np.array([255, 255, 0]),
  31. "grey": np.array([100, 100, 100]),
  32. }
  33. COLOR_NAMES = sorted(list(COLORS.keys()))
  34. # Used to map colors to integers
  35. COLOR_TO_IDX = {"red": 0, "green": 1, "blue": 2, "purple": 3, "yellow": 4, "grey": 5}
  36. IDX_TO_COLOR = dict(zip(COLOR_TO_IDX.values(), COLOR_TO_IDX.keys()))
  37. # Map of object type to integers
  38. OBJECT_TO_IDX = {
  39. "unseen": 0,
  40. "empty": 1,
  41. "wall": 2,
  42. "floor": 3,
  43. "door": 4,
  44. "key": 5,
  45. "ball": 6,
  46. "box": 7,
  47. "goal": 8,
  48. "lava": 9,
  49. "agent": 10,
  50. }
  51. IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys()))
  52. # Map of state names to integers
  53. STATE_TO_IDX = {
  54. "open": 0,
  55. "closed": 1,
  56. "locked": 2,
  57. }
  58. # Map of agent direction indices to vectors
  59. DIR_TO_VEC = [
  60. # Pointing right (positive X)
  61. np.array((1, 0)),
  62. # Down (positive Y)
  63. np.array((0, 1)),
  64. # Pointing left (negative X)
  65. np.array((-1, 0)),
  66. # Up (negative Y)
  67. np.array((0, -1)),
  68. ]
  69. class WorldObj:
  70. """
  71. Base class for grid world objects
  72. """
  73. def __init__(self, type, color):
  74. assert type in OBJECT_TO_IDX, type
  75. assert color in COLOR_TO_IDX, color
  76. self.type = type
  77. self.color = color
  78. self.contains = None
  79. # Initial position of the object
  80. self.init_pos = None
  81. # Current position of the object
  82. self.cur_pos = None
  83. def can_overlap(self):
  84. """Can the agent overlap with this?"""
  85. return False
  86. def can_pickup(self):
  87. """Can the agent pick this up?"""
  88. return False
  89. def can_contain(self):
  90. """Can this contain another object?"""
  91. return False
  92. def see_behind(self):
  93. """Can the agent see behind this object?"""
  94. return True
  95. def toggle(self, env, pos):
  96. """Method to trigger/toggle an action this object performs"""
  97. return False
  98. def encode(self):
  99. """Encode the a description of this object as a 3-tuple of integers"""
  100. return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], 0)
  101. @staticmethod
  102. def decode(type_idx, color_idx, state):
  103. """Create an object from a 3-tuple state description"""
  104. obj_type = IDX_TO_OBJECT[type_idx]
  105. color = IDX_TO_COLOR[color_idx]
  106. if obj_type == "empty" or obj_type == "unseen":
  107. return None
  108. # State, 0: open, 1: closed, 2: locked
  109. is_open = state == 0
  110. is_locked = state == 2
  111. if obj_type == "wall":
  112. v = Wall(color)
  113. elif obj_type == "floor":
  114. v = Floor(color)
  115. elif obj_type == "ball":
  116. v = Ball(color)
  117. elif obj_type == "key":
  118. v = Key(color)
  119. elif obj_type == "box":
  120. v = Box(color)
  121. elif obj_type == "door":
  122. v = Door(color, is_open, is_locked)
  123. elif obj_type == "goal":
  124. v = Goal()
  125. elif obj_type == "lava":
  126. v = Lava()
  127. else:
  128. assert False, "unknown object type in decode '%s'" % obj_type
  129. return v
  130. def render(self, r):
  131. """Draw this object with the given renderer"""
  132. raise NotImplementedError
  133. class Goal(WorldObj):
  134. def __init__(self):
  135. super().__init__("goal", "green")
  136. def can_overlap(self):
  137. return True
  138. def render(self, img):
  139. fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color])
  140. class Floor(WorldObj):
  141. """
  142. Colored floor tile the agent can walk over
  143. """
  144. def __init__(self, color="blue"):
  145. super().__init__("floor", color)
  146. def can_overlap(self):
  147. return True
  148. def render(self, img):
  149. # Give the floor a pale color
  150. color = COLORS[self.color] / 2
  151. fill_coords(img, point_in_rect(0.031, 1, 0.031, 1), color)
  152. class Lava(WorldObj):
  153. def __init__(self):
  154. super().__init__("lava", "red")
  155. def can_overlap(self):
  156. return True
  157. def render(self, img):
  158. c = (255, 128, 0)
  159. # Background color
  160. fill_coords(img, point_in_rect(0, 1, 0, 1), c)
  161. # Little waves
  162. for i in range(3):
  163. ylo = 0.3 + 0.2 * i
  164. yhi = 0.4 + 0.2 * i
  165. fill_coords(img, point_in_line(0.1, ylo, 0.3, yhi, r=0.03), (0, 0, 0))
  166. fill_coords(img, point_in_line(0.3, yhi, 0.5, ylo, r=0.03), (0, 0, 0))
  167. fill_coords(img, point_in_line(0.5, ylo, 0.7, yhi, r=0.03), (0, 0, 0))
  168. fill_coords(img, point_in_line(0.7, yhi, 0.9, ylo, r=0.03), (0, 0, 0))
  169. class Wall(WorldObj):
  170. def __init__(self, color="grey"):
  171. super().__init__("wall", color)
  172. def see_behind(self):
  173. return False
  174. def render(self, img):
  175. fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color])
  176. class Door(WorldObj):
  177. def __init__(self, color, is_open=False, is_locked=False):
  178. super().__init__("door", color)
  179. self.is_open = is_open
  180. self.is_locked = is_locked
  181. def can_overlap(self):
  182. """The agent can only walk over this cell when the door is open"""
  183. return self.is_open
  184. def see_behind(self):
  185. return self.is_open
  186. def toggle(self, env, pos):
  187. # If the player has the right key to open the door
  188. if self.is_locked:
  189. if isinstance(env.carrying, Key) and env.carrying.color == self.color:
  190. self.is_locked = False
  191. self.is_open = True
  192. return True
  193. return False
  194. self.is_open = not self.is_open
  195. return True
  196. def encode(self):
  197. """Encode the a description of this object as a 3-tuple of integers"""
  198. # State, 0: open, 1: closed, 2: locked
  199. if self.is_open:
  200. state = 0
  201. elif self.is_locked:
  202. state = 2
  203. # if door is closed and unlocked
  204. elif not self.is_open:
  205. state = 1
  206. else:
  207. raise ValueError(
  208. "There is no possible state encoding for the state:\n -Door Open: {}\n -Door Closed: {}\n -Door Locked: {}".format(
  209. self.is_open, not self.is_open, self.is_locked
  210. )
  211. )
  212. return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], state)
  213. def render(self, img):
  214. c = COLORS[self.color]
  215. if self.is_open:
  216. fill_coords(img, point_in_rect(0.88, 1.00, 0.00, 1.00), c)
  217. fill_coords(img, point_in_rect(0.92, 0.96, 0.04, 0.96), (0, 0, 0))
  218. return
  219. # Door frame and door
  220. if self.is_locked:
  221. fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)
  222. fill_coords(img, point_in_rect(0.06, 0.94, 0.06, 0.94), 0.45 * np.array(c))
  223. # Draw key slot
  224. fill_coords(img, point_in_rect(0.52, 0.75, 0.50, 0.56), c)
  225. else:
  226. fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)
  227. fill_coords(img, point_in_rect(0.04, 0.96, 0.04, 0.96), (0, 0, 0))
  228. fill_coords(img, point_in_rect(0.08, 0.92, 0.08, 0.92), c)
  229. fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), (0, 0, 0))
  230. # Draw door handle
  231. fill_coords(img, point_in_circle(cx=0.75, cy=0.50, r=0.08), c)
  232. class Key(WorldObj):
  233. def __init__(self, color="blue"):
  234. super().__init__("key", color)
  235. def can_pickup(self):
  236. return True
  237. def render(self, img):
  238. c = COLORS[self.color]
  239. # Vertical quad
  240. fill_coords(img, point_in_rect(0.50, 0.63, 0.31, 0.88), c)
  241. # Teeth
  242. fill_coords(img, point_in_rect(0.38, 0.50, 0.59, 0.66), c)
  243. fill_coords(img, point_in_rect(0.38, 0.50, 0.81, 0.88), c)
  244. # Ring
  245. fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.190), c)
  246. fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.064), (0, 0, 0))
  247. class Ball(WorldObj):
  248. def __init__(self, color="blue"):
  249. super().__init__("ball", color)
  250. def can_pickup(self):
  251. return True
  252. def render(self, img):
  253. fill_coords(img, point_in_circle(0.5, 0.5, 0.31), COLORS[self.color])
  254. class Box(WorldObj):
  255. def __init__(self, color, contains=None):
  256. super().__init__("box", color)
  257. self.contains = contains
  258. def can_pickup(self):
  259. return True
  260. def render(self, img):
  261. c = COLORS[self.color]
  262. # Outline
  263. fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), c)
  264. fill_coords(img, point_in_rect(0.18, 0.82, 0.18, 0.82), (0, 0, 0))
  265. # Horizontal slit
  266. fill_coords(img, point_in_rect(0.16, 0.84, 0.47, 0.53), c)
  267. def toggle(self, env, pos):
  268. # Replace the box by its contents
  269. env.grid.set(*pos, self.contains)
  270. return True
  271. class Grid:
  272. """
  273. Represent a grid and operations on it
  274. """
  275. # Static cache of pre-renderer tiles
  276. tile_cache = {}
  277. def __init__(self, width, height):
  278. assert width >= 3
  279. assert height >= 3
  280. self.width = width
  281. self.height = height
  282. self.grid = [None] * width * height
  283. def __contains__(self, key):
  284. if isinstance(key, WorldObj):
  285. for e in self.grid:
  286. if e is key:
  287. return True
  288. elif isinstance(key, tuple):
  289. for e in self.grid:
  290. if e is None:
  291. continue
  292. if (e.color, e.type) == key:
  293. return True
  294. if key[0] is None and key[1] == e.type:
  295. return True
  296. return False
  297. def __eq__(self, other):
  298. grid1 = self.encode()
  299. grid2 = other.encode()
  300. return np.array_equal(grid2, grid1)
  301. def __ne__(self, other):
  302. return not self == other
  303. def copy(self):
  304. from copy import deepcopy
  305. return deepcopy(self)
  306. def set(self, i, j, v):
  307. assert i >= 0 and i < self.width
  308. assert j >= 0 and j < self.height
  309. self.grid[j * self.width + i] = v
  310. def get(self, i, j):
  311. assert i >= 0 and i < self.width
  312. assert j >= 0 and j < self.height
  313. return self.grid[j * self.width + i]
  314. def horz_wall(self, x, y, length=None, obj_type=Wall):
  315. if length is None:
  316. length = self.width - x
  317. for i in range(0, length):
  318. self.set(x + i, y, obj_type())
  319. def vert_wall(self, x, y, length=None, obj_type=Wall):
  320. if length is None:
  321. length = self.height - y
  322. for j in range(0, length):
  323. self.set(x, y + j, obj_type())
  324. def wall_rect(self, x, y, w, h):
  325. self.horz_wall(x, y, w)
  326. self.horz_wall(x, y + h - 1, w)
  327. self.vert_wall(x, y, h)
  328. self.vert_wall(x + w - 1, y, h)
  329. def rotate_left(self):
  330. """
  331. Rotate the grid to the left (counter-clockwise)
  332. """
  333. grid = Grid(self.height, self.width)
  334. for i in range(self.width):
  335. for j in range(self.height):
  336. v = self.get(i, j)
  337. grid.set(j, grid.height - 1 - i, v)
  338. return grid
  339. def slice(self, topX, topY, width, height):
  340. """
  341. Get a subset of the grid
  342. """
  343. grid = Grid(width, height)
  344. for j in range(0, height):
  345. for i in range(0, width):
  346. x = topX + i
  347. y = topY + j
  348. if x >= 0 and x < self.width and y >= 0 and y < self.height:
  349. v = self.get(x, y)
  350. else:
  351. v = Wall()
  352. grid.set(i, j, v)
  353. return grid
  354. @classmethod
  355. def render_tile(
  356. cls, obj, agent_dir=None, highlight=False, tile_size=TILE_PIXELS, subdivs=3
  357. ):
  358. """
  359. Render a tile and cache the result
  360. """
  361. # Hash map lookup key for the cache
  362. key = (agent_dir, highlight, tile_size)
  363. key = obj.encode() + key if obj else key
  364. if key in cls.tile_cache:
  365. return cls.tile_cache[key]
  366. img = np.zeros(
  367. shape=(tile_size * subdivs, tile_size * subdivs, 3), dtype=np.uint8
  368. )
  369. # Draw the grid lines (top and left edges)
  370. fill_coords(img, point_in_rect(0, 0.031, 0, 1), (100, 100, 100))
  371. fill_coords(img, point_in_rect(0, 1, 0, 0.031), (100, 100, 100))
  372. if obj is not None:
  373. obj.render(img)
  374. # Overlay the agent on top
  375. if agent_dir is not None:
  376. tri_fn = point_in_triangle(
  377. (0.12, 0.19),
  378. (0.87, 0.50),
  379. (0.12, 0.81),
  380. )
  381. # Rotate the agent based on its direction
  382. tri_fn = rotate_fn(tri_fn, cx=0.5, cy=0.5, theta=0.5 * math.pi * agent_dir)
  383. fill_coords(img, tri_fn, (255, 0, 0))
  384. # Highlight the cell if needed
  385. if highlight:
  386. highlight_img(img)
  387. # Downsample the image to perform supersampling/anti-aliasing
  388. img = downsample(img, subdivs)
  389. # Cache the rendered tile
  390. cls.tile_cache[key] = img
  391. return img
  392. def render(self, tile_size, agent_pos=None, agent_dir=None, highlight_mask=None):
  393. """
  394. Render this grid at a given scale
  395. :param r: target renderer object
  396. :param tile_size: tile size in pixels
  397. """
  398. if highlight_mask is None:
  399. highlight_mask = np.zeros(shape=(self.width, self.height), dtype=bool)
  400. # Compute the total grid size
  401. width_px = self.width * tile_size
  402. height_px = self.height * tile_size
  403. img = np.zeros(shape=(height_px, width_px, 3), dtype=np.uint8)
  404. # Render the grid
  405. for j in range(0, self.height):
  406. for i in range(0, self.width):
  407. cell = self.get(i, j)
  408. agent_here = np.array_equal(agent_pos, (i, j))
  409. tile_img = Grid.render_tile(
  410. cell,
  411. agent_dir=agent_dir if agent_here else None,
  412. highlight=highlight_mask[i, j],
  413. tile_size=tile_size,
  414. )
  415. ymin = j * tile_size
  416. ymax = (j + 1) * tile_size
  417. xmin = i * tile_size
  418. xmax = (i + 1) * tile_size
  419. img[ymin:ymax, xmin:xmax, :] = tile_img
  420. return img
  421. def encode(self, vis_mask=None):
  422. """
  423. Produce a compact numpy encoding of the grid
  424. """
  425. if vis_mask is None:
  426. vis_mask = np.ones((self.width, self.height), dtype=bool)
  427. array = np.zeros((self.width, self.height, 3), dtype="uint8")
  428. for i in range(self.width):
  429. for j in range(self.height):
  430. if vis_mask[i, j]:
  431. v = self.get(i, j)
  432. if v is None:
  433. array[i, j, 0] = OBJECT_TO_IDX["empty"]
  434. array[i, j, 1] = 0
  435. array[i, j, 2] = 0
  436. else:
  437. array[i, j, :] = v.encode()
  438. return array
  439. @staticmethod
  440. def decode(array):
  441. """
  442. Decode an array grid encoding back into a grid
  443. """
  444. width, height, channels = array.shape
  445. assert channels == 3
  446. vis_mask = np.ones(shape=(width, height), dtype=bool)
  447. grid = Grid(width, height)
  448. for i in range(width):
  449. for j in range(height):
  450. type_idx, color_idx, state = array[i, j]
  451. v = WorldObj.decode(type_idx, color_idx, state)
  452. grid.set(i, j, v)
  453. vis_mask[i, j] = type_idx != OBJECT_TO_IDX["unseen"]
  454. return grid, vis_mask
  455. def process_vis(self, agent_pos):
  456. mask = np.zeros(shape=(self.width, self.height), dtype=bool)
  457. mask[agent_pos[0], agent_pos[1]] = True
  458. for j in reversed(range(0, self.height)):
  459. for i in range(0, self.width - 1):
  460. if not mask[i, j]:
  461. continue
  462. cell = self.get(i, j)
  463. if cell and not cell.see_behind():
  464. continue
  465. mask[i + 1, j] = True
  466. if j > 0:
  467. mask[i + 1, j - 1] = True
  468. mask[i, j - 1] = True
  469. for i in reversed(range(1, self.width)):
  470. if not mask[i, j]:
  471. continue
  472. cell = self.get(i, j)
  473. if cell and not cell.see_behind():
  474. continue
  475. mask[i - 1, j] = True
  476. if j > 0:
  477. mask[i - 1, j - 1] = True
  478. mask[i, j - 1] = True
  479. for j in range(0, self.height):
  480. for i in range(0, self.width):
  481. if not mask[i, j]:
  482. self.set(i, j, None)
  483. return mask
  484. class MiniGridEnv(gym.Env):
  485. """
  486. 2D grid world game environment
  487. """
  488. metadata = {
  489. # Deprecated: use 'render_modes' instead
  490. "render.modes": ["human", "rgb_array"],
  491. "video.frames_per_second": 10, # Deprecated: use 'render_fps' instead
  492. "render_modes": ["human", "rgb_array", "single_rgb_array"],
  493. "render_fps": 10,
  494. }
  495. # Enumeration of possible actions
  496. class Actions(IntEnum):
  497. # Turn left, turn right, move forward
  498. left = 0
  499. right = 1
  500. forward = 2
  501. # Pick up an object
  502. pickup = 3
  503. # Drop an object
  504. drop = 4
  505. # Toggle/activate an object
  506. toggle = 5
  507. # Done completing task
  508. done = 6
  509. def __init__(
  510. self,
  511. grid_size: int = None,
  512. width: int = None,
  513. height: int = None,
  514. max_steps: int = 100,
  515. see_through_walls: bool = False,
  516. agent_view_size: int = 7,
  517. render_mode: str = None,
  518. highlight: bool = True,
  519. tile_size: int = TILE_PIXELS,
  520. **kwargs
  521. ):
  522. # Can't set both grid_size and width/height
  523. if grid_size:
  524. assert width is None and height is None
  525. width = grid_size
  526. height = grid_size
  527. # Action enumeration for this environment
  528. self.actions = MiniGridEnv.Actions
  529. # Actions are discrete integer values
  530. self.action_space = spaces.Discrete(len(self.actions))
  531. # Number of cells (width and height) in the agent view
  532. assert agent_view_size % 2 == 1
  533. assert agent_view_size >= 3
  534. self.agent_view_size = agent_view_size
  535. # Observations are dictionaries containing an
  536. # encoding of the grid and a textual 'mission' string
  537. self.observation_space = spaces.Box(
  538. low=0,
  539. high=255,
  540. shape=(self.agent_view_size, self.agent_view_size, 3),
  541. dtype="uint8",
  542. )
  543. self.observation_space = spaces.Dict(
  544. {
  545. "image": self.observation_space,
  546. "direction": spaces.Discrete(4),
  547. "mission": spaces.Text(
  548. max_length=200,
  549. charset=string.ascii_letters + string.digits + " .,!-",
  550. ),
  551. }
  552. )
  553. # render mode
  554. self.render_mode = render_mode
  555. render_frame = partial(
  556. self._render,
  557. highlight=highlight,
  558. tile_size=tile_size,
  559. )
  560. self.renderer = Renderer(self.render_mode, render_frame)
  561. # Range of possible rewards
  562. self.reward_range = (0, 1)
  563. self.window: Window = None
  564. # Environment configuration
  565. self.width = width
  566. self.height = height
  567. self.max_steps = max_steps
  568. self.see_through_walls = see_through_walls
  569. # Current position and direction of the agent
  570. self.agent_pos: np.ndarray = None
  571. self.agent_dir: int = None
  572. # Initialize the state
  573. self.reset()
  574. def reset(self, *, seed=None, return_info=False, options=None):
  575. super().reset(seed=seed)
  576. # Current position and direction of the agent
  577. self.agent_pos = None
  578. self.agent_dir = None
  579. # Generate a new random grid at the start of each episode
  580. self._gen_grid(self.width, self.height)
  581. # These fields should be defined by _gen_grid
  582. assert self.agent_pos is not None
  583. assert self.agent_dir is not None
  584. # Check that the agent doesn't overlap with an object
  585. start_cell = self.grid.get(*self.agent_pos)
  586. assert start_cell is None or start_cell.can_overlap()
  587. # Item picked up, being carried, initially nothing
  588. self.carrying = None
  589. # Step count since episode start
  590. self.step_count = 0
  591. # Return first observation
  592. obs = self.gen_obs()
  593. self.renderer.reset()
  594. self.renderer.render_step()
  595. if not return_info:
  596. return obs
  597. else:
  598. return obs, {}
  599. def hash(self, size=16):
  600. """Compute a hash that uniquely identifies the current state of the environment.
  601. :param size: Size of the hashing
  602. """
  603. sample_hash = hashlib.sha256()
  604. to_encode = [self.grid.encode().tolist(), self.agent_pos, self.agent_dir]
  605. for item in to_encode:
  606. sample_hash.update(str(item).encode("utf8"))
  607. return sample_hash.hexdigest()[:size]
  608. @property
  609. def steps_remaining(self):
  610. return self.max_steps - self.step_count
  611. def __str__(self):
  612. """
  613. Produce a pretty string of the environment's grid along with the agent.
  614. A grid cell is represented by 2-character string, the first one for
  615. the object and the second one for the color.
  616. """
  617. # Map of object types to short string
  618. OBJECT_TO_STR = {
  619. "wall": "W",
  620. "floor": "F",
  621. "door": "D",
  622. "key": "K",
  623. "ball": "A",
  624. "box": "B",
  625. "goal": "G",
  626. "lava": "V",
  627. }
  628. # Map agent's direction to short string
  629. AGENT_DIR_TO_STR = {0: ">", 1: "V", 2: "<", 3: "^"}
  630. str = ""
  631. for j in range(self.grid.height):
  632. for i in range(self.grid.width):
  633. if i == self.agent_pos[0] and j == self.agent_pos[1]:
  634. str += 2 * AGENT_DIR_TO_STR[self.agent_dir]
  635. continue
  636. c = self.grid.get(i, j)
  637. if c is None:
  638. str += " "
  639. continue
  640. if c.type == "door":
  641. if c.is_open:
  642. str += "__"
  643. elif c.is_locked:
  644. str += "L" + c.color[0].upper()
  645. else:
  646. str += "D" + c.color[0].upper()
  647. continue
  648. str += OBJECT_TO_STR[c.type] + c.color[0].upper()
  649. if j < self.grid.height - 1:
  650. str += "\n"
  651. return str
  652. @abstractmethod
  653. def _gen_grid(self, width, height):
  654. pass
  655. def _reward(self):
  656. """
  657. Compute the reward to be given upon success
  658. """
  659. return 1 - 0.9 * (self.step_count / self.max_steps)
  660. def _rand_int(self, low, high):
  661. """
  662. Generate random integer in [low,high[
  663. """
  664. return self.np_random.integers(low, high)
  665. def _rand_float(self, low, high):
  666. """
  667. Generate random float in [low,high[
  668. """
  669. return self.np_random.uniform(low, high)
  670. def _rand_bool(self):
  671. """
  672. Generate random boolean value
  673. """
  674. return self.np_random.integers(0, 2) == 0
  675. def _rand_elem(self, iterable):
  676. """
  677. Pick a random element in a list
  678. """
  679. lst = list(iterable)
  680. idx = self._rand_int(0, len(lst))
  681. return lst[idx]
  682. def _rand_subset(self, iterable, num_elems):
  683. """
  684. Sample a random subset of distinct elements of a list
  685. """
  686. lst = list(iterable)
  687. assert num_elems <= len(lst)
  688. out = []
  689. while len(out) < num_elems:
  690. elem = self._rand_elem(lst)
  691. lst.remove(elem)
  692. out.append(elem)
  693. return out
  694. def _rand_color(self):
  695. """
  696. Generate a random color name (string)
  697. """
  698. return self._rand_elem(COLOR_NAMES)
  699. def _rand_pos(self, xLow, xHigh, yLow, yHigh):
  700. """
  701. Generate a random (x,y) position tuple
  702. """
  703. return (
  704. self.np_random.integers(xLow, xHigh),
  705. self.np_random.integers(yLow, yHigh),
  706. )
  707. def place_obj(self, obj, top=None, size=None, reject_fn=None, max_tries=math.inf):
  708. """
  709. Place an object at an empty position in the grid
  710. :param top: top-left position of the rectangle where to place
  711. :param size: size of the rectangle where to place
  712. :param reject_fn: function to filter out potential positions
  713. """
  714. if top is None:
  715. top = (0, 0)
  716. else:
  717. top = (max(top[0], 0), max(top[1], 0))
  718. if size is None:
  719. size = (self.grid.width, self.grid.height)
  720. num_tries = 0
  721. while True:
  722. # This is to handle with rare cases where rejection sampling
  723. # gets stuck in an infinite loop
  724. if num_tries > max_tries:
  725. raise RecursionError("rejection sampling failed in place_obj")
  726. num_tries += 1
  727. pos = np.array(
  728. (
  729. self._rand_int(top[0], min(top[0] + size[0], self.grid.width)),
  730. self._rand_int(top[1], min(top[1] + size[1], self.grid.height)),
  731. )
  732. )
  733. # Don't place the object on top of another object
  734. if self.grid.get(*pos) is not None:
  735. continue
  736. # Don't place the object where the agent is
  737. if np.array_equal(pos, self.agent_pos):
  738. continue
  739. # Check if there is a filtering criterion
  740. if reject_fn and reject_fn(self, pos):
  741. continue
  742. break
  743. self.grid.set(*pos, obj)
  744. if obj is not None:
  745. obj.init_pos = pos
  746. obj.cur_pos = pos
  747. return pos
  748. def put_obj(self, obj, i, j):
  749. """
  750. Put an object at a specific position in the grid
  751. """
  752. self.grid.set(i, j, obj)
  753. obj.init_pos = (i, j)
  754. obj.cur_pos = (i, j)
  755. def place_agent(self, top=None, size=None, rand_dir=True, max_tries=math.inf):
  756. """
  757. Set the agent's starting point at an empty position in the grid
  758. """
  759. self.agent_pos = None
  760. pos = self.place_obj(None, top, size, max_tries=max_tries)
  761. self.agent_pos = pos
  762. if rand_dir:
  763. self.agent_dir = self._rand_int(0, 4)
  764. return pos
  765. @property
  766. def dir_vec(self):
  767. """
  768. Get the direction vector for the agent, pointing in the direction
  769. of forward movement.
  770. """
  771. assert self.agent_dir >= 0 and self.agent_dir < 4
  772. return DIR_TO_VEC[self.agent_dir]
  773. @property
  774. def right_vec(self):
  775. """
  776. Get the vector pointing to the right of the agent.
  777. """
  778. dx, dy = self.dir_vec
  779. return np.array((-dy, dx))
  780. @property
  781. def front_pos(self):
  782. """
  783. Get the position of the cell that is right in front of the agent
  784. """
  785. return self.agent_pos + self.dir_vec
  786. def get_view_coords(self, i, j):
  787. """
  788. Translate and rotate absolute grid coordinates (i, j) into the
  789. agent's partially observable view (sub-grid). Note that the resulting
  790. coordinates may be negative or outside of the agent's view size.
  791. """
  792. ax, ay = self.agent_pos
  793. dx, dy = self.dir_vec
  794. rx, ry = self.right_vec
  795. # Compute the absolute coordinates of the top-left view corner
  796. sz = self.agent_view_size
  797. hs = self.agent_view_size // 2
  798. tx = ax + (dx * (sz - 1)) - (rx * hs)
  799. ty = ay + (dy * (sz - 1)) - (ry * hs)
  800. lx = i - tx
  801. ly = j - ty
  802. # Project the coordinates of the object relative to the top-left
  803. # corner onto the agent's own coordinate system
  804. vx = rx * lx + ry * ly
  805. vy = -(dx * lx + dy * ly)
  806. return vx, vy
  807. def get_view_exts(self, agent_view_size=None):
  808. """
  809. Get the extents of the square set of tiles visible to the agent
  810. Note: the bottom extent indices are not included in the set
  811. if agent_view_size is None, use self.agent_view_size
  812. """
  813. agent_view_size = agent_view_size or self.agent_view_size
  814. # Facing right
  815. if self.agent_dir == 0:
  816. topX = self.agent_pos[0]
  817. topY = self.agent_pos[1] - agent_view_size // 2
  818. # Facing down
  819. elif self.agent_dir == 1:
  820. topX = self.agent_pos[0] - agent_view_size // 2
  821. topY = self.agent_pos[1]
  822. # Facing left
  823. elif self.agent_dir == 2:
  824. topX = self.agent_pos[0] - agent_view_size + 1
  825. topY = self.agent_pos[1] - agent_view_size // 2
  826. # Facing up
  827. elif self.agent_dir == 3:
  828. topX = self.agent_pos[0] - agent_view_size // 2
  829. topY = self.agent_pos[1] - agent_view_size + 1
  830. else:
  831. assert False, "invalid agent direction"
  832. botX = topX + agent_view_size
  833. botY = topY + agent_view_size
  834. return (topX, topY, botX, botY)
  835. def relative_coords(self, x, y):
  836. """
  837. Check if a grid position belongs to the agent's field of view, and returns the corresponding coordinates
  838. """
  839. vx, vy = self.get_view_coords(x, y)
  840. if vx < 0 or vy < 0 or vx >= self.agent_view_size or vy >= self.agent_view_size:
  841. return None
  842. return vx, vy
  843. def in_view(self, x, y):
  844. """
  845. check if a grid position is visible to the agent
  846. """
  847. return self.relative_coords(x, y) is not None
  848. def agent_sees(self, x, y):
  849. """
  850. Check if a non-empty grid position is visible to the agent
  851. """
  852. coordinates = self.relative_coords(x, y)
  853. if coordinates is None:
  854. return False
  855. vx, vy = coordinates
  856. obs = self.gen_obs()
  857. obs_grid, _ = Grid.decode(obs["image"])
  858. obs_cell = obs_grid.get(vx, vy)
  859. world_cell = self.grid.get(x, y)
  860. return obs_cell is not None and obs_cell.type == world_cell.type
  861. def step(self, action):
  862. self.step_count += 1
  863. reward = 0
  864. done = False
  865. # Get the position in front of the agent
  866. fwd_pos = self.front_pos
  867. # Get the contents of the cell in front of the agent
  868. fwd_cell = self.grid.get(*fwd_pos)
  869. # Rotate left
  870. if action == self.actions.left:
  871. self.agent_dir -= 1
  872. if self.agent_dir < 0:
  873. self.agent_dir += 4
  874. # Rotate right
  875. elif action == self.actions.right:
  876. self.agent_dir = (self.agent_dir + 1) % 4
  877. # Move forward
  878. elif action == self.actions.forward:
  879. if fwd_cell is None or fwd_cell.can_overlap():
  880. self.agent_pos = fwd_pos
  881. if fwd_cell is not None and fwd_cell.type == "goal":
  882. done = True
  883. reward = self._reward()
  884. if fwd_cell is not None and fwd_cell.type == "lava":
  885. done = True
  886. # Pick up an object
  887. elif action == self.actions.pickup:
  888. if fwd_cell and fwd_cell.can_pickup():
  889. if self.carrying is None:
  890. self.carrying = fwd_cell
  891. self.carrying.cur_pos = np.array([-1, -1])
  892. self.grid.set(*fwd_pos, None)
  893. # Drop an object
  894. elif action == self.actions.drop:
  895. if not fwd_cell and self.carrying:
  896. self.grid.set(*fwd_pos, self.carrying)
  897. self.carrying.cur_pos = fwd_pos
  898. self.carrying = None
  899. # Toggle/activate an object
  900. elif action == self.actions.toggle:
  901. if fwd_cell:
  902. fwd_cell.toggle(self, fwd_pos)
  903. # Done action (not used by default)
  904. elif action == self.actions.done:
  905. pass
  906. else:
  907. assert False, "unknown action"
  908. if self.step_count >= self.max_steps:
  909. done = True
  910. obs = self.gen_obs()
  911. self.renderer.render_step()
  912. return obs, reward, done, {}
  913. def gen_obs_grid(self, agent_view_size=None):
  914. """
  915. Generate the sub-grid observed by the agent.
  916. This method also outputs a visibility mask telling us which grid
  917. cells the agent can actually see.
  918. if agent_view_size is None, self.agent_view_size is used
  919. """
  920. topX, topY, botX, botY = self.get_view_exts(agent_view_size)
  921. agent_view_size = agent_view_size or self.agent_view_size
  922. grid = self.grid.slice(topX, topY, agent_view_size, agent_view_size)
  923. for i in range(self.agent_dir + 1):
  924. grid = grid.rotate_left()
  925. # Process occluders and visibility
  926. # Note that this incurs some performance cost
  927. if not self.see_through_walls:
  928. vis_mask = grid.process_vis(
  929. agent_pos=(agent_view_size // 2, agent_view_size - 1)
  930. )
  931. else:
  932. vis_mask = np.ones(shape=(grid.width, grid.height), dtype=bool)
  933. # Make it so the agent sees what it's carrying
  934. # We do this by placing the carried object at the agent's position
  935. # in the agent's partially observable view
  936. agent_pos = grid.width // 2, grid.height - 1
  937. if self.carrying:
  938. grid.set(*agent_pos, self.carrying)
  939. else:
  940. grid.set(*agent_pos, None)
  941. return grid, vis_mask
  942. def gen_obs(self):
  943. """
  944. Generate the agent's view (partially observable, low-resolution encoding)
  945. """
  946. grid, vis_mask = self.gen_obs_grid()
  947. # Encode the partially observable view into a numpy array
  948. image = grid.encode(vis_mask)
  949. assert hasattr(
  950. self, "mission"
  951. ), "environments must define a textual mission string"
  952. # Observations are dictionaries containing:
  953. # - an image (partially observable view of the environment)
  954. # - the agent's direction/orientation (acting as a compass)
  955. # - a textual mission string (instructions for the agent)
  956. obs = {"image": image, "direction": self.agent_dir, "mission": self.mission}
  957. return obs
  958. def get_obs_render(self, obs, tile_size=TILE_PIXELS // 2):
  959. """
  960. Render an agent observation for visualization
  961. """
  962. grid, vis_mask = Grid.decode(obs)
  963. # Render the whole grid
  964. img = grid.render(
  965. tile_size,
  966. agent_pos=(self.agent_view_size // 2, self.agent_view_size - 1),
  967. agent_dir=3,
  968. highlight_mask=vis_mask,
  969. )
  970. return img
  971. def _render(self, mode="human", highlight=True, tile_size=TILE_PIXELS):
  972. assert mode in self.metadata["render_modes"]
  973. """
  974. Render the whole-grid human view
  975. """
  976. if mode == "human" and not self.window:
  977. self.window = Window("gym_minigrid")
  978. self.window.show(block=False)
  979. # Compute which cells are visible to the agent
  980. _, vis_mask = self.gen_obs_grid()
  981. # Compute the world coordinates of the bottom-left corner
  982. # of the agent's view area
  983. f_vec = self.dir_vec
  984. r_vec = self.right_vec
  985. top_left = (
  986. self.agent_pos
  987. + f_vec * (self.agent_view_size - 1)
  988. - r_vec * (self.agent_view_size // 2)
  989. )
  990. # Mask of which cells to highlight
  991. highlight_mask = np.zeros(shape=(self.width, self.height), dtype=bool)
  992. # For each cell in the visibility mask
  993. for vis_j in range(0, self.agent_view_size):
  994. for vis_i in range(0, self.agent_view_size):
  995. # If this cell is not visible, don't highlight it
  996. if not vis_mask[vis_i, vis_j]:
  997. continue
  998. # Compute the world coordinates of this cell
  999. abs_i, abs_j = top_left - (f_vec * vis_j) + (r_vec * vis_i)
  1000. if abs_i < 0 or abs_i >= self.width:
  1001. continue
  1002. if abs_j < 0 or abs_j >= self.height:
  1003. continue
  1004. # Mark this cell to be highlighted
  1005. highlight_mask[abs_i, abs_j] = True
  1006. # Render the whole grid
  1007. img = self.grid.render(
  1008. tile_size,
  1009. self.agent_pos,
  1010. self.agent_dir,
  1011. highlight_mask=highlight_mask if highlight else None,
  1012. )
  1013. if mode == "human":
  1014. self.window.set_caption(self.mission)
  1015. self.window.show_img(img)
  1016. else:
  1017. return img
  1018. def render(self, mode="human", close=False, highlight=True, tile_size=TILE_PIXELS):
  1019. if close:
  1020. raise Exception(
  1021. "Please close the rendering window using env.close(). Closing the rendering window with the render method is no longer allowed."
  1022. )
  1023. if self.render_mode is not None:
  1024. return self.renderer.get_renders()
  1025. else:
  1026. return self._render(mode, highlight=highlight, tile_size=tile_size)
  1027. def close(self):
  1028. if self.window:
  1029. self.window.close()