minigrid.py 45 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509
  1. import hashlib
  2. import math
  3. from abc import abstractmethod
  4. from enum import IntEnum
  5. from typing import Any, Callable, Optional, Union
  6. import gym
  7. import numpy as np
  8. from gym import spaces
  9. from gym.utils import seeding
  10. # Size in pixels of a tile in the full-scale human view
  11. from gym_minigrid.rendering import (
  12. downsample,
  13. fill_coords,
  14. highlight_img,
  15. point_in_circle,
  16. point_in_line,
  17. point_in_rect,
  18. point_in_triangle,
  19. rotate_fn,
  20. )
  21. from gym_minigrid.window import Window
  22. TILE_PIXELS = 32
  23. # Map of color names to RGB values
  24. COLORS = {
  25. "red": np.array([255, 0, 0]),
  26. "green": np.array([0, 255, 0]),
  27. "blue": np.array([0, 0, 255]),
  28. "purple": np.array([112, 39, 195]),
  29. "yellow": np.array([255, 255, 0]),
  30. "grey": np.array([100, 100, 100]),
  31. }
  32. COLOR_NAMES = sorted(list(COLORS.keys()))
  33. # Used to map colors to integers
  34. COLOR_TO_IDX = {"red": 0, "green": 1, "blue": 2, "purple": 3, "yellow": 4, "grey": 5}
  35. IDX_TO_COLOR = dict(zip(COLOR_TO_IDX.values(), COLOR_TO_IDX.keys()))
  36. # Map of object type to integers
  37. OBJECT_TO_IDX = {
  38. "unseen": 0,
  39. "empty": 1,
  40. "wall": 2,
  41. "floor": 3,
  42. "door": 4,
  43. "key": 5,
  44. "ball": 6,
  45. "box": 7,
  46. "goal": 8,
  47. "lava": 9,
  48. "agent": 10,
  49. }
  50. IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys()))
  51. # Map of state names to integers
  52. STATE_TO_IDX = {
  53. "open": 0,
  54. "closed": 1,
  55. "locked": 2,
  56. }
  57. # Map of agent direction indices to vectors
  58. DIR_TO_VEC = [
  59. # Pointing right (positive X)
  60. np.array((1, 0)),
  61. # Down (positive Y)
  62. np.array((0, 1)),
  63. # Pointing left (negative X)
  64. np.array((-1, 0)),
  65. # Up (negative Y)
  66. np.array((0, -1)),
  67. ]
  68. def check_if_no_duplicate(duplicate_list: list) -> bool:
  69. """Check if given list contains any duplicates"""
  70. return len(set(duplicate_list)) == len(duplicate_list)
  71. class MissionSpace(spaces.Space[str]):
  72. r"""A space representing a mission for the Gym-Minigrid environments.
  73. The space allows generating random mission strings constructed with an input placeholder list.
  74. Example Usage::
  75. >>> observation_space = MissionSpace(mission_func=lambda color: "Get the {} ball.".format(color),
  76. ordered_placeholders=[["green", "blue"]])
  77. >>> observation_space.sample()
  78. "Get the green ball."
  79. >>> observation_space = MissionSpace(mission_func=lambda color: "Get the ball.".,
  80. ordered_placeholders=None)
  81. >>> observation_space.sample()
  82. "Get the ball."
  83. """
  84. def __init__(
  85. self,
  86. mission_func: Callable[..., str],
  87. ordered_placeholders: Optional["list[list[str]]"] = None,
  88. seed: Optional[Union[int, seeding.RandomNumberGenerator]] = None,
  89. ):
  90. r"""Constructor of :class:`MissionSpace` space.
  91. Args:
  92. mission_func (lambda _placeholders(str): _mission(str)): Function that generates a mission string from random placeholders.
  93. ordered_placeholders (Optional["list[list[str]]"]): List of lists of placeholders ordered in placing order in the mission function mission_func.
  94. seed: seed: The seed for sampling from the space.
  95. """
  96. # Check that the ordered placeholders and mission function are well defined.
  97. if ordered_placeholders is not None:
  98. assert (
  99. len(ordered_placeholders) == mission_func.__code__.co_argcount
  100. ), "The number of placeholders {} is different from the number of parameters in the mission function {}.".format(
  101. len(ordered_placeholders), mission_func.__code__.co_argcount
  102. )
  103. for placeholder_list in ordered_placeholders:
  104. assert check_if_no_duplicate(
  105. placeholder_list
  106. ), "Make sure that the placeholders don't have any duplicate values."
  107. else:
  108. assert (
  109. mission_func.__code__.co_argcount == 0
  110. ), "If the ordered placeholders are {}, the mission function shouldn't have any parameters.".format(
  111. ordered_placeholders
  112. )
  113. self.ordered_placeholders = ordered_placeholders
  114. self.mission_func = mission_func
  115. super().__init__(dtype=str, seed=seed)
  116. # Check that mission_func returns a string
  117. sampled_mission = self.sample()
  118. assert isinstance(
  119. sampled_mission, str
  120. ), f"mission_func must return type str not {type(sampled_mission)}"
  121. def sample(self) -> str:
  122. """Sample a random mission string."""
  123. if self.ordered_placeholders is not None:
  124. placeholders = []
  125. for rand_var_list in self.ordered_placeholders:
  126. idx = self.np_random.integers(0, len(rand_var_list))
  127. placeholders.append(rand_var_list[idx])
  128. return self.mission_func(*placeholders)
  129. else:
  130. return self.mission_func()
  131. def contains(self, x: Any) -> bool:
  132. """Return boolean specifying if x is a valid member of this space."""
  133. # Store a list of all the placeholders from self.ordered_placeholders that appear in x
  134. if self.ordered_placeholders is not None:
  135. check_placeholder_list = []
  136. for placeholder_list in self.ordered_placeholders:
  137. for placeholder in placeholder_list:
  138. if placeholder in x:
  139. check_placeholder_list.append(placeholder)
  140. # Remove duplicates from the list
  141. check_placeholder_list = list(set(check_placeholder_list))
  142. start_id_placeholder = []
  143. end_id_placeholder = []
  144. # Get the starting and ending id of the identified placeholders with possible duplicates
  145. new_check_placeholder_list = []
  146. for placeholder in check_placeholder_list:
  147. new_start_id_placeholder = [
  148. i for i in range(len(x)) if x.startswith(placeholder, i)
  149. ]
  150. new_check_placeholder_list += [placeholder] * len(
  151. new_start_id_placeholder
  152. )
  153. end_id_placeholder += [
  154. start_id + len(placeholder) - 1
  155. for start_id in new_start_id_placeholder
  156. ]
  157. start_id_placeholder += new_start_id_placeholder
  158. # Order by starting id the placeholders
  159. ordered_placeholder_list = sorted(
  160. zip(
  161. start_id_placeholder, end_id_placeholder, new_check_placeholder_list
  162. )
  163. )
  164. # Check for repeated placeholders contained in each other
  165. remove_placeholder_id = []
  166. for i, placeholder_1 in enumerate(ordered_placeholder_list):
  167. starting_id = i + 1
  168. for j, placeholder_2 in enumerate(
  169. ordered_placeholder_list[starting_id:]
  170. ):
  171. # Check if place holder ids overlap and keep the longest
  172. if max(placeholder_1[0], placeholder_2[0]) < min(
  173. placeholder_1[1], placeholder_2[1]
  174. ):
  175. remove_placeholder = min(
  176. placeholder_1[2], placeholder_2[2], key=len
  177. )
  178. if remove_placeholder == placeholder_1[2]:
  179. remove_placeholder_id.append(i)
  180. else:
  181. remove_placeholder_id.append(i + j + 1)
  182. for id in remove_placeholder_id:
  183. del ordered_placeholder_list[id]
  184. final_placeholders = [
  185. placeholder[2] for placeholder in ordered_placeholder_list
  186. ]
  187. # Check that the identified final placeholders are in the same order as the original placeholders.
  188. for orered_placeholder, final_placeholder in zip(
  189. self.ordered_placeholders, final_placeholders
  190. ):
  191. if final_placeholder in orered_placeholder:
  192. continue
  193. else:
  194. return False
  195. try:
  196. mission_string_with_placeholders = self.mission_func(
  197. *final_placeholders
  198. )
  199. except Exception as e:
  200. print(
  201. f"{x} is not contained in MissionSpace due to the following exception: {e}"
  202. )
  203. return False
  204. return bool(mission_string_with_placeholders == x)
  205. else:
  206. return bool(self.mission_func() == x)
  207. def __repr__(self) -> str:
  208. """Gives a string representation of this space."""
  209. return f"MissionSpace({self.mission_func}, {self.ordered_placeholders})"
  210. def __eq__(self, other) -> bool:
  211. """Check whether ``other`` is equivalent to this instance."""
  212. if isinstance(other, MissionSpace):
  213. # Check that place holder lists are the same
  214. if self.ordered_placeholders is not None:
  215. # Check length
  216. if len(self.order_placeholder) == len(other.order_placeholder):
  217. # Placeholder list are ordered in placing order in the mission string
  218. for placeholder, other_placeholder in zip(
  219. self.order_placeholder, other.order_placeholder
  220. ):
  221. if set(placeholder) == set(other_placeholder):
  222. continue
  223. else:
  224. return False
  225. # Check mission string is the same with dummy space placeholders
  226. test_placeholders = [""] * len(self.order_placeholder)
  227. mission = self.mission_func(*test_placeholders)
  228. other_mission = other.mission_func(*test_placeholders)
  229. return mission == other_mission
  230. else:
  231. # Check that other is also None
  232. if other.ordered_placeholders is None:
  233. # Check mission string is the same
  234. mission = self.mission_func()
  235. other_mission = other.mission_func()
  236. return mission == other_mission
  237. # If none of the statements above return then False
  238. return False
  239. class WorldObj:
  240. """
  241. Base class for grid world objects
  242. """
  243. def __init__(self, type, color):
  244. assert type in OBJECT_TO_IDX, type
  245. assert color in COLOR_TO_IDX, color
  246. self.type = type
  247. self.color = color
  248. self.contains = None
  249. # Initial position of the object
  250. self.init_pos = None
  251. # Current position of the object
  252. self.cur_pos = None
  253. def can_overlap(self):
  254. """Can the agent overlap with this?"""
  255. return False
  256. def can_pickup(self):
  257. """Can the agent pick this up?"""
  258. return False
  259. def can_contain(self):
  260. """Can this contain another object?"""
  261. return False
  262. def see_behind(self):
  263. """Can the agent see behind this object?"""
  264. return True
  265. def toggle(self, env, pos):
  266. """Method to trigger/toggle an action this object performs"""
  267. return False
  268. def encode(self):
  269. """Encode the a description of this object as a 3-tuple of integers"""
  270. return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], 0)
  271. @staticmethod
  272. def decode(type_idx, color_idx, state):
  273. """Create an object from a 3-tuple state description"""
  274. obj_type = IDX_TO_OBJECT[type_idx]
  275. color = IDX_TO_COLOR[color_idx]
  276. if obj_type == "empty" or obj_type == "unseen":
  277. return None
  278. # State, 0: open, 1: closed, 2: locked
  279. is_open = state == 0
  280. is_locked = state == 2
  281. if obj_type == "wall":
  282. v = Wall(color)
  283. elif obj_type == "floor":
  284. v = Floor(color)
  285. elif obj_type == "ball":
  286. v = Ball(color)
  287. elif obj_type == "key":
  288. v = Key(color)
  289. elif obj_type == "box":
  290. v = Box(color)
  291. elif obj_type == "door":
  292. v = Door(color, is_open, is_locked)
  293. elif obj_type == "goal":
  294. v = Goal()
  295. elif obj_type == "lava":
  296. v = Lava()
  297. else:
  298. assert False, "unknown object type in decode '%s'" % obj_type
  299. return v
  300. def render(self, r):
  301. """Draw this object with the given renderer"""
  302. raise NotImplementedError
  303. class Goal(WorldObj):
  304. def __init__(self):
  305. super().__init__("goal", "green")
  306. def can_overlap(self):
  307. return True
  308. def render(self, img):
  309. fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color])
  310. class Floor(WorldObj):
  311. """
  312. Colored floor tile the agent can walk over
  313. """
  314. def __init__(self, color="blue"):
  315. super().__init__("floor", color)
  316. def can_overlap(self):
  317. return True
  318. def render(self, img):
  319. # Give the floor a pale color
  320. color = COLORS[self.color] / 2
  321. fill_coords(img, point_in_rect(0.031, 1, 0.031, 1), color)
  322. class Lava(WorldObj):
  323. def __init__(self):
  324. super().__init__("lava", "red")
  325. def can_overlap(self):
  326. return True
  327. def render(self, img):
  328. c = (255, 128, 0)
  329. # Background color
  330. fill_coords(img, point_in_rect(0, 1, 0, 1), c)
  331. # Little waves
  332. for i in range(3):
  333. ylo = 0.3 + 0.2 * i
  334. yhi = 0.4 + 0.2 * i
  335. fill_coords(img, point_in_line(0.1, ylo, 0.3, yhi, r=0.03), (0, 0, 0))
  336. fill_coords(img, point_in_line(0.3, yhi, 0.5, ylo, r=0.03), (0, 0, 0))
  337. fill_coords(img, point_in_line(0.5, ylo, 0.7, yhi, r=0.03), (0, 0, 0))
  338. fill_coords(img, point_in_line(0.7, yhi, 0.9, ylo, r=0.03), (0, 0, 0))
  339. class Wall(WorldObj):
  340. def __init__(self, color="grey"):
  341. super().__init__("wall", color)
  342. def see_behind(self):
  343. return False
  344. def render(self, img):
  345. fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color])
  346. class Door(WorldObj):
  347. def __init__(self, color, is_open=False, is_locked=False):
  348. super().__init__("door", color)
  349. self.is_open = is_open
  350. self.is_locked = is_locked
  351. def can_overlap(self):
  352. """The agent can only walk over this cell when the door is open"""
  353. return self.is_open
  354. def see_behind(self):
  355. return self.is_open
  356. def toggle(self, env, pos):
  357. # If the player has the right key to open the door
  358. if self.is_locked:
  359. if isinstance(env.carrying, Key) and env.carrying.color == self.color:
  360. self.is_locked = False
  361. self.is_open = True
  362. return True
  363. return False
  364. self.is_open = not self.is_open
  365. return True
  366. def encode(self):
  367. """Encode the a description of this object as a 3-tuple of integers"""
  368. # State, 0: open, 1: closed, 2: locked
  369. if self.is_open:
  370. state = 0
  371. elif self.is_locked:
  372. state = 2
  373. # if door is closed and unlocked
  374. elif not self.is_open:
  375. state = 1
  376. else:
  377. raise ValueError(
  378. "There is no possible state encoding for the state:\n -Door Open: {}\n -Door Closed: {}\n -Door Locked: {}".format(
  379. self.is_open, not self.is_open, self.is_locked
  380. )
  381. )
  382. return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], state)
  383. def render(self, img):
  384. c = COLORS[self.color]
  385. if self.is_open:
  386. fill_coords(img, point_in_rect(0.88, 1.00, 0.00, 1.00), c)
  387. fill_coords(img, point_in_rect(0.92, 0.96, 0.04, 0.96), (0, 0, 0))
  388. return
  389. # Door frame and door
  390. if self.is_locked:
  391. fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)
  392. fill_coords(img, point_in_rect(0.06, 0.94, 0.06, 0.94), 0.45 * np.array(c))
  393. # Draw key slot
  394. fill_coords(img, point_in_rect(0.52, 0.75, 0.50, 0.56), c)
  395. else:
  396. fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)
  397. fill_coords(img, point_in_rect(0.04, 0.96, 0.04, 0.96), (0, 0, 0))
  398. fill_coords(img, point_in_rect(0.08, 0.92, 0.08, 0.92), c)
  399. fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), (0, 0, 0))
  400. # Draw door handle
  401. fill_coords(img, point_in_circle(cx=0.75, cy=0.50, r=0.08), c)
  402. class Key(WorldObj):
  403. def __init__(self, color="blue"):
  404. super().__init__("key", color)
  405. def can_pickup(self):
  406. return True
  407. def render(self, img):
  408. c = COLORS[self.color]
  409. # Vertical quad
  410. fill_coords(img, point_in_rect(0.50, 0.63, 0.31, 0.88), c)
  411. # Teeth
  412. fill_coords(img, point_in_rect(0.38, 0.50, 0.59, 0.66), c)
  413. fill_coords(img, point_in_rect(0.38, 0.50, 0.81, 0.88), c)
  414. # Ring
  415. fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.190), c)
  416. fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.064), (0, 0, 0))
  417. class Ball(WorldObj):
  418. def __init__(self, color="blue"):
  419. super().__init__("ball", color)
  420. def can_pickup(self):
  421. return True
  422. def render(self, img):
  423. fill_coords(img, point_in_circle(0.5, 0.5, 0.31), COLORS[self.color])
  424. class Box(WorldObj):
  425. def __init__(self, color, contains=None):
  426. super().__init__("box", color)
  427. self.contains = contains
  428. def can_pickup(self):
  429. return True
  430. def render(self, img):
  431. c = COLORS[self.color]
  432. # Outline
  433. fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), c)
  434. fill_coords(img, point_in_rect(0.18, 0.82, 0.18, 0.82), (0, 0, 0))
  435. # Horizontal slit
  436. fill_coords(img, point_in_rect(0.16, 0.84, 0.47, 0.53), c)
  437. def toggle(self, env, pos):
  438. # Replace the box by its contents
  439. env.grid.set(*pos, self.contains)
  440. return True
  441. class Grid:
  442. """
  443. Represent a grid and operations on it
  444. """
  445. # Static cache of pre-renderer tiles
  446. tile_cache = {}
  447. def __init__(self, width, height):
  448. assert width >= 3
  449. assert height >= 3
  450. self.width = width
  451. self.height = height
  452. self.grid = [None] * width * height
  453. def __contains__(self, key):
  454. if isinstance(key, WorldObj):
  455. for e in self.grid:
  456. if e is key:
  457. return True
  458. elif isinstance(key, tuple):
  459. for e in self.grid:
  460. if e is None:
  461. continue
  462. if (e.color, e.type) == key:
  463. return True
  464. if key[0] is None and key[1] == e.type:
  465. return True
  466. return False
  467. def __eq__(self, other):
  468. grid1 = self.encode()
  469. grid2 = other.encode()
  470. return np.array_equal(grid2, grid1)
  471. def __ne__(self, other):
  472. return not self == other
  473. def copy(self):
  474. from copy import deepcopy
  475. return deepcopy(self)
  476. def set(self, i, j, v):
  477. assert i >= 0 and i < self.width
  478. assert j >= 0 and j < self.height
  479. self.grid[j * self.width + i] = v
  480. def get(self, i, j):
  481. assert i >= 0 and i < self.width
  482. assert j >= 0 and j < self.height
  483. return self.grid[j * self.width + i]
  484. def horz_wall(self, x, y, length=None, obj_type=Wall):
  485. if length is None:
  486. length = self.width - x
  487. for i in range(0, length):
  488. self.set(x + i, y, obj_type())
  489. def vert_wall(self, x, y, length=None, obj_type=Wall):
  490. if length is None:
  491. length = self.height - y
  492. for j in range(0, length):
  493. self.set(x, y + j, obj_type())
  494. def wall_rect(self, x, y, w, h):
  495. self.horz_wall(x, y, w)
  496. self.horz_wall(x, y + h - 1, w)
  497. self.vert_wall(x, y, h)
  498. self.vert_wall(x + w - 1, y, h)
  499. def rotate_left(self):
  500. """
  501. Rotate the grid to the left (counter-clockwise)
  502. """
  503. grid = Grid(self.height, self.width)
  504. for i in range(self.width):
  505. for j in range(self.height):
  506. v = self.get(i, j)
  507. grid.set(j, grid.height - 1 - i, v)
  508. return grid
  509. def slice(self, topX, topY, width, height):
  510. """
  511. Get a subset of the grid
  512. """
  513. grid = Grid(width, height)
  514. for j in range(0, height):
  515. for i in range(0, width):
  516. x = topX + i
  517. y = topY + j
  518. if x >= 0 and x < self.width and y >= 0 and y < self.height:
  519. v = self.get(x, y)
  520. else:
  521. v = Wall()
  522. grid.set(i, j, v)
  523. return grid
  524. @classmethod
  525. def render_tile(
  526. cls, obj, agent_dir=None, highlight=False, tile_size=TILE_PIXELS, subdivs=3
  527. ):
  528. """
  529. Render a tile and cache the result
  530. """
  531. # Hash map lookup key for the cache
  532. key = (agent_dir, highlight, tile_size)
  533. key = obj.encode() + key if obj else key
  534. if key in cls.tile_cache:
  535. return cls.tile_cache[key]
  536. img = np.zeros(
  537. shape=(tile_size * subdivs, tile_size * subdivs, 3), dtype=np.uint8
  538. )
  539. # Draw the grid lines (top and left edges)
  540. fill_coords(img, point_in_rect(0, 0.031, 0, 1), (100, 100, 100))
  541. fill_coords(img, point_in_rect(0, 1, 0, 0.031), (100, 100, 100))
  542. if obj is not None:
  543. obj.render(img)
  544. # Overlay the agent on top
  545. if agent_dir is not None:
  546. tri_fn = point_in_triangle(
  547. (0.12, 0.19),
  548. (0.87, 0.50),
  549. (0.12, 0.81),
  550. )
  551. # Rotate the agent based on its direction
  552. tri_fn = rotate_fn(tri_fn, cx=0.5, cy=0.5, theta=0.5 * math.pi * agent_dir)
  553. fill_coords(img, tri_fn, (255, 0, 0))
  554. # Highlight the cell if needed
  555. if highlight:
  556. highlight_img(img)
  557. # Downsample the image to perform supersampling/anti-aliasing
  558. img = downsample(img, subdivs)
  559. # Cache the rendered tile
  560. cls.tile_cache[key] = img
  561. return img
  562. def render(self, tile_size, agent_pos=None, agent_dir=None, highlight_mask=None):
  563. """
  564. Render this grid at a given scale
  565. :param r: target renderer object
  566. :param tile_size: tile size in pixels
  567. """
  568. if highlight_mask is None:
  569. highlight_mask = np.zeros(shape=(self.width, self.height), dtype=bool)
  570. # Compute the total grid size
  571. width_px = self.width * tile_size
  572. height_px = self.height * tile_size
  573. img = np.zeros(shape=(height_px, width_px, 3), dtype=np.uint8)
  574. # Render the grid
  575. for j in range(0, self.height):
  576. for i in range(0, self.width):
  577. cell = self.get(i, j)
  578. agent_here = np.array_equal(agent_pos, (i, j))
  579. tile_img = Grid.render_tile(
  580. cell,
  581. agent_dir=agent_dir if agent_here else None,
  582. highlight=highlight_mask[i, j],
  583. tile_size=tile_size,
  584. )
  585. ymin = j * tile_size
  586. ymax = (j + 1) * tile_size
  587. xmin = i * tile_size
  588. xmax = (i + 1) * tile_size
  589. img[ymin:ymax, xmin:xmax, :] = tile_img
  590. return img
  591. def encode(self, vis_mask=None):
  592. """
  593. Produce a compact numpy encoding of the grid
  594. """
  595. if vis_mask is None:
  596. vis_mask = np.ones((self.width, self.height), dtype=bool)
  597. array = np.zeros((self.width, self.height, 3), dtype="uint8")
  598. for i in range(self.width):
  599. for j in range(self.height):
  600. if vis_mask[i, j]:
  601. v = self.get(i, j)
  602. if v is None:
  603. array[i, j, 0] = OBJECT_TO_IDX["empty"]
  604. array[i, j, 1] = 0
  605. array[i, j, 2] = 0
  606. else:
  607. array[i, j, :] = v.encode()
  608. return array
  609. @staticmethod
  610. def decode(array):
  611. """
  612. Decode an array grid encoding back into a grid
  613. """
  614. width, height, channels = array.shape
  615. assert channels == 3
  616. vis_mask = np.ones(shape=(width, height), dtype=bool)
  617. grid = Grid(width, height)
  618. for i in range(width):
  619. for j in range(height):
  620. type_idx, color_idx, state = array[i, j]
  621. v = WorldObj.decode(type_idx, color_idx, state)
  622. grid.set(i, j, v)
  623. vis_mask[i, j] = type_idx != OBJECT_TO_IDX["unseen"]
  624. return grid, vis_mask
  625. def process_vis(self, agent_pos):
  626. mask = np.zeros(shape=(self.width, self.height), dtype=bool)
  627. mask[agent_pos[0], agent_pos[1]] = True
  628. for j in reversed(range(0, self.height)):
  629. for i in range(0, self.width - 1):
  630. if not mask[i, j]:
  631. continue
  632. cell = self.get(i, j)
  633. if cell and not cell.see_behind():
  634. continue
  635. mask[i + 1, j] = True
  636. if j > 0:
  637. mask[i + 1, j - 1] = True
  638. mask[i, j - 1] = True
  639. for i in reversed(range(1, self.width)):
  640. if not mask[i, j]:
  641. continue
  642. cell = self.get(i, j)
  643. if cell and not cell.see_behind():
  644. continue
  645. mask[i - 1, j] = True
  646. if j > 0:
  647. mask[i - 1, j - 1] = True
  648. mask[i, j - 1] = True
  649. for j in range(0, self.height):
  650. for i in range(0, self.width):
  651. if not mask[i, j]:
  652. self.set(i, j, None)
  653. return mask
  654. class MiniGridEnv(gym.Env):
  655. """
  656. 2D grid world game environment
  657. """
  658. metadata = {
  659. # Deprecated: use 'render_modes' instead
  660. "render.modes": ["human", "rgb_array"],
  661. "video.frames_per_second": 10, # Deprecated: use 'render_fps' instead
  662. "render_modes": ["human", "rgb_array", "single_rgb_array"],
  663. "render_fps": 10,
  664. }
  665. # Enumeration of possible actions
  666. class Actions(IntEnum):
  667. # Turn left, turn right, move forward
  668. left = 0
  669. right = 1
  670. forward = 2
  671. # Pick up an object
  672. pickup = 3
  673. # Drop an object
  674. drop = 4
  675. # Toggle/activate an object
  676. toggle = 5
  677. # Done completing task
  678. done = 6
  679. def __init__(
  680. self,
  681. mission_space: MissionSpace,
  682. grid_size: int = None,
  683. width: int = None,
  684. height: int = None,
  685. max_steps: int = 100,
  686. see_through_walls: bool = False,
  687. agent_view_size: int = 7,
  688. highlight: bool = True,
  689. tile_size: int = TILE_PIXELS,
  690. **kwargs,
  691. ):
  692. # Initialize mission
  693. self.mission = mission_space.sample()
  694. # Can't set both grid_size and width/height
  695. if grid_size:
  696. assert width is None and height is None
  697. width = grid_size
  698. height = grid_size
  699. # Action enumeration for this environment
  700. self.actions = MiniGridEnv.Actions
  701. # Actions are discrete integer values
  702. self.action_space = spaces.Discrete(len(self.actions))
  703. # Number of cells (width and height) in the agent view
  704. assert agent_view_size % 2 == 1
  705. assert agent_view_size >= 3
  706. self.agent_view_size = agent_view_size
  707. # Observations are dictionaries containing an
  708. # encoding of the grid and a textual 'mission' string
  709. image_observation_space = spaces.Box(
  710. low=0,
  711. high=255,
  712. shape=(self.agent_view_size, self.agent_view_size, 3),
  713. dtype="uint8",
  714. )
  715. self.observation_space = spaces.Dict(
  716. {
  717. "image": image_observation_space,
  718. "direction": spaces.Discrete(4),
  719. "mission": mission_space,
  720. }
  721. )
  722. # Range of possible rewards
  723. self.reward_range = (0, 1)
  724. self.window: Window = None
  725. # Environment configuration
  726. self.width = width
  727. self.height = height
  728. self.max_steps = max_steps
  729. self.see_through_walls = see_through_walls
  730. # Current position and direction of the agent
  731. self.agent_pos: np.ndarray = None
  732. self.agent_dir: int = None
  733. # Initialize the state
  734. self.reset()
  735. def reset(self, *, seed=None, return_info=False, options=None):
  736. super().reset(seed=seed)
  737. # Current position and direction of the agent
  738. self.agent_pos = None
  739. self.agent_dir = None
  740. # Generate a new random grid at the start of each episode
  741. self._gen_grid(self.width, self.height)
  742. # These fields should be defined by _gen_grid
  743. assert self.agent_pos is not None
  744. assert self.agent_dir is not None
  745. # Check that the agent doesn't overlap with an object
  746. start_cell = self.grid.get(*self.agent_pos)
  747. assert start_cell is None or start_cell.can_overlap()
  748. # Item picked up, being carried, initially nothing
  749. self.carrying = None
  750. # Step count since episode start
  751. self.step_count = 0
  752. # Return first observation
  753. obs = self.gen_obs()
  754. if not return_info:
  755. return obs
  756. else:
  757. return obs, {}
  758. def hash(self, size=16):
  759. """Compute a hash that uniquely identifies the current state of the environment.
  760. :param size: Size of the hashing
  761. """
  762. sample_hash = hashlib.sha256()
  763. to_encode = [self.grid.encode().tolist(), self.agent_pos, self.agent_dir]
  764. for item in to_encode:
  765. sample_hash.update(str(item).encode("utf8"))
  766. return sample_hash.hexdigest()[:size]
  767. @property
  768. def steps_remaining(self):
  769. return self.max_steps - self.step_count
  770. def __str__(self):
  771. """
  772. Produce a pretty string of the environment's grid along with the agent.
  773. A grid cell is represented by 2-character string, the first one for
  774. the object and the second one for the color.
  775. """
  776. # Map of object types to short string
  777. OBJECT_TO_STR = {
  778. "wall": "W",
  779. "floor": "F",
  780. "door": "D",
  781. "key": "K",
  782. "ball": "A",
  783. "box": "B",
  784. "goal": "G",
  785. "lava": "V",
  786. }
  787. # Map agent's direction to short string
  788. AGENT_DIR_TO_STR = {0: ">", 1: "V", 2: "<", 3: "^"}
  789. str = ""
  790. for j in range(self.grid.height):
  791. for i in range(self.grid.width):
  792. if i == self.agent_pos[0] and j == self.agent_pos[1]:
  793. str += 2 * AGENT_DIR_TO_STR[self.agent_dir]
  794. continue
  795. c = self.grid.get(i, j)
  796. if c is None:
  797. str += " "
  798. continue
  799. if c.type == "door":
  800. if c.is_open:
  801. str += "__"
  802. elif c.is_locked:
  803. str += "L" + c.color[0].upper()
  804. else:
  805. str += "D" + c.color[0].upper()
  806. continue
  807. str += OBJECT_TO_STR[c.type] + c.color[0].upper()
  808. if j < self.grid.height - 1:
  809. str += "\n"
  810. return str
  811. @abstractmethod
  812. def _gen_grid(self, width, height):
  813. pass
  814. def _reward(self):
  815. """
  816. Compute the reward to be given upon success
  817. """
  818. return 1 - 0.9 * (self.step_count / self.max_steps)
  819. def _rand_int(self, low, high):
  820. """
  821. Generate random integer in [low,high[
  822. """
  823. return self.np_random.integers(low, high)
  824. def _rand_float(self, low, high):
  825. """
  826. Generate random float in [low,high[
  827. """
  828. return self.np_random.uniform(low, high)
  829. def _rand_bool(self):
  830. """
  831. Generate random boolean value
  832. """
  833. return self.np_random.integers(0, 2) == 0
  834. def _rand_elem(self, iterable):
  835. """
  836. Pick a random element in a list
  837. """
  838. lst = list(iterable)
  839. idx = self._rand_int(0, len(lst))
  840. return lst[idx]
  841. def _rand_subset(self, iterable, num_elems):
  842. """
  843. Sample a random subset of distinct elements of a list
  844. """
  845. lst = list(iterable)
  846. assert num_elems <= len(lst)
  847. out = []
  848. while len(out) < num_elems:
  849. elem = self._rand_elem(lst)
  850. lst.remove(elem)
  851. out.append(elem)
  852. return out
  853. def _rand_color(self):
  854. """
  855. Generate a random color name (string)
  856. """
  857. return self._rand_elem(COLOR_NAMES)
  858. def _rand_pos(self, xLow, xHigh, yLow, yHigh):
  859. """
  860. Generate a random (x,y) position tuple
  861. """
  862. return (
  863. self.np_random.integers(xLow, xHigh),
  864. self.np_random.integers(yLow, yHigh),
  865. )
  866. def place_obj(self, obj, top=None, size=None, reject_fn=None, max_tries=math.inf):
  867. """
  868. Place an object at an empty position in the grid
  869. :param top: top-left position of the rectangle where to place
  870. :param size: size of the rectangle where to place
  871. :param reject_fn: function to filter out potential positions
  872. """
  873. if top is None:
  874. top = (0, 0)
  875. else:
  876. top = (max(top[0], 0), max(top[1], 0))
  877. if size is None:
  878. size = (self.grid.width, self.grid.height)
  879. num_tries = 0
  880. while True:
  881. # This is to handle with rare cases where rejection sampling
  882. # gets stuck in an infinite loop
  883. if num_tries > max_tries:
  884. raise RecursionError("rejection sampling failed in place_obj")
  885. num_tries += 1
  886. pos = np.array(
  887. (
  888. self._rand_int(top[0], min(top[0] + size[0], self.grid.width)),
  889. self._rand_int(top[1], min(top[1] + size[1], self.grid.height)),
  890. )
  891. )
  892. # Don't place the object on top of another object
  893. if self.grid.get(*pos) is not None:
  894. continue
  895. # Don't place the object where the agent is
  896. if np.array_equal(pos, self.agent_pos):
  897. continue
  898. # Check if there is a filtering criterion
  899. if reject_fn and reject_fn(self, pos):
  900. continue
  901. break
  902. self.grid.set(*pos, obj)
  903. if obj is not None:
  904. obj.init_pos = pos
  905. obj.cur_pos = pos
  906. return pos
  907. def put_obj(self, obj, i, j):
  908. """
  909. Put an object at a specific position in the grid
  910. """
  911. self.grid.set(i, j, obj)
  912. obj.init_pos = (i, j)
  913. obj.cur_pos = (i, j)
  914. def place_agent(self, top=None, size=None, rand_dir=True, max_tries=math.inf):
  915. """
  916. Set the agent's starting point at an empty position in the grid
  917. """
  918. self.agent_pos = None
  919. pos = self.place_obj(None, top, size, max_tries=max_tries)
  920. self.agent_pos = pos
  921. if rand_dir:
  922. self.agent_dir = self._rand_int(0, 4)
  923. return pos
  924. @property
  925. def dir_vec(self):
  926. """
  927. Get the direction vector for the agent, pointing in the direction
  928. of forward movement.
  929. """
  930. assert self.agent_dir >= 0 and self.agent_dir < 4
  931. return DIR_TO_VEC[self.agent_dir]
  932. @property
  933. def right_vec(self):
  934. """
  935. Get the vector pointing to the right of the agent.
  936. """
  937. dx, dy = self.dir_vec
  938. return np.array((-dy, dx))
  939. @property
  940. def front_pos(self):
  941. """
  942. Get the position of the cell that is right in front of the agent
  943. """
  944. return self.agent_pos + self.dir_vec
  945. def get_view_coords(self, i, j):
  946. """
  947. Translate and rotate absolute grid coordinates (i, j) into the
  948. agent's partially observable view (sub-grid). Note that the resulting
  949. coordinates may be negative or outside of the agent's view size.
  950. """
  951. ax, ay = self.agent_pos
  952. dx, dy = self.dir_vec
  953. rx, ry = self.right_vec
  954. # Compute the absolute coordinates of the top-left view corner
  955. sz = self.agent_view_size
  956. hs = self.agent_view_size // 2
  957. tx = ax + (dx * (sz - 1)) - (rx * hs)
  958. ty = ay + (dy * (sz - 1)) - (ry * hs)
  959. lx = i - tx
  960. ly = j - ty
  961. # Project the coordinates of the object relative to the top-left
  962. # corner onto the agent's own coordinate system
  963. vx = rx * lx + ry * ly
  964. vy = -(dx * lx + dy * ly)
  965. return vx, vy
  966. def get_view_exts(self, agent_view_size=None):
  967. """
  968. Get the extents of the square set of tiles visible to the agent
  969. Note: the bottom extent indices are not included in the set
  970. if agent_view_size is None, use self.agent_view_size
  971. """
  972. agent_view_size = agent_view_size or self.agent_view_size
  973. # Facing right
  974. if self.agent_dir == 0:
  975. topX = self.agent_pos[0]
  976. topY = self.agent_pos[1] - agent_view_size // 2
  977. # Facing down
  978. elif self.agent_dir == 1:
  979. topX = self.agent_pos[0] - agent_view_size // 2
  980. topY = self.agent_pos[1]
  981. # Facing left
  982. elif self.agent_dir == 2:
  983. topX = self.agent_pos[0] - agent_view_size + 1
  984. topY = self.agent_pos[1] - agent_view_size // 2
  985. # Facing up
  986. elif self.agent_dir == 3:
  987. topX = self.agent_pos[0] - agent_view_size // 2
  988. topY = self.agent_pos[1] - agent_view_size + 1
  989. else:
  990. assert False, "invalid agent direction"
  991. botX = topX + agent_view_size
  992. botY = topY + agent_view_size
  993. return (topX, topY, botX, botY)
  994. def relative_coords(self, x, y):
  995. """
  996. Check if a grid position belongs to the agent's field of view, and returns the corresponding coordinates
  997. """
  998. vx, vy = self.get_view_coords(x, y)
  999. if vx < 0 or vy < 0 or vx >= self.agent_view_size or vy >= self.agent_view_size:
  1000. return None
  1001. return vx, vy
  1002. def in_view(self, x, y):
  1003. """
  1004. check if a grid position is visible to the agent
  1005. """
  1006. return self.relative_coords(x, y) is not None
  1007. def agent_sees(self, x, y):
  1008. """
  1009. Check if a non-empty grid position is visible to the agent
  1010. """
  1011. coordinates = self.relative_coords(x, y)
  1012. if coordinates is None:
  1013. return False
  1014. vx, vy = coordinates
  1015. obs = self.gen_obs()
  1016. obs_grid, _ = Grid.decode(obs["image"])
  1017. obs_cell = obs_grid.get(vx, vy)
  1018. world_cell = self.grid.get(x, y)
  1019. return obs_cell is not None and obs_cell.type == world_cell.type
  1020. def step(self, action):
  1021. self.step_count += 1
  1022. reward = 0
  1023. done = False
  1024. # Get the position in front of the agent
  1025. fwd_pos = self.front_pos
  1026. # Get the contents of the cell in front of the agent
  1027. fwd_cell = self.grid.get(*fwd_pos)
  1028. # Rotate left
  1029. if action == self.actions.left:
  1030. self.agent_dir -= 1
  1031. if self.agent_dir < 0:
  1032. self.agent_dir += 4
  1033. # Rotate right
  1034. elif action == self.actions.right:
  1035. self.agent_dir = (self.agent_dir + 1) % 4
  1036. # Move forward
  1037. elif action == self.actions.forward:
  1038. if fwd_cell is None or fwd_cell.can_overlap():
  1039. self.agent_pos = fwd_pos
  1040. if fwd_cell is not None and fwd_cell.type == "goal":
  1041. done = True
  1042. reward = self._reward()
  1043. if fwd_cell is not None and fwd_cell.type == "lava":
  1044. done = True
  1045. # Pick up an object
  1046. elif action == self.actions.pickup:
  1047. if fwd_cell and fwd_cell.can_pickup():
  1048. if self.carrying is None:
  1049. self.carrying = fwd_cell
  1050. self.carrying.cur_pos = np.array([-1, -1])
  1051. self.grid.set(*fwd_pos, None)
  1052. # Drop an object
  1053. elif action == self.actions.drop:
  1054. if not fwd_cell and self.carrying:
  1055. self.grid.set(*fwd_pos, self.carrying)
  1056. self.carrying.cur_pos = fwd_pos
  1057. self.carrying = None
  1058. # Toggle/activate an object
  1059. elif action == self.actions.toggle:
  1060. if fwd_cell:
  1061. fwd_cell.toggle(self, fwd_pos)
  1062. # Done action (not used by default)
  1063. elif action == self.actions.done:
  1064. pass
  1065. else:
  1066. assert False, "unknown action"
  1067. if self.step_count >= self.max_steps:
  1068. done = True
  1069. obs = self.gen_obs()
  1070. return obs, reward, done, {}
  1071. def gen_obs_grid(self, agent_view_size=None):
  1072. """
  1073. Generate the sub-grid observed by the agent.
  1074. This method also outputs a visibility mask telling us which grid
  1075. cells the agent can actually see.
  1076. if agent_view_size is None, self.agent_view_size is used
  1077. """
  1078. topX, topY, botX, botY = self.get_view_exts(agent_view_size)
  1079. agent_view_size = agent_view_size or self.agent_view_size
  1080. grid = self.grid.slice(topX, topY, agent_view_size, agent_view_size)
  1081. for i in range(self.agent_dir + 1):
  1082. grid = grid.rotate_left()
  1083. # Process occluders and visibility
  1084. # Note that this incurs some performance cost
  1085. if not self.see_through_walls:
  1086. vis_mask = grid.process_vis(
  1087. agent_pos=(agent_view_size // 2, agent_view_size - 1)
  1088. )
  1089. else:
  1090. vis_mask = np.ones(shape=(grid.width, grid.height), dtype=bool)
  1091. # Make it so the agent sees what it's carrying
  1092. # We do this by placing the carried object at the agent's position
  1093. # in the agent's partially observable view
  1094. agent_pos = grid.width // 2, grid.height - 1
  1095. if self.carrying:
  1096. grid.set(*agent_pos, self.carrying)
  1097. else:
  1098. grid.set(*agent_pos, None)
  1099. return grid, vis_mask
  1100. def gen_obs(self):
  1101. """
  1102. Generate the agent's view (partially observable, low-resolution encoding)
  1103. """
  1104. grid, vis_mask = self.gen_obs_grid()
  1105. # Encode the partially observable view into a numpy array
  1106. image = grid.encode(vis_mask)
  1107. assert hasattr(
  1108. self, "mission"
  1109. ), "environments must define a textual mission string"
  1110. # Observations are dictionaries containing:
  1111. # - an image (partially observable view of the environment)
  1112. # - the agent's direction/orientation (acting as a compass)
  1113. # - a textual mission string (instructions for the agent)
  1114. obs = {"image": image, "direction": self.agent_dir, "mission": self.mission}
  1115. return obs
  1116. def get_obs_render(self, obs, tile_size=TILE_PIXELS // 2):
  1117. """
  1118. Render an agent observation for visualization
  1119. """
  1120. grid, vis_mask = Grid.decode(obs)
  1121. # Render the whole grid
  1122. img = grid.render(
  1123. tile_size,
  1124. agent_pos=(self.agent_view_size // 2, self.agent_view_size - 1),
  1125. agent_dir=3,
  1126. highlight_mask=vis_mask,
  1127. )
  1128. return img
  1129. def render(self, mode="human", highlight=True, tile_size=TILE_PIXELS):
  1130. assert mode in self.metadata["render_modes"]
  1131. """
  1132. Render the whole-grid human view
  1133. """
  1134. if mode == "human" and not self.window:
  1135. self.window = Window("gym_minigrid")
  1136. self.window.show(block=False)
  1137. # Compute which cells are visible to the agent
  1138. _, vis_mask = self.gen_obs_grid()
  1139. # Compute the world coordinates of the bottom-left corner
  1140. # of the agent's view area
  1141. f_vec = self.dir_vec
  1142. r_vec = self.right_vec
  1143. top_left = (
  1144. self.agent_pos
  1145. + f_vec * (self.agent_view_size - 1)
  1146. - r_vec * (self.agent_view_size // 2)
  1147. )
  1148. # Mask of which cells to highlight
  1149. highlight_mask = np.zeros(shape=(self.width, self.height), dtype=bool)
  1150. # For each cell in the visibility mask
  1151. for vis_j in range(0, self.agent_view_size):
  1152. for vis_i in range(0, self.agent_view_size):
  1153. # If this cell is not visible, don't highlight it
  1154. if not vis_mask[vis_i, vis_j]:
  1155. continue
  1156. # Compute the world coordinates of this cell
  1157. abs_i, abs_j = top_left - (f_vec * vis_j) + (r_vec * vis_i)
  1158. if abs_i < 0 or abs_i >= self.width:
  1159. continue
  1160. if abs_j < 0 or abs_j >= self.height:
  1161. continue
  1162. # Mark this cell to be highlighted
  1163. highlight_mask[abs_i, abs_j] = True
  1164. # Render the whole grid
  1165. img = self.grid.render(
  1166. tile_size,
  1167. self.agent_pos,
  1168. self.agent_dir,
  1169. highlight_mask=highlight_mask if highlight else None,
  1170. )
  1171. if mode == "human":
  1172. self.window.set_caption(self.mission)
  1173. self.window.show_img(img)
  1174. else:
  1175. return img
  1176. def close(self):
  1177. if self.window:
  1178. self.window.close()