minigrid.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317
  1. import hashlib
  2. import math
  3. import string
  4. from abc import abstractmethod
  5. from enum import IntEnum
  6. import gym
  7. import numpy as np
  8. from gym import spaces
  9. # Size in pixels of a tile in the full-scale human view
  10. from gym_minigrid.rendering import (
  11. downsample,
  12. fill_coords,
  13. highlight_img,
  14. point_in_circle,
  15. point_in_line,
  16. point_in_rect,
  17. point_in_triangle,
  18. rotate_fn,
  19. )
  20. TILE_PIXELS = 32
  21. # Map of color names to RGB values
  22. COLORS = {
  23. "red": np.array([255, 0, 0]),
  24. "green": np.array([0, 255, 0]),
  25. "blue": np.array([0, 0, 255]),
  26. "purple": np.array([112, 39, 195]),
  27. "yellow": np.array([255, 255, 0]),
  28. "grey": np.array([100, 100, 100]),
  29. }
  30. COLOR_NAMES = sorted(list(COLORS.keys()))
  31. # Used to map colors to integers
  32. COLOR_TO_IDX = {"red": 0, "green": 1, "blue": 2, "purple": 3, "yellow": 4, "grey": 5}
  33. IDX_TO_COLOR = dict(zip(COLOR_TO_IDX.values(), COLOR_TO_IDX.keys()))
  34. # Map of object type to integers
  35. OBJECT_TO_IDX = {
  36. "unseen": 0,
  37. "empty": 1,
  38. "wall": 2,
  39. "floor": 3,
  40. "door": 4,
  41. "key": 5,
  42. "ball": 6,
  43. "box": 7,
  44. "goal": 8,
  45. "lava": 9,
  46. "agent": 10,
  47. }
  48. IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys()))
  49. # Map of state names to integers
  50. STATE_TO_IDX = {
  51. "open": 0,
  52. "closed": 1,
  53. "locked": 2,
  54. }
  55. # Map of agent direction indices to vectors
  56. DIR_TO_VEC = [
  57. # Pointing right (positive X)
  58. np.array((1, 0)),
  59. # Down (positive Y)
  60. np.array((0, 1)),
  61. # Pointing left (negative X)
  62. np.array((-1, 0)),
  63. # Up (negative Y)
  64. np.array((0, -1)),
  65. ]
  66. class WorldObj:
  67. """
  68. Base class for grid world objects
  69. """
  70. def __init__(self, type, color):
  71. assert type in OBJECT_TO_IDX, type
  72. assert color in COLOR_TO_IDX, color
  73. self.type = type
  74. self.color = color
  75. self.contains = None
  76. # Initial position of the object
  77. self.init_pos = None
  78. # Current position of the object
  79. self.cur_pos = None
  80. def can_overlap(self):
  81. """Can the agent overlap with this?"""
  82. return False
  83. def can_pickup(self):
  84. """Can the agent pick this up?"""
  85. return False
  86. def can_contain(self):
  87. """Can this contain another object?"""
  88. return False
  89. def see_behind(self):
  90. """Can the agent see behind this object?"""
  91. return True
  92. def toggle(self, env, pos):
  93. """Method to trigger/toggle an action this object performs"""
  94. return False
  95. def encode(self):
  96. """Encode the a description of this object as a 3-tuple of integers"""
  97. return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], 0)
  98. @staticmethod
  99. def decode(type_idx, color_idx, state):
  100. """Create an object from a 3-tuple state description"""
  101. obj_type = IDX_TO_OBJECT[type_idx]
  102. color = IDX_TO_COLOR[color_idx]
  103. if obj_type == "empty" or obj_type == "unseen":
  104. return None
  105. # State, 0: open, 1: closed, 2: locked
  106. is_open = state == 0
  107. is_locked = state == 2
  108. if obj_type == "wall":
  109. v = Wall(color)
  110. elif obj_type == "floor":
  111. v = Floor(color)
  112. elif obj_type == "ball":
  113. v = Ball(color)
  114. elif obj_type == "key":
  115. v = Key(color)
  116. elif obj_type == "box":
  117. v = Box(color)
  118. elif obj_type == "door":
  119. v = Door(color, is_open, is_locked)
  120. elif obj_type == "goal":
  121. v = Goal()
  122. elif obj_type == "lava":
  123. v = Lava()
  124. else:
  125. assert False, "unknown object type in decode '%s'" % obj_type
  126. return v
  127. def render(self, r):
  128. """Draw this object with the given renderer"""
  129. raise NotImplementedError
  130. class Goal(WorldObj):
  131. def __init__(self):
  132. super().__init__("goal", "green")
  133. def can_overlap(self):
  134. return True
  135. def render(self, img):
  136. fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color])
  137. class Floor(WorldObj):
  138. """
  139. Colored floor tile the agent can walk over
  140. """
  141. def __init__(self, color="blue"):
  142. super().__init__("floor", color)
  143. def can_overlap(self):
  144. return True
  145. def render(self, img):
  146. # Give the floor a pale color
  147. color = COLORS[self.color] / 2
  148. fill_coords(img, point_in_rect(0.031, 1, 0.031, 1), color)
  149. class Lava(WorldObj):
  150. def __init__(self):
  151. super().__init__("lava", "red")
  152. def can_overlap(self):
  153. return True
  154. def render(self, img):
  155. c = (255, 128, 0)
  156. # Background color
  157. fill_coords(img, point_in_rect(0, 1, 0, 1), c)
  158. # Little waves
  159. for i in range(3):
  160. ylo = 0.3 + 0.2 * i
  161. yhi = 0.4 + 0.2 * i
  162. fill_coords(img, point_in_line(0.1, ylo, 0.3, yhi, r=0.03), (0, 0, 0))
  163. fill_coords(img, point_in_line(0.3, yhi, 0.5, ylo, r=0.03), (0, 0, 0))
  164. fill_coords(img, point_in_line(0.5, ylo, 0.7, yhi, r=0.03), (0, 0, 0))
  165. fill_coords(img, point_in_line(0.7, yhi, 0.9, ylo, r=0.03), (0, 0, 0))
  166. class Wall(WorldObj):
  167. def __init__(self, color="grey"):
  168. super().__init__("wall", color)
  169. def see_behind(self):
  170. return False
  171. def render(self, img):
  172. fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color])
  173. class Door(WorldObj):
  174. def __init__(self, color, is_open=False, is_locked=False):
  175. super().__init__("door", color)
  176. self.is_open = is_open
  177. self.is_locked = is_locked
  178. def can_overlap(self):
  179. """The agent can only walk over this cell when the door is open"""
  180. return self.is_open
  181. def see_behind(self):
  182. return self.is_open
  183. def toggle(self, env, pos):
  184. # If the player has the right key to open the door
  185. if self.is_locked:
  186. if isinstance(env.carrying, Key) and env.carrying.color == self.color:
  187. self.is_locked = False
  188. self.is_open = True
  189. return True
  190. return False
  191. self.is_open = not self.is_open
  192. return True
  193. def encode(self):
  194. """Encode the a description of this object as a 3-tuple of integers"""
  195. # State, 0: open, 1: closed, 2: locked
  196. if self.is_open:
  197. state = 0
  198. elif self.is_locked:
  199. state = 2
  200. else:
  201. state = 1
  202. return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], state)
  203. def render(self, img):
  204. c = COLORS[self.color]
  205. if self.is_open:
  206. fill_coords(img, point_in_rect(0.88, 1.00, 0.00, 1.00), c)
  207. fill_coords(img, point_in_rect(0.92, 0.96, 0.04, 0.96), (0, 0, 0))
  208. return
  209. # Door frame and door
  210. if self.is_locked:
  211. fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)
  212. fill_coords(img, point_in_rect(0.06, 0.94, 0.06, 0.94), 0.45 * np.array(c))
  213. # Draw key slot
  214. fill_coords(img, point_in_rect(0.52, 0.75, 0.50, 0.56), c)
  215. else:
  216. fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)
  217. fill_coords(img, point_in_rect(0.04, 0.96, 0.04, 0.96), (0, 0, 0))
  218. fill_coords(img, point_in_rect(0.08, 0.92, 0.08, 0.92), c)
  219. fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), (0, 0, 0))
  220. # Draw door handle
  221. fill_coords(img, point_in_circle(cx=0.75, cy=0.50, r=0.08), c)
  222. class Key(WorldObj):
  223. def __init__(self, color="blue"):
  224. super().__init__("key", color)
  225. def can_pickup(self):
  226. return True
  227. def render(self, img):
  228. c = COLORS[self.color]
  229. # Vertical quad
  230. fill_coords(img, point_in_rect(0.50, 0.63, 0.31, 0.88), c)
  231. # Teeth
  232. fill_coords(img, point_in_rect(0.38, 0.50, 0.59, 0.66), c)
  233. fill_coords(img, point_in_rect(0.38, 0.50, 0.81, 0.88), c)
  234. # Ring
  235. fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.190), c)
  236. fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.064), (0, 0, 0))
  237. class Ball(WorldObj):
  238. def __init__(self, color="blue"):
  239. super().__init__("ball", color)
  240. def can_pickup(self):
  241. return True
  242. def render(self, img):
  243. fill_coords(img, point_in_circle(0.5, 0.5, 0.31), COLORS[self.color])
  244. class Box(WorldObj):
  245. def __init__(self, color, contains=None):
  246. super().__init__("box", color)
  247. self.contains = contains
  248. def can_pickup(self):
  249. return True
  250. def set_contains(self, contains):
  251. self.contains = contains
  252. def render(self, img):
  253. c = COLORS[self.color]
  254. # Outline
  255. fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), c)
  256. fill_coords(img, point_in_rect(0.18, 0.82, 0.18, 0.82), (0, 0, 0))
  257. # Horizontal slit
  258. fill_coords(img, point_in_rect(0.16, 0.84, 0.47, 0.53), c)
  259. def toggle(self, env, pos):
  260. # Replace the box by its contents
  261. env.grid.set(pos[0], pos[1], self.contains)
  262. return True
  263. class Grid:
  264. """
  265. Represent a grid and operations on it
  266. """
  267. # Static cache of pre-renderer tiles
  268. tile_cache = {}
  269. def __init__(self, width, height):
  270. assert width >= 3
  271. assert height >= 3
  272. self.width = width
  273. self.height = height
  274. self.grid = [None] * width * height
  275. def __contains__(self, key):
  276. if isinstance(key, WorldObj):
  277. for e in self.grid:
  278. if e is key:
  279. return True
  280. elif isinstance(key, tuple):
  281. for e in self.grid:
  282. if e is None:
  283. continue
  284. if (e.color, e.type) == key:
  285. return True
  286. if key[0] is None and key[1] == e.type:
  287. return True
  288. return False
  289. def __eq__(self, other):
  290. grid1 = self.encode()
  291. grid2 = other.encode()
  292. return np.array_equal(grid2, grid1)
  293. def __ne__(self, other):
  294. return not self == other
  295. def copy(self):
  296. from copy import deepcopy
  297. return deepcopy(self)
  298. def set(self, i, j, v):
  299. assert i >= 0 and i < self.width
  300. assert j >= 0 and j < self.height
  301. self.grid[j * self.width + i] = v
  302. def get(self, i, j):
  303. assert i >= 0 and i < self.width
  304. assert j >= 0 and j < self.height
  305. return self.grid[j * self.width + i]
  306. def horz_wall(self, x, y, length=None, obj_type=Wall):
  307. if length is None:
  308. length = self.width - x
  309. for i in range(0, length):
  310. self.set(x + i, y, obj_type())
  311. def vert_wall(self, x, y, length=None, obj_type=Wall):
  312. if length is None:
  313. length = self.height - y
  314. for j in range(0, length):
  315. self.set(x, y + j, obj_type())
  316. def wall_rect(self, x, y, w, h):
  317. self.horz_wall(x, y, w)
  318. self.horz_wall(x, y + h - 1, w)
  319. self.vert_wall(x, y, h)
  320. self.vert_wall(x + w - 1, y, h)
  321. def rotate_left(self):
  322. """
  323. Rotate the grid to the left (counter-clockwise)
  324. """
  325. grid = Grid(self.height, self.width)
  326. for i in range(self.width):
  327. for j in range(self.height):
  328. v = self.get(i, j)
  329. grid.set(j, grid.height - 1 - i, v)
  330. return grid
  331. def slice(self, topX, topY, width, height):
  332. """
  333. Get a subset of the grid
  334. """
  335. grid = Grid(width, height)
  336. for j in range(0, height):
  337. for i in range(0, width):
  338. x = topX + i
  339. y = topY + j
  340. if x >= 0 and x < self.width and y >= 0 and y < self.height:
  341. v = self.get(x, y)
  342. else:
  343. v = Wall()
  344. grid.set(i, j, v)
  345. return grid
  346. @classmethod
  347. def render_tile(
  348. cls, obj, agent_dir=None, highlight=False, tile_size=TILE_PIXELS, subdivs=3
  349. ):
  350. """
  351. Render a tile and cache the result
  352. """
  353. # Hash map lookup key for the cache
  354. key = (agent_dir, highlight, tile_size)
  355. key = obj.encode() + key if obj else key
  356. if key in cls.tile_cache:
  357. return cls.tile_cache[key]
  358. img = np.zeros(
  359. shape=(tile_size * subdivs, tile_size * subdivs, 3), dtype=np.uint8
  360. )
  361. # Draw the grid lines (top and left edges)
  362. fill_coords(img, point_in_rect(0, 0.031, 0, 1), (100, 100, 100))
  363. fill_coords(img, point_in_rect(0, 1, 0, 0.031), (100, 100, 100))
  364. if obj is not None:
  365. obj.render(img)
  366. # Overlay the agent on top
  367. if agent_dir is not None:
  368. tri_fn = point_in_triangle(
  369. (0.12, 0.19),
  370. (0.87, 0.50),
  371. (0.12, 0.81),
  372. )
  373. # Rotate the agent based on its direction
  374. tri_fn = rotate_fn(tri_fn, cx=0.5, cy=0.5, theta=0.5 * math.pi * agent_dir)
  375. fill_coords(img, tri_fn, (255, 0, 0))
  376. # Highlight the cell if needed
  377. if highlight:
  378. highlight_img(img)
  379. # Downsample the image to perform supersampling/anti-aliasing
  380. img = downsample(img, subdivs)
  381. # Cache the rendered tile
  382. cls.tile_cache[key] = img
  383. return img
  384. def render(self, tile_size, agent_pos, agent_dir=None, highlight_mask=None):
  385. """
  386. Render this grid at a given scale
  387. :param r: target renderer object
  388. :param tile_size: tile size in pixels
  389. """
  390. if highlight_mask is None:
  391. highlight_mask = np.zeros(shape=(self.width, self.height), dtype=bool)
  392. # Compute the total grid size
  393. width_px = self.width * tile_size
  394. height_px = self.height * tile_size
  395. img = np.zeros(shape=(height_px, width_px, 3), dtype=np.uint8)
  396. # Render the grid
  397. for j in range(0, self.height):
  398. for i in range(0, self.width):
  399. cell = self.get(i, j)
  400. agent_here = np.array_equal(agent_pos, (i, j))
  401. tile_img = Grid.render_tile(
  402. cell,
  403. agent_dir=agent_dir if agent_here else None,
  404. highlight=highlight_mask[i, j],
  405. tile_size=tile_size,
  406. )
  407. ymin = j * tile_size
  408. ymax = (j + 1) * tile_size
  409. xmin = i * tile_size
  410. xmax = (i + 1) * tile_size
  411. img[ymin:ymax, xmin:xmax, :] = tile_img
  412. return img
  413. def encode(self, vis_mask=None):
  414. """
  415. Produce a compact numpy encoding of the grid
  416. """
  417. if vis_mask is None:
  418. vis_mask = np.ones((self.width, self.height), dtype=bool)
  419. array = np.zeros((self.width, self.height, 3), dtype="uint8")
  420. for i in range(self.width):
  421. for j in range(self.height):
  422. if vis_mask[i, j]:
  423. v = self.get(i, j)
  424. if v is None:
  425. array[i, j, 0] = OBJECT_TO_IDX["empty"]
  426. array[i, j, 1] = 0
  427. array[i, j, 2] = 0
  428. else:
  429. array[i, j, :] = v.encode()
  430. return array
  431. @staticmethod
  432. def decode(array):
  433. """
  434. Decode an array grid encoding back into a grid
  435. """
  436. width, height, channels = array.shape
  437. assert channels == 3
  438. vis_mask = np.ones(shape=(width, height), dtype=bool)
  439. grid = Grid(width, height)
  440. for i in range(width):
  441. for j in range(height):
  442. type_idx, color_idx, state = array[i, j]
  443. v = WorldObj.decode(type_idx, color_idx, state)
  444. grid.set(i, j, v)
  445. vis_mask[i, j] = type_idx != OBJECT_TO_IDX["unseen"]
  446. return grid, vis_mask
  447. def process_vis(self, agent_pos):
  448. mask = np.zeros(shape=(self.width, self.height), dtype=bool)
  449. mask[agent_pos[0], agent_pos[1]] = True
  450. for j in reversed(range(0, self.height)):
  451. for i in range(0, self.width - 1):
  452. if not mask[i, j]:
  453. continue
  454. cell = self.get(i, j)
  455. if cell and not cell.see_behind():
  456. continue
  457. mask[i + 1, j] = True
  458. if j > 0:
  459. mask[i + 1, j - 1] = True
  460. mask[i, j - 1] = True
  461. for i in reversed(range(1, self.width)):
  462. if not mask[i, j]:
  463. continue
  464. cell = self.get(i, j)
  465. if cell and not cell.see_behind():
  466. continue
  467. mask[i - 1, j] = True
  468. if j > 0:
  469. mask[i - 1, j - 1] = True
  470. mask[i, j - 1] = True
  471. for j in range(0, self.height):
  472. for i in range(0, self.width):
  473. if not mask[i, j]:
  474. self.set(i, j, None)
  475. return mask
  476. class MiniGridEnv(gym.Env):
  477. """
  478. 2D grid world game environment
  479. """
  480. metadata = {
  481. # Deprecated: use 'render_modes' instead
  482. "render.modes": ["human", "rgb_array"],
  483. "video.frames_per_second": 10, # Deprecated: use 'render_fps' instead
  484. "render_modes": ["human", "rgb_array"],
  485. "render_fps": 10,
  486. }
  487. # Enumeration of possible actions
  488. class Actions(IntEnum):
  489. # Turn left, turn right, move forward
  490. left = 0
  491. right = 1
  492. forward = 2
  493. # Pick up an object
  494. pickup = 3
  495. # Drop an object
  496. drop = 4
  497. # Toggle/activate an object
  498. toggle = 5
  499. # Done completing task
  500. done = 6
  501. def __init__(
  502. self,
  503. grid_size=None,
  504. width=None,
  505. height=None,
  506. max_steps=100,
  507. see_through_walls=False,
  508. agent_view_size=7,
  509. render_mode=None,
  510. **kwargs
  511. ):
  512. # Can't set both grid_size and width/height
  513. if grid_size:
  514. assert width is None and height is None
  515. width = grid_size
  516. height = grid_size
  517. # Action enumeration for this environment
  518. self.actions = MiniGridEnv.Actions
  519. # Actions are discrete integer values
  520. self.action_space = spaces.Discrete(len(self.actions))
  521. # Number of cells (width and height) in the agent view
  522. assert agent_view_size % 2 == 1
  523. assert agent_view_size >= 3
  524. self.agent_view_size = agent_view_size
  525. # Observations are dictionaries containing an
  526. # encoding of the grid and a textual 'mission' string
  527. self.observation_space = spaces.Box(
  528. low=0,
  529. high=255,
  530. shape=(self.agent_view_size, self.agent_view_size, 3),
  531. dtype="uint8",
  532. )
  533. self.observation_space = spaces.Dict(
  534. {
  535. "image": self.observation_space,
  536. "direction": spaces.Discrete(4),
  537. "mission": spaces.Text(
  538. max_length=200,
  539. charset=string.ascii_letters + string.digits + " .,!-",
  540. ),
  541. }
  542. )
  543. # render mode
  544. self.render_mode = render_mode
  545. # Range of possible rewards
  546. self.reward_range = (0, 1)
  547. # Window to use for human rendering mode
  548. self.window = None
  549. # Environment configuration
  550. self.width = width
  551. self.height = height
  552. self.max_steps = max_steps
  553. self.see_through_walls = see_through_walls
  554. # Current position and direction of the agent
  555. self.agent_pos = (-1, -1)
  556. self.agent_dir = -1
  557. # Current grid and mission and carryinh
  558. self.grid = Grid(width, height)
  559. self.mission = ""
  560. self.carrying = None
  561. # Initialize the state
  562. self.reset()
  563. def reset(self, *, seed=None, return_info=False, options=None):
  564. super().reset(seed=seed)
  565. # Reinitialize episode-specific variables
  566. self.agent_pos = (-1, -1)
  567. self.agent_dir = -1
  568. # Generate a new random grid at the start of each episode
  569. self._gen_grid(self.width, self.height)
  570. # These fields should be defined by _gen_grid
  571. assert self.agent_pos >= (0, 0) and self.agent_dir >= 0
  572. # Check that the agent doesn't overlap with an object
  573. start_cell = self.grid.get(*self.agent_pos)
  574. assert start_cell is None or start_cell.can_overlap()
  575. # Item picked up, being carried, initially nothing
  576. self.carrying = None
  577. # Step count since episode start
  578. self.step_count = 0
  579. # Return first observation
  580. obs = self.gen_obs()
  581. return obs
  582. def hash(self, size=16):
  583. """Compute a hash that uniquely identifies the current state of the environment.
  584. :param size: Size of the hashing
  585. """
  586. sample_hash = hashlib.sha256()
  587. to_encode = [self.grid.encode().tolist(), self.agent_pos, self.agent_dir]
  588. for item in to_encode:
  589. sample_hash.update(str(item).encode("utf8"))
  590. return sample_hash.hexdigest()[:size]
  591. @property
  592. def steps_remaining(self):
  593. return self.max_steps - self.step_count
  594. def __str__(self):
  595. """
  596. Produce a pretty string of the environment's grid along with the agent.
  597. A grid cell is represented by 2-character string, the first one for
  598. the object and the second one for the color.
  599. """
  600. # Map of object types to short string
  601. OBJECT_TO_STR = {
  602. "wall": "W",
  603. "floor": "F",
  604. "door": "D",
  605. "key": "K",
  606. "ball": "A",
  607. "box": "B",
  608. "goal": "G",
  609. "lava": "V",
  610. }
  611. # Map agent's direction to short string
  612. AGENT_DIR_TO_STR = {0: ">", 1: "V", 2: "<", 3: "^"}
  613. str = ""
  614. for j in range(self.grid.height):
  615. for i in range(self.grid.width):
  616. if i == self.agent_pos[0] and j == self.agent_pos[1]:
  617. str += 2 * AGENT_DIR_TO_STR[self.agent_dir]
  618. continue
  619. c = self.grid.get(i, j)
  620. if c is None:
  621. str += " "
  622. continue
  623. if c.type == "door":
  624. if c.is_open:
  625. str += "__"
  626. elif c.is_locked:
  627. str += "L" + c.color[0].upper()
  628. else:
  629. str += "D" + c.color[0].upper()
  630. continue
  631. str += OBJECT_TO_STR[c.type] + c.color[0].upper()
  632. if j < self.grid.height - 1:
  633. str += "\n"
  634. return str
  635. @abstractmethod
  636. def _gen_grid(self, width, height):
  637. pass
  638. def _reward(self):
  639. """
  640. Compute the reward to be given upon success
  641. """
  642. return 1 - 0.9 * (self.step_count / self.max_steps)
  643. def _rand_int(self, low, high):
  644. """
  645. Generate random integer in [low,high[
  646. """
  647. return self.np_random.integers(low, high)
  648. def _rand_float(self, low, high):
  649. """
  650. Generate random float in [low,high[
  651. """
  652. return self.np_random.uniform(low, high)
  653. def _rand_bool(self):
  654. """
  655. Generate random boolean value
  656. """
  657. return self.np_random.integers(0, 2) == 0
  658. def _rand_elem(self, iterable):
  659. """
  660. Pick a random element in a list
  661. """
  662. lst = list(iterable)
  663. idx = self._rand_int(0, len(lst))
  664. return lst[idx]
  665. def _rand_subset(self, iterable, num_elems):
  666. """
  667. Sample a random subset of distinct elements of a list
  668. """
  669. lst = list(iterable)
  670. assert num_elems <= len(lst)
  671. out = []
  672. while len(out) < num_elems:
  673. elem = self._rand_elem(lst)
  674. lst.remove(elem)
  675. out.append(elem)
  676. return out
  677. def _rand_color(self):
  678. """
  679. Generate a random color name (string)
  680. """
  681. return self._rand_elem(COLOR_NAMES)
  682. def _rand_pos(self, xLow, xHigh, yLow, yHigh):
  683. """
  684. Generate a random (x,y) position tuple
  685. """
  686. return (
  687. self.np_random.integers(xLow, xHigh),
  688. self.np_random.integers(yLow, yHigh),
  689. )
  690. def place_obj(self, obj, top=None, size=None, reject_fn=None, max_tries=math.inf):
  691. """
  692. Place an object at an empty position in the grid
  693. :param top: top-left position of the rectangle where to place
  694. :param size: size of the rectangle where to place
  695. :param reject_fn: function to filter out potential positions
  696. """
  697. if top is None:
  698. top = (0, 0)
  699. else:
  700. top = (max(top[0], 0), max(top[1], 0))
  701. if size is None:
  702. size = (self.grid.width, self.grid.height)
  703. num_tries = 0
  704. while True:
  705. # This is to handle with rare cases where rejection sampling
  706. # gets stuck in an infinite loop
  707. if num_tries > max_tries:
  708. raise RecursionError("rejection sampling failed in place_obj")
  709. num_tries += 1
  710. pos = np.array(
  711. (
  712. self._rand_int(top[0], min(top[0] + size[0], self.grid.width)),
  713. self._rand_int(top[1], min(top[1] + size[1], self.grid.height)),
  714. )
  715. )
  716. pos = tuple(pos)
  717. # Don't place the object on top of another object
  718. if self.grid.get(*pos) is not None:
  719. continue
  720. # Don't place the object where the agent is
  721. if np.array_equal(pos, self.agent_pos):
  722. continue
  723. # Check if there is a filtering criterion
  724. if reject_fn and reject_fn(self, pos):
  725. continue
  726. break
  727. self.grid.set(pos[0], pos[1], obj)
  728. if obj is not None:
  729. obj.init_pos = pos
  730. obj.cur_pos = pos
  731. return pos
  732. def put_obj(self, obj, i, j):
  733. """
  734. Put an object at a specific position in the grid
  735. """
  736. self.grid.set(i, j, obj)
  737. obj.init_pos = (i, j)
  738. obj.cur_pos = (i, j)
  739. def place_agent(self, top=None, size=None, rand_dir=True, max_tries=math.inf):
  740. """
  741. Set the agent's starting point at an empty position in the grid
  742. """
  743. self.agent_pos = (-1, -1)
  744. pos = self.place_obj(None, top, size, max_tries=max_tries)
  745. self.agent_pos = pos
  746. if rand_dir:
  747. self.agent_dir = self._rand_int(0, 4)
  748. return pos
  749. @property
  750. def dir_vec(self):
  751. """
  752. Get the direction vector for the agent, pointing in the direction
  753. of forward movement.
  754. """
  755. assert self.agent_dir >= 0 and self.agent_dir < 4
  756. return DIR_TO_VEC[self.agent_dir]
  757. @property
  758. def right_vec(self):
  759. """
  760. Get the vector pointing to the right of the agent.
  761. """
  762. dx, dy = self.dir_vec
  763. return np.array((-dy, dx))
  764. @property
  765. def front_pos(self):
  766. """
  767. Get the position of the cell that is right in front of the agent
  768. """
  769. return self.agent_pos + self.dir_vec
  770. def get_view_coords(self, i, j):
  771. """
  772. Translate and rotate absolute grid coordinates (i, j) into the
  773. agent's partially observable view (sub-grid). Note that the resulting
  774. coordinates may be negative or outside of the agent's view size.
  775. """
  776. ax, ay = self.agent_pos
  777. dx, dy = self.dir_vec
  778. rx, ry = self.right_vec
  779. # Compute the absolute coordinates of the top-left view corner
  780. sz = self.agent_view_size
  781. hs = self.agent_view_size // 2
  782. tx = ax + (dx * (sz - 1)) - (rx * hs)
  783. ty = ay + (dy * (sz - 1)) - (ry * hs)
  784. lx = i - tx
  785. ly = j - ty
  786. # Project the coordinates of the object relative to the top-left
  787. # corner onto the agent's own coordinate system
  788. vx = rx * lx + ry * ly
  789. vy = -(dx * lx + dy * ly)
  790. return vx, vy
  791. def get_view_exts(self, agent_view_size=None):
  792. """
  793. Get the extents of the square set of tiles visible to the agent
  794. Note: the bottom extent indices are not included in the set
  795. if agent_view_size is None, use self.agent_view_size
  796. """
  797. agent_view_size = agent_view_size or self.agent_view_size
  798. # Facing right
  799. if self.agent_dir == 0:
  800. topX = self.agent_pos[0]
  801. topY = self.agent_pos[1] - agent_view_size // 2
  802. # Facing down
  803. elif self.agent_dir == 1:
  804. topX = self.agent_pos[0] - agent_view_size // 2
  805. topY = self.agent_pos[1]
  806. # Facing left
  807. elif self.agent_dir == 2:
  808. topX = self.agent_pos[0] - agent_view_size + 1
  809. topY = self.agent_pos[1] - agent_view_size // 2
  810. # Facing up
  811. elif self.agent_dir == 3:
  812. topX = self.agent_pos[0] - agent_view_size // 2
  813. topY = self.agent_pos[1] - agent_view_size + 1
  814. else:
  815. assert False, "invalid agent direction"
  816. botX = topX + agent_view_size
  817. botY = topY + agent_view_size
  818. return (topX, topY, botX, botY)
  819. def relative_coords(self, x, y):
  820. """
  821. Check if a grid position belongs to the agent's field of view, and returns the corresponding coordinates
  822. """
  823. vx, vy = self.get_view_coords(x, y)
  824. if vx < 0 or vy < 0 or vx >= self.agent_view_size or vy >= self.agent_view_size:
  825. return None
  826. return vx, vy
  827. def in_view(self, x, y):
  828. """
  829. check if a grid position is visible to the agent
  830. """
  831. return self.relative_coords(x, y) is not None
  832. def agent_sees(self, x, y):
  833. """
  834. Check if a non-empty grid position is visible to the agent
  835. """
  836. coordinates = self.relative_coords(x, y)
  837. if coordinates is None:
  838. return False
  839. vx, vy = coordinates
  840. obs = self.gen_obs()
  841. obs_grid, _ = Grid.decode(obs["image"])
  842. obs_cell = obs_grid.get(vx, vy)
  843. world_cell = self.grid.get(x, y)
  844. assert world_cell is not None
  845. return obs_cell is not None and obs_cell.type == world_cell.type
  846. def step(self, action):
  847. self.step_count += 1
  848. reward = 0
  849. terminated = False
  850. truncated = False
  851. # Get the position in front of the agent
  852. fwd_pos = self.front_pos
  853. # Get the contents of the cell in front of the agent
  854. fwd_cell = self.grid.get(*fwd_pos)
  855. # Rotate left
  856. if action == self.actions.left:
  857. self.agent_dir -= 1
  858. if self.agent_dir < 0:
  859. self.agent_dir += 4
  860. # Rotate right
  861. elif action == self.actions.right:
  862. self.agent_dir = (self.agent_dir + 1) % 4
  863. # Move forward
  864. elif action == self.actions.forward:
  865. if fwd_cell is None or fwd_cell.can_overlap():
  866. self.agent_pos = tuple(fwd_pos)
  867. if fwd_cell is not None and fwd_cell.type == "goal":
  868. terminated = True
  869. reward = self._reward()
  870. if fwd_cell is not None and fwd_cell.type == "lava":
  871. terminated = True
  872. # Pick up an object
  873. elif action == self.actions.pickup:
  874. if fwd_cell and fwd_cell.can_pickup():
  875. if self.carrying is None:
  876. self.carrying = fwd_cell
  877. self.carrying.cur_pos = np.array([-1, -1])
  878. self.grid.set(fwd_pos[0], fwd_pos[1], None)
  879. # Drop an object
  880. elif action == self.actions.drop:
  881. if not fwd_cell and self.carrying:
  882. self.grid.set(fwd_pos[0], fwd_pos[1], self.carrying)
  883. self.carrying.cur_pos = fwd_pos
  884. self.carrying = None
  885. # Toggle/activate an object
  886. elif action == self.actions.toggle:
  887. if fwd_cell:
  888. fwd_cell.toggle(self, fwd_pos)
  889. # Done action (not used by default)
  890. elif action == self.actions.done:
  891. pass
  892. else:
  893. raise ValueError(f"Unknown action: {action}")
  894. if self.step_count >= self.max_steps:
  895. truncated = True
  896. obs = self.gen_obs()
  897. return obs, reward, terminated, truncated, {}
  898. def gen_obs_grid(self, agent_view_size=None):
  899. """
  900. Generate the sub-grid observed by the agent.
  901. This method also outputs a visibility mask telling us which grid
  902. cells the agent can actually see.
  903. if agent_view_size is None, self.agent_view_size is used
  904. """
  905. topX, topY, botX, botY = self.get_view_exts(agent_view_size)
  906. agent_view_size = agent_view_size or self.agent_view_size
  907. grid = self.grid.slice(topX, topY, agent_view_size, agent_view_size)
  908. for i in range(self.agent_dir + 1):
  909. grid = grid.rotate_left()
  910. # Process occluders and visibility
  911. # Note that this incurs some performance cost
  912. if not self.see_through_walls:
  913. vis_mask = grid.process_vis(
  914. agent_pos=(agent_view_size // 2, agent_view_size - 1)
  915. )
  916. else:
  917. vis_mask = np.ones(shape=(grid.width, grid.height), dtype=bool)
  918. # Make it so the agent sees what it's carrying
  919. # We do this by placing the carried object at the agent's position
  920. # in the agent's partially observable view
  921. agent_pos = grid.width // 2, grid.height - 1
  922. if self.carrying:
  923. grid.set(*agent_pos, self.carrying)
  924. else:
  925. grid.set(*agent_pos, None)
  926. return grid, vis_mask
  927. def gen_obs(self):
  928. """
  929. Generate the agent's view (partially observable, low-resolution encoding)
  930. """
  931. grid, vis_mask = self.gen_obs_grid()
  932. # Encode the partially observable view into a numpy array
  933. image = grid.encode(vis_mask)
  934. # Observations are dictionaries containing:
  935. # - an image (partially observable view of the environment)
  936. # - the agent's direction/orientation (acting as a compass)
  937. # - a textual mission string (instructions for the agent)
  938. obs = {"image": image, "direction": self.agent_dir, "mission": self.mission}
  939. return obs
  940. def get_obs_render(self, obs, tile_size=TILE_PIXELS // 2):
  941. """
  942. Render an agent observation for visualization
  943. """
  944. grid, vis_mask = Grid.decode(obs)
  945. # Render the whole grid
  946. img = grid.render(
  947. tile_size,
  948. agent_pos=(self.agent_view_size // 2, self.agent_view_size - 1),
  949. agent_dir=3,
  950. highlight_mask=vis_mask,
  951. )
  952. return img
  953. def render(self, mode="human", close=False, highlight=True, tile_size=TILE_PIXELS):
  954. """
  955. Render the whole-grid human view
  956. """
  957. if self.render_mode is not None:
  958. mode = self.render_mode
  959. if close:
  960. if self.window:
  961. self.window.close()
  962. return
  963. if mode == "human" and not self.window:
  964. import gym_minigrid.window
  965. self.window = gym_minigrid.window.Window("gym_minigrid")
  966. self.window.show(block=False)
  967. # Compute which cells are visible to the agent
  968. _, vis_mask = self.gen_obs_grid()
  969. # Compute the world coordinates of the bottom-left corner
  970. # of the agent's view area
  971. f_vec = self.dir_vec
  972. r_vec = self.right_vec
  973. top_left = (
  974. self.agent_pos
  975. + f_vec * (self.agent_view_size - 1)
  976. - r_vec * (self.agent_view_size // 2)
  977. )
  978. # Mask of which cells to highlight
  979. highlight_mask = np.zeros(shape=(self.width, self.height), dtype=bool)
  980. # For each cell in the visibility mask
  981. for vis_j in range(0, self.agent_view_size):
  982. for vis_i in range(0, self.agent_view_size):
  983. # If this cell is not visible, don't highlight it
  984. if not vis_mask[vis_i, vis_j]:
  985. continue
  986. # Compute the world coordinates of this cell
  987. abs_i, abs_j = top_left - (f_vec * vis_j) + (r_vec * vis_i)
  988. if abs_i < 0 or abs_i >= self.width:
  989. continue
  990. if abs_j < 0 or abs_j >= self.height:
  991. continue
  992. # Mark this cell to be highlighted
  993. highlight_mask[abs_i, abs_j] = True
  994. # Render the whole grid
  995. img = self.grid.render(
  996. tile_size,
  997. self.agent_pos,
  998. self.agent_dir,
  999. highlight_mask=highlight_mask if highlight else None,
  1000. )
  1001. if mode == "human":
  1002. assert self.window is not None
  1003. self.window.set_caption(self.mission)
  1004. self.window.show_img(img)
  1005. return img
  1006. def close(self):
  1007. if self.window:
  1008. self.window.close()
  1009. return