minigrid_env.py 46 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548
  1. import hashlib
  2. import math
  3. from abc import abstractmethod
  4. from enum import IntEnum
  5. from typing import Any, Callable, Optional, Union
  6. import gymnasium as gym
  7. import numpy as np
  8. from gymnasium import spaces
  9. from gymnasium.utils import seeding
  10. # Size in pixels of a tile in the full-scale human view
  11. from minigrid.utils.rendering import (
  12. downsample,
  13. fill_coords,
  14. highlight_img,
  15. point_in_circle,
  16. point_in_line,
  17. point_in_rect,
  18. point_in_triangle,
  19. rotate_fn,
  20. )
  21. from minigrid.utils.window import Window
  22. TILE_PIXELS = 32
  23. # Map of color names to RGB values
  24. COLORS = {
  25. "red": np.array([255, 0, 0]),
  26. "green": np.array([0, 255, 0]),
  27. "blue": np.array([0, 0, 255]),
  28. "purple": np.array([112, 39, 195]),
  29. "yellow": np.array([255, 255, 0]),
  30. "grey": np.array([100, 100, 100]),
  31. }
  32. COLOR_NAMES = sorted(list(COLORS.keys()))
  33. # Used to map colors to integers
  34. COLOR_TO_IDX = {"red": 0, "green": 1, "blue": 2, "purple": 3, "yellow": 4, "grey": 5}
  35. IDX_TO_COLOR = dict(zip(COLOR_TO_IDX.values(), COLOR_TO_IDX.keys()))
  36. # Map of object type to integers
  37. OBJECT_TO_IDX = {
  38. "unseen": 0,
  39. "empty": 1,
  40. "wall": 2,
  41. "floor": 3,
  42. "door": 4,
  43. "key": 5,
  44. "ball": 6,
  45. "box": 7,
  46. "goal": 8,
  47. "lava": 9,
  48. "agent": 10,
  49. }
  50. IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys()))
  51. # Map of state names to integers
  52. STATE_TO_IDX = {
  53. "open": 0,
  54. "closed": 1,
  55. "locked": 2,
  56. }
  57. # Map of agent direction indices to vectors
  58. DIR_TO_VEC = [
  59. # Pointing right (positive X)
  60. np.array((1, 0)),
  61. # Down (positive Y)
  62. np.array((0, 1)),
  63. # Pointing left (negative X)
  64. np.array((-1, 0)),
  65. # Up (negative Y)
  66. np.array((0, -1)),
  67. ]
  68. def check_if_no_duplicate(duplicate_list: list) -> bool:
  69. """Check if given list contains any duplicates"""
  70. return len(set(duplicate_list)) == len(duplicate_list)
  71. class MissionSpace(spaces.Space[str]):
  72. r"""A space representing a mission for the Minigrid environments.
  73. The space allows generating random mission strings constructed with an input placeholder list.
  74. Example Usage::
  75. >>> def _gen_mission() -> str:
  76. >>> return "Get the ball."
  77. >>> observation_space = MissionSpace(mission_func=_gen_mission)
  78. >>> observation_space.sample()
  79. "Get the ball."
  80. >>> def _gen_mission(color: str, object_type:str) -> str:
  81. >>> return f"Get the {color} {object_type}."
  82. >>> observation_space = MissionSpace(
  83. >>> mission_func=_gen_mission,
  84. >>> ordered_placeholders=[["green", "blue"], ["ball", "key"]],
  85. >>> )
  86. >>> observation_space.sample()
  87. "Get the green ball."
  88. """
  89. def __init__(
  90. self,
  91. mission_func: Callable[..., str],
  92. ordered_placeholders: Optional["list[list[str]]"] = None,
  93. seed: Optional[Union[int, seeding.RandomNumberGenerator]] = None,
  94. ):
  95. r"""Constructor of :class:`MissionSpace` space.
  96. Args:
  97. mission_func (lambda _placeholders(str): _mission(str)): Function that generates a mission string from random placeholders.
  98. ordered_placeholders (Optional["list[list[str]]"]): List of lists of placeholders ordered in placing order in the mission function mission_func.
  99. seed: seed: The seed for sampling from the space.
  100. """
  101. # Check that the ordered placeholders and mission function are well defined.
  102. if ordered_placeholders is not None:
  103. assert (
  104. len(ordered_placeholders) == mission_func.__code__.co_argcount
  105. ), f"The number of placeholders {len(ordered_placeholders)} is different from the number of parameters in the mission function {mission_func.__code__.co_argcount}."
  106. for placeholder_list in ordered_placeholders:
  107. assert check_if_no_duplicate(
  108. placeholder_list
  109. ), "Make sure that the placeholders don't have any duplicate values."
  110. else:
  111. assert (
  112. mission_func.__code__.co_argcount == 0
  113. ), f"If the ordered placeholders are {ordered_placeholders}, the mission function shouldn't have any parameters."
  114. self.ordered_placeholders = ordered_placeholders
  115. self.mission_func = mission_func
  116. super().__init__(dtype=str, seed=seed)
  117. # Check that mission_func returns a string
  118. sampled_mission = self.sample()
  119. assert isinstance(
  120. sampled_mission, str
  121. ), f"mission_func must return type str not {type(sampled_mission)}"
  122. def sample(self) -> str:
  123. """Sample a random mission string."""
  124. if self.ordered_placeholders is not None:
  125. placeholders = []
  126. for rand_var_list in self.ordered_placeholders:
  127. idx = self.np_random.integers(0, len(rand_var_list))
  128. placeholders.append(rand_var_list[idx])
  129. return self.mission_func(*placeholders)
  130. else:
  131. return self.mission_func()
  132. def contains(self, x: Any) -> bool:
  133. """Return boolean specifying if x is a valid member of this space."""
  134. # Store a list of all the placeholders from self.ordered_placeholders that appear in x
  135. if self.ordered_placeholders is not None:
  136. check_placeholder_list = []
  137. for placeholder_list in self.ordered_placeholders:
  138. for placeholder in placeholder_list:
  139. if placeholder in x:
  140. check_placeholder_list.append(placeholder)
  141. # Remove duplicates from the list
  142. check_placeholder_list = list(set(check_placeholder_list))
  143. start_id_placeholder = []
  144. end_id_placeholder = []
  145. # Get the starting and ending id of the identified placeholders with possible duplicates
  146. new_check_placeholder_list = []
  147. for placeholder in check_placeholder_list:
  148. new_start_id_placeholder = [
  149. i for i in range(len(x)) if x.startswith(placeholder, i)
  150. ]
  151. new_check_placeholder_list += [placeholder] * len(
  152. new_start_id_placeholder
  153. )
  154. end_id_placeholder += [
  155. start_id + len(placeholder) - 1
  156. for start_id in new_start_id_placeholder
  157. ]
  158. start_id_placeholder += new_start_id_placeholder
  159. # Order by starting id the placeholders
  160. ordered_placeholder_list = sorted(
  161. zip(
  162. start_id_placeholder, end_id_placeholder, new_check_placeholder_list
  163. )
  164. )
  165. # Check for repeated placeholders contained in each other
  166. remove_placeholder_id = []
  167. for i, placeholder_1 in enumerate(ordered_placeholder_list):
  168. starting_id = i + 1
  169. for j, placeholder_2 in enumerate(
  170. ordered_placeholder_list[starting_id:]
  171. ):
  172. # Check if place holder ids overlap and keep the longest
  173. if max(placeholder_1[0], placeholder_2[0]) < min(
  174. placeholder_1[1], placeholder_2[1]
  175. ):
  176. remove_placeholder = min(
  177. placeholder_1[2], placeholder_2[2], key=len
  178. )
  179. if remove_placeholder == placeholder_1[2]:
  180. remove_placeholder_id.append(i)
  181. else:
  182. remove_placeholder_id.append(i + j + 1)
  183. for id in remove_placeholder_id:
  184. del ordered_placeholder_list[id]
  185. final_placeholders = [
  186. placeholder[2] for placeholder in ordered_placeholder_list
  187. ]
  188. # Check that the identified final placeholders are in the same order as the original placeholders.
  189. for orered_placeholder, final_placeholder in zip(
  190. self.ordered_placeholders, final_placeholders
  191. ):
  192. if final_placeholder in orered_placeholder:
  193. continue
  194. else:
  195. return False
  196. try:
  197. mission_string_with_placeholders = self.mission_func(
  198. *final_placeholders
  199. )
  200. except Exception as e:
  201. print(
  202. f"{x} is not contained in MissionSpace due to the following exception: {e}"
  203. )
  204. return False
  205. return bool(mission_string_with_placeholders == x)
  206. else:
  207. return bool(self.mission_func() == x)
  208. def __repr__(self) -> str:
  209. """Gives a string representation of this space."""
  210. return f"MissionSpace({self.mission_func}, {self.ordered_placeholders})"
  211. def __eq__(self, other) -> bool:
  212. """Check whether ``other`` is equivalent to this instance."""
  213. if isinstance(other, MissionSpace):
  214. # Check that place holder lists are the same
  215. if self.ordered_placeholders is not None:
  216. # Check length
  217. if (len(self.order_placeholder) == len(other.order_placeholder)) and (
  218. all(
  219. set(i) == set(j)
  220. for i, j in zip(self.order_placeholder, other.order_placeholder)
  221. )
  222. ):
  223. # Check mission string is the same with dummy space placeholders
  224. test_placeholders = [""] * len(self.order_placeholder)
  225. mission = self.mission_func(*test_placeholders)
  226. other_mission = other.mission_func(*test_placeholders)
  227. return mission == other_mission
  228. else:
  229. # Check that other is also None
  230. if other.ordered_placeholders is None:
  231. # Check mission string is the same
  232. mission = self.mission_func()
  233. other_mission = other.mission_func()
  234. return mission == other_mission
  235. # If none of the statements above return then False
  236. return False
  237. class WorldObj:
  238. """
  239. Base class for grid world objects
  240. """
  241. def __init__(self, type, color):
  242. assert type in OBJECT_TO_IDX, type
  243. assert color in COLOR_TO_IDX, color
  244. self.type = type
  245. self.color = color
  246. self.contains = None
  247. # Initial position of the object
  248. self.init_pos = None
  249. # Current position of the object
  250. self.cur_pos = None
  251. def can_overlap(self):
  252. """Can the agent overlap with this?"""
  253. return False
  254. def can_pickup(self):
  255. """Can the agent pick this up?"""
  256. return False
  257. def can_contain(self):
  258. """Can this contain another object?"""
  259. return False
  260. def see_behind(self):
  261. """Can the agent see behind this object?"""
  262. return True
  263. def toggle(self, env, pos):
  264. """Method to trigger/toggle an action this object performs"""
  265. return False
  266. def encode(self):
  267. """Encode the a description of this object as a 3-tuple of integers"""
  268. return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], 0)
  269. @staticmethod
  270. def decode(type_idx, color_idx, state):
  271. """Create an object from a 3-tuple state description"""
  272. obj_type = IDX_TO_OBJECT[type_idx]
  273. color = IDX_TO_COLOR[color_idx]
  274. if obj_type == "empty" or obj_type == "unseen":
  275. return None
  276. # State, 0: open, 1: closed, 2: locked
  277. is_open = state == 0
  278. is_locked = state == 2
  279. if obj_type == "wall":
  280. v = Wall(color)
  281. elif obj_type == "floor":
  282. v = Floor(color)
  283. elif obj_type == "ball":
  284. v = Ball(color)
  285. elif obj_type == "key":
  286. v = Key(color)
  287. elif obj_type == "box":
  288. v = Box(color)
  289. elif obj_type == "door":
  290. v = Door(color, is_open, is_locked)
  291. elif obj_type == "goal":
  292. v = Goal()
  293. elif obj_type == "lava":
  294. v = Lava()
  295. else:
  296. assert False, "unknown object type in decode '%s'" % obj_type
  297. return v
  298. def render(self, r):
  299. """Draw this object with the given renderer"""
  300. raise NotImplementedError
  301. class Goal(WorldObj):
  302. def __init__(self):
  303. super().__init__("goal", "green")
  304. def can_overlap(self):
  305. return True
  306. def render(self, img):
  307. fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color])
  308. class Floor(WorldObj):
  309. """
  310. Colored floor tile the agent can walk over
  311. """
  312. def __init__(self, color="blue"):
  313. super().__init__("floor", color)
  314. def can_overlap(self):
  315. return True
  316. def render(self, img):
  317. # Give the floor a pale color
  318. color = COLORS[self.color] / 2
  319. fill_coords(img, point_in_rect(0.031, 1, 0.031, 1), color)
  320. class Lava(WorldObj):
  321. def __init__(self):
  322. super().__init__("lava", "red")
  323. def can_overlap(self):
  324. return True
  325. def render(self, img):
  326. c = (255, 128, 0)
  327. # Background color
  328. fill_coords(img, point_in_rect(0, 1, 0, 1), c)
  329. # Little waves
  330. for i in range(3):
  331. ylo = 0.3 + 0.2 * i
  332. yhi = 0.4 + 0.2 * i
  333. fill_coords(img, point_in_line(0.1, ylo, 0.3, yhi, r=0.03), (0, 0, 0))
  334. fill_coords(img, point_in_line(0.3, yhi, 0.5, ylo, r=0.03), (0, 0, 0))
  335. fill_coords(img, point_in_line(0.5, ylo, 0.7, yhi, r=0.03), (0, 0, 0))
  336. fill_coords(img, point_in_line(0.7, yhi, 0.9, ylo, r=0.03), (0, 0, 0))
  337. class Wall(WorldObj):
  338. def __init__(self, color="grey"):
  339. super().__init__("wall", color)
  340. def see_behind(self):
  341. return False
  342. def render(self, img):
  343. fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color])
  344. class Door(WorldObj):
  345. def __init__(self, color, is_open=False, is_locked=False):
  346. super().__init__("door", color)
  347. self.is_open = is_open
  348. self.is_locked = is_locked
  349. def can_overlap(self):
  350. """The agent can only walk over this cell when the door is open"""
  351. return self.is_open
  352. def see_behind(self):
  353. return self.is_open
  354. def toggle(self, env, pos):
  355. # If the player has the right key to open the door
  356. if self.is_locked:
  357. if isinstance(env.carrying, Key) and env.carrying.color == self.color:
  358. self.is_locked = False
  359. self.is_open = True
  360. return True
  361. return False
  362. self.is_open = not self.is_open
  363. return True
  364. def encode(self):
  365. """Encode the a description of this object as a 3-tuple of integers"""
  366. # State, 0: open, 1: closed, 2: locked
  367. if self.is_open:
  368. state = 0
  369. elif self.is_locked:
  370. state = 2
  371. # if door is closed and unlocked
  372. elif not self.is_open:
  373. state = 1
  374. else:
  375. raise ValueError(
  376. f"There is no possible state encoding for the state:\n -Door Open: {self.is_open}\n -Door Closed: {not self.is_open}\n -Door Locked: {self.is_locked}"
  377. )
  378. return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], state)
  379. def render(self, img):
  380. c = COLORS[self.color]
  381. if self.is_open:
  382. fill_coords(img, point_in_rect(0.88, 1.00, 0.00, 1.00), c)
  383. fill_coords(img, point_in_rect(0.92, 0.96, 0.04, 0.96), (0, 0, 0))
  384. return
  385. # Door frame and door
  386. if self.is_locked:
  387. fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)
  388. fill_coords(img, point_in_rect(0.06, 0.94, 0.06, 0.94), 0.45 * np.array(c))
  389. # Draw key slot
  390. fill_coords(img, point_in_rect(0.52, 0.75, 0.50, 0.56), c)
  391. else:
  392. fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)
  393. fill_coords(img, point_in_rect(0.04, 0.96, 0.04, 0.96), (0, 0, 0))
  394. fill_coords(img, point_in_rect(0.08, 0.92, 0.08, 0.92), c)
  395. fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), (0, 0, 0))
  396. # Draw door handle
  397. fill_coords(img, point_in_circle(cx=0.75, cy=0.50, r=0.08), c)
  398. class Key(WorldObj):
  399. def __init__(self, color="blue"):
  400. super().__init__("key", color)
  401. def can_pickup(self):
  402. return True
  403. def render(self, img):
  404. c = COLORS[self.color]
  405. # Vertical quad
  406. fill_coords(img, point_in_rect(0.50, 0.63, 0.31, 0.88), c)
  407. # Teeth
  408. fill_coords(img, point_in_rect(0.38, 0.50, 0.59, 0.66), c)
  409. fill_coords(img, point_in_rect(0.38, 0.50, 0.81, 0.88), c)
  410. # Ring
  411. fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.190), c)
  412. fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.064), (0, 0, 0))
  413. class Ball(WorldObj):
  414. def __init__(self, color="blue"):
  415. super().__init__("ball", color)
  416. def can_pickup(self):
  417. return True
  418. def render(self, img):
  419. fill_coords(img, point_in_circle(0.5, 0.5, 0.31), COLORS[self.color])
  420. class Box(WorldObj):
  421. def __init__(self, color, contains=None):
  422. super().__init__("box", color)
  423. self.contains = contains
  424. def can_pickup(self):
  425. return True
  426. def render(self, img):
  427. c = COLORS[self.color]
  428. # Outline
  429. fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), c)
  430. fill_coords(img, point_in_rect(0.18, 0.82, 0.18, 0.82), (0, 0, 0))
  431. # Horizontal slit
  432. fill_coords(img, point_in_rect(0.16, 0.84, 0.47, 0.53), c)
  433. def toggle(self, env, pos):
  434. # Replace the box by its contents
  435. env.grid.set(pos[0], pos[1], self.contains)
  436. return True
  437. class Grid:
  438. """
  439. Represent a grid and operations on it
  440. """
  441. # Static cache of pre-renderer tiles
  442. tile_cache = {}
  443. def __init__(self, width, height):
  444. assert width >= 3
  445. assert height >= 3
  446. self.width = width
  447. self.height = height
  448. self.grid = [None] * width * height
  449. def __contains__(self, key):
  450. if isinstance(key, WorldObj):
  451. for e in self.grid:
  452. if e is key:
  453. return True
  454. elif isinstance(key, tuple):
  455. for e in self.grid:
  456. if e is None:
  457. continue
  458. if (e.color, e.type) == key:
  459. return True
  460. if key[0] is None and key[1] == e.type:
  461. return True
  462. return False
  463. def __eq__(self, other):
  464. grid1 = self.encode()
  465. grid2 = other.encode()
  466. return np.array_equal(grid2, grid1)
  467. def __ne__(self, other):
  468. return not self == other
  469. def copy(self):
  470. from copy import deepcopy
  471. return deepcopy(self)
  472. def set(self, i, j, v):
  473. assert i >= 0 and i < self.width
  474. assert j >= 0 and j < self.height
  475. self.grid[j * self.width + i] = v
  476. def get(self, i, j):
  477. assert i >= 0 and i < self.width
  478. assert j >= 0 and j < self.height
  479. return self.grid[j * self.width + i]
  480. def horz_wall(self, x, y, length=None, obj_type=Wall):
  481. if length is None:
  482. length = self.width - x
  483. for i in range(0, length):
  484. self.set(x + i, y, obj_type())
  485. def vert_wall(self, x, y, length=None, obj_type=Wall):
  486. if length is None:
  487. length = self.height - y
  488. for j in range(0, length):
  489. self.set(x, y + j, obj_type())
  490. def wall_rect(self, x, y, w, h):
  491. self.horz_wall(x, y, w)
  492. self.horz_wall(x, y + h - 1, w)
  493. self.vert_wall(x, y, h)
  494. self.vert_wall(x + w - 1, y, h)
  495. def rotate_left(self):
  496. """
  497. Rotate the grid to the left (counter-clockwise)
  498. """
  499. grid = Grid(self.height, self.width)
  500. for i in range(self.width):
  501. for j in range(self.height):
  502. v = self.get(i, j)
  503. grid.set(j, grid.height - 1 - i, v)
  504. return grid
  505. def slice(self, topX, topY, width, height):
  506. """
  507. Get a subset of the grid
  508. """
  509. grid = Grid(width, height)
  510. for j in range(0, height):
  511. for i in range(0, width):
  512. x = topX + i
  513. y = topY + j
  514. if x >= 0 and x < self.width and y >= 0 and y < self.height:
  515. v = self.get(x, y)
  516. else:
  517. v = Wall()
  518. grid.set(i, j, v)
  519. return grid
  520. @classmethod
  521. def render_tile(
  522. cls, obj, agent_dir=None, highlight=False, tile_size=TILE_PIXELS, subdivs=3
  523. ):
  524. """
  525. Render a tile and cache the result
  526. """
  527. # Hash map lookup key for the cache
  528. key = (agent_dir, highlight, tile_size)
  529. key = obj.encode() + key if obj else key
  530. if key in cls.tile_cache:
  531. return cls.tile_cache[key]
  532. img = np.zeros(
  533. shape=(tile_size * subdivs, tile_size * subdivs, 3), dtype=np.uint8
  534. )
  535. # Draw the grid lines (top and left edges)
  536. fill_coords(img, point_in_rect(0, 0.031, 0, 1), (100, 100, 100))
  537. fill_coords(img, point_in_rect(0, 1, 0, 0.031), (100, 100, 100))
  538. if obj is not None:
  539. obj.render(img)
  540. # Overlay the agent on top
  541. if agent_dir is not None:
  542. tri_fn = point_in_triangle(
  543. (0.12, 0.19),
  544. (0.87, 0.50),
  545. (0.12, 0.81),
  546. )
  547. # Rotate the agent based on its direction
  548. tri_fn = rotate_fn(tri_fn, cx=0.5, cy=0.5, theta=0.5 * math.pi * agent_dir)
  549. fill_coords(img, tri_fn, (255, 0, 0))
  550. # Highlight the cell if needed
  551. if highlight:
  552. highlight_img(img)
  553. # Downsample the image to perform supersampling/anti-aliasing
  554. img = downsample(img, subdivs)
  555. # Cache the rendered tile
  556. cls.tile_cache[key] = img
  557. return img
  558. def render(self, tile_size, agent_pos, agent_dir=None, highlight_mask=None):
  559. """
  560. Render this grid at a given scale
  561. :param r: target renderer object
  562. :param tile_size: tile size in pixels
  563. """
  564. if highlight_mask is None:
  565. highlight_mask = np.zeros(shape=(self.width, self.height), dtype=bool)
  566. # Compute the total grid size
  567. width_px = self.width * tile_size
  568. height_px = self.height * tile_size
  569. img = np.zeros(shape=(height_px, width_px, 3), dtype=np.uint8)
  570. # Render the grid
  571. for j in range(0, self.height):
  572. for i in range(0, self.width):
  573. cell = self.get(i, j)
  574. agent_here = np.array_equal(agent_pos, (i, j))
  575. tile_img = Grid.render_tile(
  576. cell,
  577. agent_dir=agent_dir if agent_here else None,
  578. highlight=highlight_mask[i, j],
  579. tile_size=tile_size,
  580. )
  581. ymin = j * tile_size
  582. ymax = (j + 1) * tile_size
  583. xmin = i * tile_size
  584. xmax = (i + 1) * tile_size
  585. img[ymin:ymax, xmin:xmax, :] = tile_img
  586. return img
  587. def encode(self, vis_mask=None):
  588. """
  589. Produce a compact numpy encoding of the grid
  590. """
  591. if vis_mask is None:
  592. vis_mask = np.ones((self.width, self.height), dtype=bool)
  593. array = np.zeros((self.width, self.height, 3), dtype="uint8")
  594. for i in range(self.width):
  595. for j in range(self.height):
  596. if vis_mask[i, j]:
  597. v = self.get(i, j)
  598. if v is None:
  599. array[i, j, 0] = OBJECT_TO_IDX["empty"]
  600. array[i, j, 1] = 0
  601. array[i, j, 2] = 0
  602. else:
  603. array[i, j, :] = v.encode()
  604. return array
  605. @staticmethod
  606. def decode(array):
  607. """
  608. Decode an array grid encoding back into a grid
  609. """
  610. width, height, channels = array.shape
  611. assert channels == 3
  612. vis_mask = np.ones(shape=(width, height), dtype=bool)
  613. grid = Grid(width, height)
  614. for i in range(width):
  615. for j in range(height):
  616. type_idx, color_idx, state = array[i, j]
  617. v = WorldObj.decode(type_idx, color_idx, state)
  618. grid.set(i, j, v)
  619. vis_mask[i, j] = type_idx != OBJECT_TO_IDX["unseen"]
  620. return grid, vis_mask
  621. def process_vis(self, agent_pos):
  622. mask = np.zeros(shape=(self.width, self.height), dtype=bool)
  623. mask[agent_pos[0], agent_pos[1]] = True
  624. for j in reversed(range(0, self.height)):
  625. for i in range(0, self.width - 1):
  626. if not mask[i, j]:
  627. continue
  628. cell = self.get(i, j)
  629. if cell and not cell.see_behind():
  630. continue
  631. mask[i + 1, j] = True
  632. if j > 0:
  633. mask[i + 1, j - 1] = True
  634. mask[i, j - 1] = True
  635. for i in reversed(range(1, self.width)):
  636. if not mask[i, j]:
  637. continue
  638. cell = self.get(i, j)
  639. if cell and not cell.see_behind():
  640. continue
  641. mask[i - 1, j] = True
  642. if j > 0:
  643. mask[i - 1, j - 1] = True
  644. mask[i, j - 1] = True
  645. for j in range(0, self.height):
  646. for i in range(0, self.width):
  647. if not mask[i, j]:
  648. self.set(i, j, None)
  649. return mask
  650. class MiniGridEnv(gym.Env):
  651. """
  652. 2D grid world game environment
  653. """
  654. metadata = {
  655. "render_modes": ["human", "rgb_array"],
  656. "render_fps": 10,
  657. }
  658. # Enumeration of possible actions
  659. class Actions(IntEnum):
  660. # Turn left, turn right, move forward
  661. left = 0
  662. right = 1
  663. forward = 2
  664. # Pick up an object
  665. pickup = 3
  666. # Drop an object
  667. drop = 4
  668. # Toggle/activate an object
  669. toggle = 5
  670. # Done completing task
  671. done = 6
  672. def __init__(
  673. self,
  674. mission_space: MissionSpace,
  675. grid_size: int = None,
  676. width: int = None,
  677. height: int = None,
  678. max_steps: int = 100,
  679. see_through_walls: bool = False,
  680. agent_view_size: int = 7,
  681. render_mode: Optional[str] = None,
  682. highlight: bool = True,
  683. tile_size: int = TILE_PIXELS,
  684. agent_pov: bool = False,
  685. ):
  686. # Initialize mission
  687. self.mission = mission_space.sample()
  688. # Can't set both grid_size and width/height
  689. if grid_size:
  690. assert width is None and height is None
  691. width = grid_size
  692. height = grid_size
  693. # Action enumeration for this environment
  694. self.actions = MiniGridEnv.Actions
  695. # Actions are discrete integer values
  696. self.action_space = spaces.Discrete(len(self.actions))
  697. # Number of cells (width and height) in the agent view
  698. assert agent_view_size % 2 == 1
  699. assert agent_view_size >= 3
  700. self.agent_view_size = agent_view_size
  701. # Observations are dictionaries containing an
  702. # encoding of the grid and a textual 'mission' string
  703. image_observation_space = spaces.Box(
  704. low=0,
  705. high=255,
  706. shape=(self.agent_view_size, self.agent_view_size, 3),
  707. dtype="uint8",
  708. )
  709. self.observation_space = spaces.Dict(
  710. {
  711. "image": image_observation_space,
  712. "direction": spaces.Discrete(4),
  713. "mission": mission_space,
  714. }
  715. )
  716. # Range of possible rewards
  717. self.reward_range = (0, 1)
  718. self.window: Window = None
  719. # Environment configuration
  720. self.width = width
  721. self.height = height
  722. assert isinstance(
  723. max_steps, int
  724. ), f"The argument max_steps must be an integer, got: {type(max_steps)}"
  725. self.max_steps = max_steps
  726. self.see_through_walls = see_through_walls
  727. # Current position and direction of the agent
  728. self.agent_pos: np.ndarray = None
  729. self.agent_dir: int = None
  730. # Current grid and mission and carryinh
  731. self.grid = Grid(width, height)
  732. self.carrying = None
  733. # Rendering attributes
  734. self.render_mode = render_mode
  735. self.highlight = highlight
  736. self.tile_size = tile_size
  737. self.agent_pov = agent_pov
  738. def reset(self, *, seed=None, options=None):
  739. super().reset(seed=seed)
  740. # Reinitialize episode-specific variables
  741. self.agent_pos = (-1, -1)
  742. self.agent_dir = -1
  743. # Generate a new random grid at the start of each episode
  744. self._gen_grid(self.width, self.height)
  745. # These fields should be defined by _gen_grid
  746. assert (
  747. self.agent_pos >= (0, 0)
  748. if isinstance(self.agent_pos, tuple)
  749. else all(self.agent_pos >= 0) and self.agent_dir >= 0
  750. )
  751. # Check that the agent doesn't overlap with an object
  752. start_cell = self.grid.get(*self.agent_pos)
  753. assert start_cell is None or start_cell.can_overlap()
  754. # Item picked up, being carried, initially nothing
  755. self.carrying = None
  756. # Step count since episode start
  757. self.step_count = 0
  758. if self.render_mode == "human":
  759. self.render()
  760. # Return first observation
  761. obs = self.gen_obs()
  762. return obs, {}
  763. def hash(self, size=16):
  764. """Compute a hash that uniquely identifies the current state of the environment.
  765. :param size: Size of the hashing
  766. """
  767. sample_hash = hashlib.sha256()
  768. to_encode = [self.grid.encode().tolist(), self.agent_pos, self.agent_dir]
  769. for item in to_encode:
  770. sample_hash.update(str(item).encode("utf8"))
  771. return sample_hash.hexdigest()[:size]
  772. @property
  773. def steps_remaining(self):
  774. return self.max_steps - self.step_count
  775. def __str__(self):
  776. """
  777. Produce a pretty string of the environment's grid along with the agent.
  778. A grid cell is represented by 2-character string, the first one for
  779. the object and the second one for the color.
  780. """
  781. # Map of object types to short string
  782. OBJECT_TO_STR = {
  783. "wall": "W",
  784. "floor": "F",
  785. "door": "D",
  786. "key": "K",
  787. "ball": "A",
  788. "box": "B",
  789. "goal": "G",
  790. "lava": "V",
  791. }
  792. # Map agent's direction to short string
  793. AGENT_DIR_TO_STR = {0: ">", 1: "V", 2: "<", 3: "^"}
  794. str = ""
  795. for j in range(self.grid.height):
  796. for i in range(self.grid.width):
  797. if i == self.agent_pos[0] and j == self.agent_pos[1]:
  798. str += 2 * AGENT_DIR_TO_STR[self.agent_dir]
  799. continue
  800. c = self.grid.get(i, j)
  801. if c is None:
  802. str += " "
  803. continue
  804. if c.type == "door":
  805. if c.is_open:
  806. str += "__"
  807. elif c.is_locked:
  808. str += "L" + c.color[0].upper()
  809. else:
  810. str += "D" + c.color[0].upper()
  811. continue
  812. str += OBJECT_TO_STR[c.type] + c.color[0].upper()
  813. if j < self.grid.height - 1:
  814. str += "\n"
  815. return str
  816. @abstractmethod
  817. def _gen_grid(self, width, height):
  818. pass
  819. def _reward(self):
  820. """
  821. Compute the reward to be given upon success
  822. """
  823. return 1 - 0.9 * (self.step_count / self.max_steps)
  824. def _rand_int(self, low, high):
  825. """
  826. Generate random integer in [low,high[
  827. """
  828. return self.np_random.integers(low, high)
  829. def _rand_float(self, low, high):
  830. """
  831. Generate random float in [low,high[
  832. """
  833. return self.np_random.uniform(low, high)
  834. def _rand_bool(self):
  835. """
  836. Generate random boolean value
  837. """
  838. return self.np_random.integers(0, 2) == 0
  839. def _rand_elem(self, iterable):
  840. """
  841. Pick a random element in a list
  842. """
  843. lst = list(iterable)
  844. idx = self._rand_int(0, len(lst))
  845. return lst[idx]
  846. def _rand_subset(self, iterable, num_elems):
  847. """
  848. Sample a random subset of distinct elements of a list
  849. """
  850. lst = list(iterable)
  851. assert num_elems <= len(lst)
  852. out = []
  853. while len(out) < num_elems:
  854. elem = self._rand_elem(lst)
  855. lst.remove(elem)
  856. out.append(elem)
  857. return out
  858. def _rand_color(self):
  859. """
  860. Generate a random color name (string)
  861. """
  862. return self._rand_elem(COLOR_NAMES)
  863. def _rand_pos(self, xLow, xHigh, yLow, yHigh):
  864. """
  865. Generate a random (x,y) position tuple
  866. """
  867. return (
  868. self.np_random.integers(xLow, xHigh),
  869. self.np_random.integers(yLow, yHigh),
  870. )
  871. def place_obj(self, obj, top=None, size=None, reject_fn=None, max_tries=math.inf):
  872. """
  873. Place an object at an empty position in the grid
  874. :param top: top-left position of the rectangle where to place
  875. :param size: size of the rectangle where to place
  876. :param reject_fn: function to filter out potential positions
  877. """
  878. if top is None:
  879. top = (0, 0)
  880. else:
  881. top = (max(top[0], 0), max(top[1], 0))
  882. if size is None:
  883. size = (self.grid.width, self.grid.height)
  884. num_tries = 0
  885. while True:
  886. # This is to handle with rare cases where rejection sampling
  887. # gets stuck in an infinite loop
  888. if num_tries > max_tries:
  889. raise RecursionError("rejection sampling failed in place_obj")
  890. num_tries += 1
  891. pos = np.array(
  892. (
  893. self._rand_int(top[0], min(top[0] + size[0], self.grid.width)),
  894. self._rand_int(top[1], min(top[1] + size[1], self.grid.height)),
  895. )
  896. )
  897. pos = tuple(pos)
  898. # Don't place the object on top of another object
  899. if self.grid.get(*pos) is not None:
  900. continue
  901. # Don't place the object where the agent is
  902. if np.array_equal(pos, self.agent_pos):
  903. continue
  904. # Check if there is a filtering criterion
  905. if reject_fn and reject_fn(self, pos):
  906. continue
  907. break
  908. self.grid.set(pos[0], pos[1], obj)
  909. if obj is not None:
  910. obj.init_pos = pos
  911. obj.cur_pos = pos
  912. return pos
  913. def put_obj(self, obj, i, j):
  914. """
  915. Put an object at a specific position in the grid
  916. """
  917. self.grid.set(i, j, obj)
  918. obj.init_pos = (i, j)
  919. obj.cur_pos = (i, j)
  920. def place_agent(self, top=None, size=None, rand_dir=True, max_tries=math.inf):
  921. """
  922. Set the agent's starting point at an empty position in the grid
  923. """
  924. self.agent_pos = (-1, -1)
  925. pos = self.place_obj(None, top, size, max_tries=max_tries)
  926. self.agent_pos = pos
  927. if rand_dir:
  928. self.agent_dir = self._rand_int(0, 4)
  929. return pos
  930. @property
  931. def dir_vec(self):
  932. """
  933. Get the direction vector for the agent, pointing in the direction
  934. of forward movement.
  935. """
  936. assert self.agent_dir >= 0 and self.agent_dir < 4
  937. return DIR_TO_VEC[self.agent_dir]
  938. @property
  939. def right_vec(self):
  940. """
  941. Get the vector pointing to the right of the agent.
  942. """
  943. dx, dy = self.dir_vec
  944. return np.array((-dy, dx))
  945. @property
  946. def front_pos(self):
  947. """
  948. Get the position of the cell that is right in front of the agent
  949. """
  950. return self.agent_pos + self.dir_vec
  951. def get_view_coords(self, i, j):
  952. """
  953. Translate and rotate absolute grid coordinates (i, j) into the
  954. agent's partially observable view (sub-grid). Note that the resulting
  955. coordinates may be negative or outside of the agent's view size.
  956. """
  957. ax, ay = self.agent_pos
  958. dx, dy = self.dir_vec
  959. rx, ry = self.right_vec
  960. # Compute the absolute coordinates of the top-left view corner
  961. sz = self.agent_view_size
  962. hs = self.agent_view_size // 2
  963. tx = ax + (dx * (sz - 1)) - (rx * hs)
  964. ty = ay + (dy * (sz - 1)) - (ry * hs)
  965. lx = i - tx
  966. ly = j - ty
  967. # Project the coordinates of the object relative to the top-left
  968. # corner onto the agent's own coordinate system
  969. vx = rx * lx + ry * ly
  970. vy = -(dx * lx + dy * ly)
  971. return vx, vy
  972. def get_view_exts(self, agent_view_size=None):
  973. """
  974. Get the extents of the square set of tiles visible to the agent
  975. Note: the bottom extent indices are not included in the set
  976. if agent_view_size is None, use self.agent_view_size
  977. """
  978. agent_view_size = agent_view_size or self.agent_view_size
  979. # Facing right
  980. if self.agent_dir == 0:
  981. topX = self.agent_pos[0]
  982. topY = self.agent_pos[1] - agent_view_size // 2
  983. # Facing down
  984. elif self.agent_dir == 1:
  985. topX = self.agent_pos[0] - agent_view_size // 2
  986. topY = self.agent_pos[1]
  987. # Facing left
  988. elif self.agent_dir == 2:
  989. topX = self.agent_pos[0] - agent_view_size + 1
  990. topY = self.agent_pos[1] - agent_view_size // 2
  991. # Facing up
  992. elif self.agent_dir == 3:
  993. topX = self.agent_pos[0] - agent_view_size // 2
  994. topY = self.agent_pos[1] - agent_view_size + 1
  995. else:
  996. assert False, "invalid agent direction"
  997. botX = topX + agent_view_size
  998. botY = topY + agent_view_size
  999. return (topX, topY, botX, botY)
  1000. def relative_coords(self, x, y):
  1001. """
  1002. Check if a grid position belongs to the agent's field of view, and returns the corresponding coordinates
  1003. """
  1004. vx, vy = self.get_view_coords(x, y)
  1005. if vx < 0 or vy < 0 or vx >= self.agent_view_size or vy >= self.agent_view_size:
  1006. return None
  1007. return vx, vy
  1008. def in_view(self, x, y):
  1009. """
  1010. check if a grid position is visible to the agent
  1011. """
  1012. return self.relative_coords(x, y) is not None
  1013. def agent_sees(self, x, y):
  1014. """
  1015. Check if a non-empty grid position is visible to the agent
  1016. """
  1017. coordinates = self.relative_coords(x, y)
  1018. if coordinates is None:
  1019. return False
  1020. vx, vy = coordinates
  1021. obs = self.gen_obs()
  1022. obs_grid, _ = Grid.decode(obs["image"])
  1023. obs_cell = obs_grid.get(vx, vy)
  1024. world_cell = self.grid.get(x, y)
  1025. assert world_cell is not None
  1026. return obs_cell is not None and obs_cell.type == world_cell.type
  1027. def step(self, action):
  1028. self.step_count += 1
  1029. reward = 0
  1030. terminated = False
  1031. truncated = False
  1032. # Get the position in front of the agent
  1033. fwd_pos = self.front_pos
  1034. # Get the contents of the cell in front of the agent
  1035. fwd_cell = self.grid.get(*fwd_pos)
  1036. # Rotate left
  1037. if action == self.actions.left:
  1038. self.agent_dir -= 1
  1039. if self.agent_dir < 0:
  1040. self.agent_dir += 4
  1041. # Rotate right
  1042. elif action == self.actions.right:
  1043. self.agent_dir = (self.agent_dir + 1) % 4
  1044. # Move forward
  1045. elif action == self.actions.forward:
  1046. if fwd_cell is None or fwd_cell.can_overlap():
  1047. self.agent_pos = tuple(fwd_pos)
  1048. if fwd_cell is not None and fwd_cell.type == "goal":
  1049. terminated = True
  1050. reward = self._reward()
  1051. if fwd_cell is not None and fwd_cell.type == "lava":
  1052. terminated = True
  1053. # Pick up an object
  1054. elif action == self.actions.pickup:
  1055. if fwd_cell and fwd_cell.can_pickup():
  1056. if self.carrying is None:
  1057. self.carrying = fwd_cell
  1058. self.carrying.cur_pos = np.array([-1, -1])
  1059. self.grid.set(fwd_pos[0], fwd_pos[1], None)
  1060. # Drop an object
  1061. elif action == self.actions.drop:
  1062. if not fwd_cell and self.carrying:
  1063. self.grid.set(fwd_pos[0], fwd_pos[1], self.carrying)
  1064. self.carrying.cur_pos = fwd_pos
  1065. self.carrying = None
  1066. # Toggle/activate an object
  1067. elif action == self.actions.toggle:
  1068. if fwd_cell:
  1069. fwd_cell.toggle(self, fwd_pos)
  1070. # Done action (not used by default)
  1071. elif action == self.actions.done:
  1072. pass
  1073. else:
  1074. raise ValueError(f"Unknown action: {action}")
  1075. if self.step_count >= self.max_steps:
  1076. truncated = True
  1077. if self.render_mode == "human":
  1078. self.render()
  1079. obs = self.gen_obs()
  1080. return obs, reward, terminated, truncated, {}
  1081. def gen_obs_grid(self, agent_view_size=None):
  1082. """
  1083. Generate the sub-grid observed by the agent.
  1084. This method also outputs a visibility mask telling us which grid
  1085. cells the agent can actually see.
  1086. if agent_view_size is None, self.agent_view_size is used
  1087. """
  1088. topX, topY, botX, botY = self.get_view_exts(agent_view_size)
  1089. agent_view_size = agent_view_size or self.agent_view_size
  1090. grid = self.grid.slice(topX, topY, agent_view_size, agent_view_size)
  1091. for i in range(self.agent_dir + 1):
  1092. grid = grid.rotate_left()
  1093. # Process occluders and visibility
  1094. # Note that this incurs some performance cost
  1095. if not self.see_through_walls:
  1096. vis_mask = grid.process_vis(
  1097. agent_pos=(agent_view_size // 2, agent_view_size - 1)
  1098. )
  1099. else:
  1100. vis_mask = np.ones(shape=(grid.width, grid.height), dtype=bool)
  1101. # Make it so the agent sees what it's carrying
  1102. # We do this by placing the carried object at the agent's position
  1103. # in the agent's partially observable view
  1104. agent_pos = grid.width // 2, grid.height - 1
  1105. if self.carrying:
  1106. grid.set(*agent_pos, self.carrying)
  1107. else:
  1108. grid.set(*agent_pos, None)
  1109. return grid, vis_mask
  1110. def gen_obs(self):
  1111. """
  1112. Generate the agent's view (partially observable, low-resolution encoding)
  1113. """
  1114. grid, vis_mask = self.gen_obs_grid()
  1115. # Encode the partially observable view into a numpy array
  1116. image = grid.encode(vis_mask)
  1117. # Observations are dictionaries containing:
  1118. # - an image (partially observable view of the environment)
  1119. # - the agent's direction/orientation (acting as a compass)
  1120. # - a textual mission string (instructions for the agent)
  1121. obs = {"image": image, "direction": self.agent_dir, "mission": self.mission}
  1122. return obs
  1123. def get_pov_render(self, tile_size):
  1124. """
  1125. Render an agent's POV observation for visualization
  1126. """
  1127. grid, vis_mask = self.gen_obs_grid()
  1128. # Render the whole grid
  1129. img = grid.render(
  1130. tile_size,
  1131. agent_pos=(self.agent_view_size // 2, self.agent_view_size - 1),
  1132. agent_dir=3,
  1133. highlight_mask=vis_mask,
  1134. )
  1135. return img
  1136. def get_full_render(self, highlight, tile_size):
  1137. """
  1138. Render a non-paratial observation for visualization
  1139. """
  1140. # Compute which cells are visible to the agent
  1141. _, vis_mask = self.gen_obs_grid()
  1142. # Compute the world coordinates of the bottom-left corner
  1143. # of the agent's view area
  1144. f_vec = self.dir_vec
  1145. r_vec = self.right_vec
  1146. top_left = (
  1147. self.agent_pos
  1148. + f_vec * (self.agent_view_size - 1)
  1149. - r_vec * (self.agent_view_size // 2)
  1150. )
  1151. # Mask of which cells to highlight
  1152. highlight_mask = np.zeros(shape=(self.width, self.height), dtype=bool)
  1153. # For each cell in the visibility mask
  1154. for vis_j in range(0, self.agent_view_size):
  1155. for vis_i in range(0, self.agent_view_size):
  1156. # If this cell is not visible, don't highlight it
  1157. if not vis_mask[vis_i, vis_j]:
  1158. continue
  1159. # Compute the world coordinates of this cell
  1160. abs_i, abs_j = top_left - (f_vec * vis_j) + (r_vec * vis_i)
  1161. if abs_i < 0 or abs_i >= self.width:
  1162. continue
  1163. if abs_j < 0 or abs_j >= self.height:
  1164. continue
  1165. # Mark this cell to be highlighted
  1166. highlight_mask[abs_i, abs_j] = True
  1167. # Render the whole grid
  1168. img = self.grid.render(
  1169. tile_size,
  1170. self.agent_pos,
  1171. self.agent_dir,
  1172. highlight_mask=highlight_mask if highlight else None,
  1173. )
  1174. return img
  1175. def get_frame(
  1176. self,
  1177. highlight: bool = True,
  1178. tile_size: int = TILE_PIXELS,
  1179. agent_pov: bool = False,
  1180. ):
  1181. """Returns an RGB image corresponding to the whole environment or the agent's point of view.
  1182. Args:
  1183. highlight (bool): If true, the agent's field of view or point of view is highlighted with a lighter gray color.
  1184. tile_size (int): How many pixels will form a tile from the NxM grid.
  1185. agent_pov (bool): If true, the rendered frame will only contain the point of view of the agent.
  1186. Returns:
  1187. frame (np.ndarray): A frame of type numpy.ndarray with shape (x, y, 3) representing RGB values for the x-by-y pixel image.
  1188. """
  1189. if agent_pov:
  1190. return self.get_pov_render(tile_size)
  1191. else:
  1192. return self.get_full_render(highlight, tile_size)
  1193. def render(self):
  1194. img = self.get_frame(self.highlight, self.tile_size, self.agent_pov)
  1195. if self.render_mode == "human":
  1196. if self.window is None:
  1197. self.window = Window("minigrid")
  1198. self.window.show(block=False)
  1199. self.window.set_caption(self.mission)
  1200. self.window.show_img(img)
  1201. elif self.render_mode == "rgb_array":
  1202. return img
  1203. def close(self):
  1204. if self.window:
  1205. self.window.close()