minigrid.py 37 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310
  1. import hashlib
  2. import math
  3. import string
  4. from enum import IntEnum
  5. import gym
  6. import numpy as np
  7. from gym import spaces
  8. # Size in pixels of a tile in the full-scale human view
  9. from gym_minigrid.rendering import (
  10. downsample,
  11. fill_coords,
  12. highlight_img,
  13. point_in_circle,
  14. point_in_line,
  15. point_in_rect,
  16. point_in_triangle,
  17. rotate_fn,
  18. )
  19. from gym_minigrid.window import Window
  20. TILE_PIXELS = 32
  21. # Map of color names to RGB values
  22. COLORS = {
  23. "red": np.array([255, 0, 0]),
  24. "green": np.array([0, 255, 0]),
  25. "blue": np.array([0, 0, 255]),
  26. "purple": np.array([112, 39, 195]),
  27. "yellow": np.array([255, 255, 0]),
  28. "grey": np.array([100, 100, 100]),
  29. }
  30. COLOR_NAMES = sorted(list(COLORS.keys()))
  31. # Used to map colors to integers
  32. COLOR_TO_IDX = {"red": 0, "green": 1, "blue": 2, "purple": 3, "yellow": 4, "grey": 5}
  33. IDX_TO_COLOR = dict(zip(COLOR_TO_IDX.values(), COLOR_TO_IDX.keys()))
  34. # Map of object type to integers
  35. OBJECT_TO_IDX = {
  36. "unseen": 0,
  37. "empty": 1,
  38. "wall": 2,
  39. "floor": 3,
  40. "door": 4,
  41. "key": 5,
  42. "ball": 6,
  43. "box": 7,
  44. "goal": 8,
  45. "lava": 9,
  46. "agent": 10,
  47. }
  48. IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys()))
  49. # Map of state names to integers
  50. STATE_TO_IDX = {
  51. "open": 0,
  52. "closed": 1,
  53. "locked": 2,
  54. }
  55. # Map of agent direction indices to vectors
  56. DIR_TO_VEC = [
  57. # Pointing right (positive X)
  58. np.array((1, 0)),
  59. # Down (positive Y)
  60. np.array((0, 1)),
  61. # Pointing left (negative X)
  62. np.array((-1, 0)),
  63. # Up (negative Y)
  64. np.array((0, -1)),
  65. ]
  66. class WorldObj:
  67. """
  68. Base class for grid world objects
  69. """
  70. def __init__(self, type, color):
  71. assert type in OBJECT_TO_IDX, type
  72. assert color in COLOR_TO_IDX, color
  73. self.type = type
  74. self.color = color
  75. self.contains = None
  76. # Initial position of the object
  77. self.init_pos = None
  78. # Current position of the object
  79. self.cur_pos = None
  80. def can_overlap(self):
  81. """Can the agent overlap with this?"""
  82. return False
  83. def can_pickup(self):
  84. """Can the agent pick this up?"""
  85. return False
  86. def can_contain(self):
  87. """Can this contain another object?"""
  88. return False
  89. def see_behind(self):
  90. """Can the agent see behind this object?"""
  91. return True
  92. def toggle(self, env, pos):
  93. """Method to trigger/toggle an action this object performs"""
  94. return False
  95. def encode(self):
  96. """Encode the a description of this object as a 3-tuple of integers"""
  97. return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], 0)
  98. @staticmethod
  99. def decode(type_idx, color_idx, state):
  100. """Create an object from a 3-tuple state description"""
  101. obj_type = IDX_TO_OBJECT[type_idx]
  102. color = IDX_TO_COLOR[color_idx]
  103. if obj_type == "empty" or obj_type == "unseen":
  104. return None
  105. # State, 0: open, 1: closed, 2: locked
  106. is_open = state == 0
  107. is_locked = state == 2
  108. if obj_type == "wall":
  109. v = Wall(color)
  110. elif obj_type == "floor":
  111. v = Floor(color)
  112. elif obj_type == "ball":
  113. v = Ball(color)
  114. elif obj_type == "key":
  115. v = Key(color)
  116. elif obj_type == "box":
  117. v = Box(color)
  118. elif obj_type == "door":
  119. v = Door(color, is_open, is_locked)
  120. elif obj_type == "goal":
  121. v = Goal()
  122. elif obj_type == "lava":
  123. v = Lava()
  124. else:
  125. assert False, "unknown object type in decode '%s'" % obj_type
  126. return v
  127. def render(self, r):
  128. """Draw this object with the given renderer"""
  129. raise NotImplementedError
  130. class Goal(WorldObj):
  131. def __init__(self):
  132. super().__init__("goal", "green")
  133. def can_overlap(self):
  134. return True
  135. def render(self, img):
  136. fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color])
  137. class Floor(WorldObj):
  138. """
  139. Colored floor tile the agent can walk over
  140. """
  141. def __init__(self, color="blue"):
  142. super().__init__("floor", color)
  143. def can_overlap(self):
  144. return True
  145. def render(self, img):
  146. # Give the floor a pale color
  147. color = COLORS[self.color] / 2
  148. fill_coords(img, point_in_rect(0.031, 1, 0.031, 1), color)
  149. class Lava(WorldObj):
  150. def __init__(self):
  151. super().__init__("lava", "red")
  152. def can_overlap(self):
  153. return True
  154. def render(self, img):
  155. c = (255, 128, 0)
  156. # Background color
  157. fill_coords(img, point_in_rect(0, 1, 0, 1), c)
  158. # Little waves
  159. for i in range(3):
  160. ylo = 0.3 + 0.2 * i
  161. yhi = 0.4 + 0.2 * i
  162. fill_coords(img, point_in_line(0.1, ylo, 0.3, yhi, r=0.03), (0, 0, 0))
  163. fill_coords(img, point_in_line(0.3, yhi, 0.5, ylo, r=0.03), (0, 0, 0))
  164. fill_coords(img, point_in_line(0.5, ylo, 0.7, yhi, r=0.03), (0, 0, 0))
  165. fill_coords(img, point_in_line(0.7, yhi, 0.9, ylo, r=0.03), (0, 0, 0))
  166. class Wall(WorldObj):
  167. def __init__(self, color="grey"):
  168. super().__init__("wall", color)
  169. def see_behind(self):
  170. return False
  171. def render(self, img):
  172. fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color])
  173. class Door(WorldObj):
  174. def __init__(self, color, is_open=False, is_locked=False):
  175. super().__init__("door", color)
  176. self.is_open = is_open
  177. self.is_locked = is_locked
  178. def can_overlap(self):
  179. """The agent can only walk over this cell when the door is open"""
  180. return self.is_open
  181. def see_behind(self):
  182. return self.is_open
  183. def toggle(self, env, pos):
  184. # If the player has the right key to open the door
  185. if self.is_locked:
  186. if isinstance(env.carrying, Key) and env.carrying.color == self.color:
  187. self.is_locked = False
  188. self.is_open = True
  189. return True
  190. return False
  191. self.is_open = not self.is_open
  192. return True
  193. def encode(self):
  194. """Encode the a description of this object as a 3-tuple of integers"""
  195. # State, 0: open, 1: closed, 2: locked
  196. if self.is_open:
  197. state = 0
  198. elif self.is_locked:
  199. state = 2
  200. # if door is closed and unlocked
  201. elif not self.is_open:
  202. state = 1
  203. else:
  204. raise ValueError(
  205. "There is no possible state encoding for the state:\n -Door Open: {}\n -Door Closed: {}\n -Door Locked: {}".format(
  206. self.is_open, not self.is_open, self.is_locked
  207. )
  208. )
  209. return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], state)
  210. def render(self, img):
  211. c = COLORS[self.color]
  212. if self.is_open:
  213. fill_coords(img, point_in_rect(0.88, 1.00, 0.00, 1.00), c)
  214. fill_coords(img, point_in_rect(0.92, 0.96, 0.04, 0.96), (0, 0, 0))
  215. return
  216. # Door frame and door
  217. if self.is_locked:
  218. fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)
  219. fill_coords(img, point_in_rect(0.06, 0.94, 0.06, 0.94), 0.45 * np.array(c))
  220. # Draw key slot
  221. fill_coords(img, point_in_rect(0.52, 0.75, 0.50, 0.56), c)
  222. else:
  223. fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)
  224. fill_coords(img, point_in_rect(0.04, 0.96, 0.04, 0.96), (0, 0, 0))
  225. fill_coords(img, point_in_rect(0.08, 0.92, 0.08, 0.92), c)
  226. fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), (0, 0, 0))
  227. # Draw door handle
  228. fill_coords(img, point_in_circle(cx=0.75, cy=0.50, r=0.08), c)
  229. class Key(WorldObj):
  230. def __init__(self, color="blue"):
  231. super().__init__("key", color)
  232. def can_pickup(self):
  233. return True
  234. def render(self, img):
  235. c = COLORS[self.color]
  236. # Vertical quad
  237. fill_coords(img, point_in_rect(0.50, 0.63, 0.31, 0.88), c)
  238. # Teeth
  239. fill_coords(img, point_in_rect(0.38, 0.50, 0.59, 0.66), c)
  240. fill_coords(img, point_in_rect(0.38, 0.50, 0.81, 0.88), c)
  241. # Ring
  242. fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.190), c)
  243. fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.064), (0, 0, 0))
  244. class Ball(WorldObj):
  245. def __init__(self, color="blue"):
  246. super().__init__("ball", color)
  247. def can_pickup(self):
  248. return True
  249. def render(self, img):
  250. fill_coords(img, point_in_circle(0.5, 0.5, 0.31), COLORS[self.color])
  251. class Box(WorldObj):
  252. def __init__(self, color, contains=None):
  253. super().__init__("box", color)
  254. self.contains = contains
  255. def can_pickup(self):
  256. return True
  257. def render(self, img):
  258. c = COLORS[self.color]
  259. # Outline
  260. fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), c)
  261. fill_coords(img, point_in_rect(0.18, 0.82, 0.18, 0.82), (0, 0, 0))
  262. # Horizontal slit
  263. fill_coords(img, point_in_rect(0.16, 0.84, 0.47, 0.53), c)
  264. def toggle(self, env, pos):
  265. # Replace the box by its contents
  266. env.grid.set(*pos, self.contains)
  267. return True
  268. class Grid:
  269. """
  270. Represent a grid and operations on it
  271. """
  272. # Static cache of pre-renderer tiles
  273. tile_cache = {}
  274. def __init__(self, width, height):
  275. assert width >= 3
  276. assert height >= 3
  277. self.width = width
  278. self.height = height
  279. self.grid = [None] * width * height
  280. def __contains__(self, key):
  281. if isinstance(key, WorldObj):
  282. for e in self.grid:
  283. if e is key:
  284. return True
  285. elif isinstance(key, tuple):
  286. for e in self.grid:
  287. if e is None:
  288. continue
  289. if (e.color, e.type) == key:
  290. return True
  291. if key[0] is None and key[1] == e.type:
  292. return True
  293. return False
  294. def __eq__(self, other):
  295. grid1 = self.encode()
  296. grid2 = other.encode()
  297. return np.array_equal(grid2, grid1)
  298. def __ne__(self, other):
  299. return not self == other
  300. def copy(self):
  301. from copy import deepcopy
  302. return deepcopy(self)
  303. def set(self, i, j, v):
  304. assert i >= 0 and i < self.width
  305. assert j >= 0 and j < self.height
  306. self.grid[j * self.width + i] = v
  307. def get(self, i, j):
  308. assert i >= 0 and i < self.width
  309. assert j >= 0 and j < self.height
  310. return self.grid[j * self.width + i]
  311. def horz_wall(self, x, y, length=None, obj_type=Wall):
  312. if length is None:
  313. length = self.width - x
  314. for i in range(0, length):
  315. self.set(x + i, y, obj_type())
  316. def vert_wall(self, x, y, length=None, obj_type=Wall):
  317. if length is None:
  318. length = self.height - y
  319. for j in range(0, length):
  320. self.set(x, y + j, obj_type())
  321. def wall_rect(self, x, y, w, h):
  322. self.horz_wall(x, y, w)
  323. self.horz_wall(x, y + h - 1, w)
  324. self.vert_wall(x, y, h)
  325. self.vert_wall(x + w - 1, y, h)
  326. def rotate_left(self):
  327. """
  328. Rotate the grid to the left (counter-clockwise)
  329. """
  330. grid = Grid(self.height, self.width)
  331. for i in range(self.width):
  332. for j in range(self.height):
  333. v = self.get(i, j)
  334. grid.set(j, grid.height - 1 - i, v)
  335. return grid
  336. def slice(self, topX, topY, width, height):
  337. """
  338. Get a subset of the grid
  339. """
  340. grid = Grid(width, height)
  341. for j in range(0, height):
  342. for i in range(0, width):
  343. x = topX + i
  344. y = topY + j
  345. if x >= 0 and x < self.width and y >= 0 and y < self.height:
  346. v = self.get(x, y)
  347. else:
  348. v = Wall()
  349. grid.set(i, j, v)
  350. return grid
  351. @classmethod
  352. def render_tile(
  353. cls, obj, agent_dir=None, highlight=False, tile_size=TILE_PIXELS, subdivs=3
  354. ):
  355. """
  356. Render a tile and cache the result
  357. """
  358. # Hash map lookup key for the cache
  359. key = (agent_dir, highlight, tile_size)
  360. key = obj.encode() + key if obj else key
  361. if key in cls.tile_cache:
  362. return cls.tile_cache[key]
  363. img = np.zeros(
  364. shape=(tile_size * subdivs, tile_size * subdivs, 3), dtype=np.uint8
  365. )
  366. # Draw the grid lines (top and left edges)
  367. fill_coords(img, point_in_rect(0, 0.031, 0, 1), (100, 100, 100))
  368. fill_coords(img, point_in_rect(0, 1, 0, 0.031), (100, 100, 100))
  369. if obj is not None:
  370. obj.render(img)
  371. # Overlay the agent on top
  372. if agent_dir is not None:
  373. tri_fn = point_in_triangle(
  374. (0.12, 0.19),
  375. (0.87, 0.50),
  376. (0.12, 0.81),
  377. )
  378. # Rotate the agent based on its direction
  379. tri_fn = rotate_fn(tri_fn, cx=0.5, cy=0.5, theta=0.5 * math.pi * agent_dir)
  380. fill_coords(img, tri_fn, (255, 0, 0))
  381. # Highlight the cell if needed
  382. if highlight:
  383. highlight_img(img)
  384. # Downsample the image to perform supersampling/anti-aliasing
  385. img = downsample(img, subdivs)
  386. # Cache the rendered tile
  387. cls.tile_cache[key] = img
  388. return img
  389. def render(self, tile_size, agent_pos=None, agent_dir=None, highlight_mask=None):
  390. """
  391. Render this grid at a given scale
  392. :param r: target renderer object
  393. :param tile_size: tile size in pixels
  394. """
  395. if highlight_mask is None:
  396. highlight_mask = np.zeros(shape=(self.width, self.height), dtype=bool)
  397. # Compute the total grid size
  398. width_px = self.width * tile_size
  399. height_px = self.height * tile_size
  400. img = np.zeros(shape=(height_px, width_px, 3), dtype=np.uint8)
  401. # Render the grid
  402. for j in range(0, self.height):
  403. for i in range(0, self.width):
  404. cell = self.get(i, j)
  405. agent_here = np.array_equal(agent_pos, (i, j))
  406. tile_img = Grid.render_tile(
  407. cell,
  408. agent_dir=agent_dir if agent_here else None,
  409. highlight=highlight_mask[i, j],
  410. tile_size=tile_size,
  411. )
  412. ymin = j * tile_size
  413. ymax = (j + 1) * tile_size
  414. xmin = i * tile_size
  415. xmax = (i + 1) * tile_size
  416. img[ymin:ymax, xmin:xmax, :] = tile_img
  417. return img
  418. def encode(self, vis_mask=None):
  419. """
  420. Produce a compact numpy encoding of the grid
  421. """
  422. if vis_mask is None:
  423. vis_mask = np.ones((self.width, self.height), dtype=bool)
  424. array = np.zeros((self.width, self.height, 3), dtype="uint8")
  425. for i in range(self.width):
  426. for j in range(self.height):
  427. if vis_mask[i, j]:
  428. v = self.get(i, j)
  429. if v is None:
  430. array[i, j, 0] = OBJECT_TO_IDX["empty"]
  431. array[i, j, 1] = 0
  432. array[i, j, 2] = 0
  433. else:
  434. array[i, j, :] = v.encode()
  435. return array
  436. @staticmethod
  437. def decode(array):
  438. """
  439. Decode an array grid encoding back into a grid
  440. """
  441. width, height, channels = array.shape
  442. assert channels == 3
  443. vis_mask = np.ones(shape=(width, height), dtype=bool)
  444. grid = Grid(width, height)
  445. for i in range(width):
  446. for j in range(height):
  447. type_idx, color_idx, state = array[i, j]
  448. v = WorldObj.decode(type_idx, color_idx, state)
  449. grid.set(i, j, v)
  450. vis_mask[i, j] = type_idx != OBJECT_TO_IDX["unseen"]
  451. return grid, vis_mask
  452. def process_vis(self, agent_pos):
  453. mask = np.zeros(shape=(self.width, self.height), dtype=bool)
  454. mask[agent_pos[0], agent_pos[1]] = True
  455. for j in reversed(range(0, self.height)):
  456. for i in range(0, self.width - 1):
  457. if not mask[i, j]:
  458. continue
  459. cell = self.get(i, j)
  460. if cell and not cell.see_behind():
  461. continue
  462. mask[i + 1, j] = True
  463. if j > 0:
  464. mask[i + 1, j - 1] = True
  465. mask[i, j - 1] = True
  466. for i in reversed(range(1, self.width)):
  467. if not mask[i, j]:
  468. continue
  469. cell = self.get(i, j)
  470. if cell and not cell.see_behind():
  471. continue
  472. mask[i - 1, j] = True
  473. if j > 0:
  474. mask[i - 1, j - 1] = True
  475. mask[i, j - 1] = True
  476. for j in range(0, self.height):
  477. for i in range(0, self.width):
  478. if not mask[i, j]:
  479. self.set(i, j, None)
  480. return mask
  481. class MiniGridEnv(gym.Env):
  482. """
  483. 2D grid world game environment
  484. """
  485. metadata = {
  486. # Deprecated: use 'render_modes' instead
  487. "render.modes": ["human", "rgb_array"],
  488. "video.frames_per_second": 10, # Deprecated: use 'render_fps' instead
  489. "render_modes": ["human", "rgb_array"],
  490. "render_fps": 10,
  491. }
  492. # Enumeration of possible actions
  493. class Actions(IntEnum):
  494. # Turn left, turn right, move forward
  495. left = 0
  496. right = 1
  497. forward = 2
  498. # Pick up an object
  499. pickup = 3
  500. # Drop an object
  501. drop = 4
  502. # Toggle/activate an object
  503. toggle = 5
  504. # Done completing task
  505. done = 6
  506. def __init__(
  507. self,
  508. grid_size: int = None,
  509. width: int = None,
  510. height: int = None,
  511. max_steps: int = 100,
  512. see_through_walls: bool = False,
  513. agent_view_size: int = 7,
  514. render_mode: str = None,
  515. **kwargs
  516. ):
  517. # Can't set both grid_size and width/height
  518. if grid_size:
  519. assert width is None and height is None
  520. width = grid_size
  521. height = grid_size
  522. # Action enumeration for this environment
  523. self.actions = MiniGridEnv.Actions
  524. # Actions are discrete integer values
  525. self.action_space = spaces.Discrete(len(self.actions))
  526. # Number of cells (width and height) in the agent view
  527. assert agent_view_size % 2 == 1
  528. assert agent_view_size >= 3
  529. self.agent_view_size = agent_view_size
  530. # Observations are dictionaries containing an
  531. # encoding of the grid and a textual 'mission' string
  532. self.observation_space = spaces.Box(
  533. low=0,
  534. high=255,
  535. shape=(self.agent_view_size, self.agent_view_size, 3),
  536. dtype="uint8",
  537. )
  538. self.observation_space = spaces.Dict(
  539. {
  540. "image": self.observation_space,
  541. "direction": spaces.Discrete(4),
  542. "mission": spaces.Text(
  543. max_length=200,
  544. charset=string.ascii_letters + string.digits + " .,!-",
  545. ),
  546. }
  547. )
  548. # render mode
  549. self.render_mode = render_mode
  550. # Range of possible rewards
  551. self.reward_range = (0, 1)
  552. self.window: Window = None
  553. # Environment configuration
  554. self.width = width
  555. self.height = height
  556. self.max_steps = max_steps
  557. self.see_through_walls = see_through_walls
  558. # Current position and direction of the agent
  559. self.agent_pos: np.ndarray = None
  560. self.agent_dir: int = None
  561. # Initialize the state
  562. self.reset()
  563. def reset(self, *, seed=None, return_info=False, options=None):
  564. super().reset(seed=seed)
  565. # Current position and direction of the agent
  566. self.agent_pos = None
  567. self.agent_dir = None
  568. # Generate a new random grid at the start of each episode
  569. self._gen_grid(self.width, self.height)
  570. # These fields should be defined by _gen_grid
  571. assert self.agent_pos is not None
  572. assert self.agent_dir is not None
  573. # Check that the agent doesn't overlap with an object
  574. start_cell = self.grid.get(*self.agent_pos)
  575. assert start_cell is None or start_cell.can_overlap()
  576. # Item picked up, being carried, initially nothing
  577. self.carrying = None
  578. # Step count since episode start
  579. self.step_count = 0
  580. # Return first observation
  581. obs = self.gen_obs()
  582. return obs
  583. def hash(self, size=16):
  584. """Compute a hash that uniquely identifies the current state of the environment.
  585. :param size: Size of the hashing
  586. """
  587. sample_hash = hashlib.sha256()
  588. to_encode = [self.grid.encode().tolist(), self.agent_pos, self.agent_dir]
  589. for item in to_encode:
  590. sample_hash.update(str(item).encode("utf8"))
  591. return sample_hash.hexdigest()[:size]
  592. @property
  593. def steps_remaining(self):
  594. return self.max_steps - self.step_count
  595. def __str__(self):
  596. """
  597. Produce a pretty string of the environment's grid along with the agent.
  598. A grid cell is represented by 2-character string, the first one for
  599. the object and the second one for the color.
  600. """
  601. # Map of object types to short string
  602. OBJECT_TO_STR = {
  603. "wall": "W",
  604. "floor": "F",
  605. "door": "D",
  606. "key": "K",
  607. "ball": "A",
  608. "box": "B",
  609. "goal": "G",
  610. "lava": "V",
  611. }
  612. # Map agent's direction to short string
  613. AGENT_DIR_TO_STR = {0: ">", 1: "V", 2: "<", 3: "^"}
  614. str = ""
  615. for j in range(self.grid.height):
  616. for i in range(self.grid.width):
  617. if i == self.agent_pos[0] and j == self.agent_pos[1]:
  618. str += 2 * AGENT_DIR_TO_STR[self.agent_dir]
  619. continue
  620. c = self.grid.get(i, j)
  621. if c is None:
  622. str += " "
  623. continue
  624. if c.type == "door":
  625. if c.is_open:
  626. str += "__"
  627. elif c.is_locked:
  628. str += "L" + c.color[0].upper()
  629. else:
  630. str += "D" + c.color[0].upper()
  631. continue
  632. str += OBJECT_TO_STR[c.type] + c.color[0].upper()
  633. if j < self.grid.height - 1:
  634. str += "\n"
  635. return str
  636. def _gen_grid(self, width, height):
  637. assert False, "_gen_grid needs to be implemented by each environment"
  638. def _reward(self):
  639. """
  640. Compute the reward to be given upon success
  641. """
  642. return 1 - 0.9 * (self.step_count / self.max_steps)
  643. def _rand_int(self, low, high):
  644. """
  645. Generate random integer in [low,high[
  646. """
  647. return self.np_random.integers(low, high)
  648. def _rand_float(self, low, high):
  649. """
  650. Generate random float in [low,high[
  651. """
  652. return self.np_random.uniform(low, high)
  653. def _rand_bool(self):
  654. """
  655. Generate random boolean value
  656. """
  657. return self.np_random.integers(0, 2) == 0
  658. def _rand_elem(self, iterable):
  659. """
  660. Pick a random element in a list
  661. """
  662. lst = list(iterable)
  663. idx = self._rand_int(0, len(lst))
  664. return lst[idx]
  665. def _rand_subset(self, iterable, num_elems):
  666. """
  667. Sample a random subset of distinct elements of a list
  668. """
  669. lst = list(iterable)
  670. assert num_elems <= len(lst)
  671. out = []
  672. while len(out) < num_elems:
  673. elem = self._rand_elem(lst)
  674. lst.remove(elem)
  675. out.append(elem)
  676. return out
  677. def _rand_color(self):
  678. """
  679. Generate a random color name (string)
  680. """
  681. return self._rand_elem(COLOR_NAMES)
  682. def _rand_pos(self, xLow, xHigh, yLow, yHigh):
  683. """
  684. Generate a random (x,y) position tuple
  685. """
  686. return (
  687. self.np_random.integers(xLow, xHigh),
  688. self.np_random.integers(yLow, yHigh),
  689. )
  690. def place_obj(self, obj, top=None, size=None, reject_fn=None, max_tries=math.inf):
  691. """
  692. Place an object at an empty position in the grid
  693. :param top: top-left position of the rectangle where to place
  694. :param size: size of the rectangle where to place
  695. :param reject_fn: function to filter out potential positions
  696. """
  697. if top is None:
  698. top = (0, 0)
  699. else:
  700. top = (max(top[0], 0), max(top[1], 0))
  701. if size is None:
  702. size = (self.grid.width, self.grid.height)
  703. num_tries = 0
  704. while True:
  705. # This is to handle with rare cases where rejection sampling
  706. # gets stuck in an infinite loop
  707. if num_tries > max_tries:
  708. raise RecursionError("rejection sampling failed in place_obj")
  709. num_tries += 1
  710. pos = np.array(
  711. (
  712. self._rand_int(top[0], min(top[0] + size[0], self.grid.width)),
  713. self._rand_int(top[1], min(top[1] + size[1], self.grid.height)),
  714. )
  715. )
  716. # Don't place the object on top of another object
  717. if self.grid.get(*pos) is not None:
  718. continue
  719. # Don't place the object where the agent is
  720. if np.array_equal(pos, self.agent_pos):
  721. continue
  722. # Check if there is a filtering criterion
  723. if reject_fn and reject_fn(self, pos):
  724. continue
  725. break
  726. self.grid.set(*pos, obj)
  727. if obj is not None:
  728. obj.init_pos = pos
  729. obj.cur_pos = pos
  730. return pos
  731. def put_obj(self, obj, i, j):
  732. """
  733. Put an object at a specific position in the grid
  734. """
  735. self.grid.set(i, j, obj)
  736. obj.init_pos = (i, j)
  737. obj.cur_pos = (i, j)
  738. def place_agent(self, top=None, size=None, rand_dir=True, max_tries=math.inf):
  739. """
  740. Set the agent's starting point at an empty position in the grid
  741. """
  742. self.agent_pos = None
  743. pos = self.place_obj(None, top, size, max_tries=max_tries)
  744. self.agent_pos = pos
  745. if rand_dir:
  746. self.agent_dir = self._rand_int(0, 4)
  747. return pos
  748. @property
  749. def dir_vec(self):
  750. """
  751. Get the direction vector for the agent, pointing in the direction
  752. of forward movement.
  753. """
  754. assert self.agent_dir >= 0 and self.agent_dir < 4
  755. return DIR_TO_VEC[self.agent_dir]
  756. @property
  757. def right_vec(self):
  758. """
  759. Get the vector pointing to the right of the agent.
  760. """
  761. dx, dy = self.dir_vec
  762. return np.array((-dy, dx))
  763. @property
  764. def front_pos(self):
  765. """
  766. Get the position of the cell that is right in front of the agent
  767. """
  768. return self.agent_pos + self.dir_vec
  769. def get_view_coords(self, i, j):
  770. """
  771. Translate and rotate absolute grid coordinates (i, j) into the
  772. agent's partially observable view (sub-grid). Note that the resulting
  773. coordinates may be negative or outside of the agent's view size.
  774. """
  775. ax, ay = self.agent_pos
  776. dx, dy = self.dir_vec
  777. rx, ry = self.right_vec
  778. # Compute the absolute coordinates of the top-left view corner
  779. sz = self.agent_view_size
  780. hs = self.agent_view_size // 2
  781. tx = ax + (dx * (sz - 1)) - (rx * hs)
  782. ty = ay + (dy * (sz - 1)) - (ry * hs)
  783. lx = i - tx
  784. ly = j - ty
  785. # Project the coordinates of the object relative to the top-left
  786. # corner onto the agent's own coordinate system
  787. vx = rx * lx + ry * ly
  788. vy = -(dx * lx + dy * ly)
  789. return vx, vy
  790. def get_view_exts(self, agent_view_size=None):
  791. """
  792. Get the extents of the square set of tiles visible to the agent
  793. Note: the bottom extent indices are not included in the set
  794. if agent_view_size is None, use self.agent_view_size
  795. """
  796. agent_view_size = agent_view_size or self.agent_view_size
  797. # Facing right
  798. if self.agent_dir == 0:
  799. topX = self.agent_pos[0]
  800. topY = self.agent_pos[1] - agent_view_size // 2
  801. # Facing down
  802. elif self.agent_dir == 1:
  803. topX = self.agent_pos[0] - agent_view_size // 2
  804. topY = self.agent_pos[1]
  805. # Facing left
  806. elif self.agent_dir == 2:
  807. topX = self.agent_pos[0] - agent_view_size + 1
  808. topY = self.agent_pos[1] - agent_view_size // 2
  809. # Facing up
  810. elif self.agent_dir == 3:
  811. topX = self.agent_pos[0] - agent_view_size // 2
  812. topY = self.agent_pos[1] - agent_view_size + 1
  813. else:
  814. assert False, "invalid agent direction"
  815. botX = topX + agent_view_size
  816. botY = topY + agent_view_size
  817. return (topX, topY, botX, botY)
  818. def relative_coords(self, x, y):
  819. """
  820. Check if a grid position belongs to the agent's field of view, and returns the corresponding coordinates
  821. """
  822. vx, vy = self.get_view_coords(x, y)
  823. if vx < 0 or vy < 0 or vx >= self.agent_view_size or vy >= self.agent_view_size:
  824. return None
  825. return vx, vy
  826. def in_view(self, x, y):
  827. """
  828. check if a grid position is visible to the agent
  829. """
  830. return self.relative_coords(x, y) is not None
  831. def agent_sees(self, x, y):
  832. """
  833. Check if a non-empty grid position is visible to the agent
  834. """
  835. coordinates = self.relative_coords(x, y)
  836. if coordinates is None:
  837. return False
  838. vx, vy = coordinates
  839. obs = self.gen_obs()
  840. obs_grid, _ = Grid.decode(obs["image"])
  841. obs_cell = obs_grid.get(vx, vy)
  842. world_cell = self.grid.get(x, y)
  843. return obs_cell is not None and obs_cell.type == world_cell.type
  844. def step(self, action):
  845. self.step_count += 1
  846. reward = 0
  847. done = False
  848. # Get the position in front of the agent
  849. fwd_pos = self.front_pos
  850. # Get the contents of the cell in front of the agent
  851. fwd_cell = self.grid.get(*fwd_pos)
  852. # Rotate left
  853. if action == self.actions.left:
  854. self.agent_dir -= 1
  855. if self.agent_dir < 0:
  856. self.agent_dir += 4
  857. # Rotate right
  858. elif action == self.actions.right:
  859. self.agent_dir = (self.agent_dir + 1) % 4
  860. # Move forward
  861. elif action == self.actions.forward:
  862. if fwd_cell is None or fwd_cell.can_overlap():
  863. self.agent_pos = fwd_pos
  864. if fwd_cell is not None and fwd_cell.type == "goal":
  865. done = True
  866. reward = self._reward()
  867. if fwd_cell is not None and fwd_cell.type == "lava":
  868. done = True
  869. # Pick up an object
  870. elif action == self.actions.pickup:
  871. if fwd_cell and fwd_cell.can_pickup():
  872. if self.carrying is None:
  873. self.carrying = fwd_cell
  874. self.carrying.cur_pos = np.array([-1, -1])
  875. self.grid.set(*fwd_pos, None)
  876. # Drop an object
  877. elif action == self.actions.drop:
  878. if not fwd_cell and self.carrying:
  879. self.grid.set(*fwd_pos, self.carrying)
  880. self.carrying.cur_pos = fwd_pos
  881. self.carrying = None
  882. # Toggle/activate an object
  883. elif action == self.actions.toggle:
  884. if fwd_cell:
  885. fwd_cell.toggle(self, fwd_pos)
  886. # Done action (not used by default)
  887. elif action == self.actions.done:
  888. pass
  889. else:
  890. assert False, "unknown action"
  891. if self.step_count >= self.max_steps:
  892. done = True
  893. obs = self.gen_obs()
  894. return obs, reward, done, {}
  895. def gen_obs_grid(self, agent_view_size=None):
  896. """
  897. Generate the sub-grid observed by the agent.
  898. This method also outputs a visibility mask telling us which grid
  899. cells the agent can actually see.
  900. if agent_view_size is None, self.agent_view_size is used
  901. """
  902. topX, topY, botX, botY = self.get_view_exts(agent_view_size)
  903. agent_view_size = agent_view_size or self.agent_view_size
  904. grid = self.grid.slice(topX, topY, agent_view_size, agent_view_size)
  905. for i in range(self.agent_dir + 1):
  906. grid = grid.rotate_left()
  907. # Process occluders and visibility
  908. # Note that this incurs some performance cost
  909. if not self.see_through_walls:
  910. vis_mask = grid.process_vis(
  911. agent_pos=(agent_view_size // 2, agent_view_size - 1)
  912. )
  913. else:
  914. vis_mask = np.ones(shape=(grid.width, grid.height), dtype=bool)
  915. # Make it so the agent sees what it's carrying
  916. # We do this by placing the carried object at the agent's position
  917. # in the agent's partially observable view
  918. agent_pos = grid.width // 2, grid.height - 1
  919. if self.carrying:
  920. grid.set(*agent_pos, self.carrying)
  921. else:
  922. grid.set(*agent_pos, None)
  923. return grid, vis_mask
  924. def gen_obs(self):
  925. """
  926. Generate the agent's view (partially observable, low-resolution encoding)
  927. """
  928. grid, vis_mask = self.gen_obs_grid()
  929. # Encode the partially observable view into a numpy array
  930. image = grid.encode(vis_mask)
  931. assert hasattr(
  932. self, "mission"
  933. ), "environments must define a textual mission string"
  934. # Observations are dictionaries containing:
  935. # - an image (partially observable view of the environment)
  936. # - the agent's direction/orientation (acting as a compass)
  937. # - a textual mission string (instructions for the agent)
  938. obs = {"image": image, "direction": self.agent_dir, "mission": self.mission}
  939. return obs
  940. def get_obs_render(self, obs, tile_size=TILE_PIXELS // 2):
  941. """
  942. Render an agent observation for visualization
  943. """
  944. grid, vis_mask = Grid.decode(obs)
  945. # Render the whole grid
  946. img = grid.render(
  947. tile_size,
  948. agent_pos=(self.agent_view_size // 2, self.agent_view_size - 1),
  949. agent_dir=3,
  950. highlight_mask=vis_mask,
  951. )
  952. return img
  953. def render(self, mode="human", close=False, highlight=True, tile_size=TILE_PIXELS):
  954. """
  955. Render the whole-grid human view
  956. """
  957. if self.render_mode is not None:
  958. mode = self.render_mode
  959. if close:
  960. if self.window:
  961. self.window.close()
  962. return
  963. if mode == "human" and not self.window:
  964. self.window = Window("gym_minigrid")
  965. self.window.show(block=False)
  966. # Compute which cells are visible to the agent
  967. _, vis_mask = self.gen_obs_grid()
  968. # Compute the world coordinates of the bottom-left corner
  969. # of the agent's view area
  970. f_vec = self.dir_vec
  971. r_vec = self.right_vec
  972. top_left = (
  973. self.agent_pos
  974. + f_vec * (self.agent_view_size - 1)
  975. - r_vec * (self.agent_view_size // 2)
  976. )
  977. # Mask of which cells to highlight
  978. highlight_mask = np.zeros(shape=(self.width, self.height), dtype=bool)
  979. # For each cell in the visibility mask
  980. for vis_j in range(0, self.agent_view_size):
  981. for vis_i in range(0, self.agent_view_size):
  982. # If this cell is not visible, don't highlight it
  983. if not vis_mask[vis_i, vis_j]:
  984. continue
  985. # Compute the world coordinates of this cell
  986. abs_i, abs_j = top_left - (f_vec * vis_j) + (r_vec * vis_i)
  987. if abs_i < 0 or abs_i >= self.width:
  988. continue
  989. if abs_j < 0 or abs_j >= self.height:
  990. continue
  991. # Mark this cell to be highlighted
  992. highlight_mask[abs_i, abs_j] = True
  993. # Render the whole grid
  994. img = self.grid.render(
  995. tile_size,
  996. self.agent_pos,
  997. self.agent_dir,
  998. highlight_mask=highlight_mask if highlight else None,
  999. )
  1000. if mode == "human":
  1001. self.window.set_caption(self.mission)
  1002. self.window.show_img(img)
  1003. return img
  1004. def close(self):
  1005. if self.window:
  1006. self.window.close()
  1007. return