minigrid.py 47 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561
  1. import hashlib
  2. import math
  3. from abc import abstractmethod
  4. from enum import IntEnum
  5. from functools import partial
  6. from typing import Any, Callable, Optional, Union
  7. import gym
  8. import numpy as np
  9. from gym import spaces
  10. from gym.utils import seeding
  11. from gym.utils.renderer import Renderer
  12. # Size in pixels of a tile in the full-scale human view
  13. from gym_minigrid.rendering import (
  14. downsample,
  15. fill_coords,
  16. highlight_img,
  17. point_in_circle,
  18. point_in_line,
  19. point_in_rect,
  20. point_in_triangle,
  21. rotate_fn,
  22. )
  23. from gym_minigrid.window import Window
  24. TILE_PIXELS = 32
  25. # Map of color names to RGB values
  26. COLORS = {
  27. "red": np.array([255, 0, 0]),
  28. "green": np.array([0, 255, 0]),
  29. "blue": np.array([0, 0, 255]),
  30. "purple": np.array([112, 39, 195]),
  31. "yellow": np.array([255, 255, 0]),
  32. "grey": np.array([100, 100, 100]),
  33. }
  34. COLOR_NAMES = sorted(list(COLORS.keys()))
  35. # Used to map colors to integers
  36. COLOR_TO_IDX = {"red": 0, "green": 1, "blue": 2, "purple": 3, "yellow": 4, "grey": 5}
  37. IDX_TO_COLOR = dict(zip(COLOR_TO_IDX.values(), COLOR_TO_IDX.keys()))
  38. # Map of object type to integers
  39. OBJECT_TO_IDX = {
  40. "unseen": 0,
  41. "empty": 1,
  42. "wall": 2,
  43. "floor": 3,
  44. "door": 4,
  45. "key": 5,
  46. "ball": 6,
  47. "box": 7,
  48. "goal": 8,
  49. "lava": 9,
  50. "agent": 10,
  51. }
  52. IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys()))
  53. # Map of state names to integers
  54. STATE_TO_IDX = {
  55. "open": 0,
  56. "closed": 1,
  57. "locked": 2,
  58. }
  59. # Map of agent direction indices to vectors
  60. DIR_TO_VEC = [
  61. # Pointing right (positive X)
  62. np.array((1, 0)),
  63. # Down (positive Y)
  64. np.array((0, 1)),
  65. # Pointing left (negative X)
  66. np.array((-1, 0)),
  67. # Up (negative Y)
  68. np.array((0, -1)),
  69. ]
  70. def check_if_no_duplicate(duplicate_list: list) -> bool:
  71. """Check if given list contains any duplicates"""
  72. return len(set(duplicate_list)) == len(duplicate_list)
  73. class MissionSpace(spaces.Space[str]):
  74. r"""A space representing a mission for the Gym-Minigrid environments.
  75. The space allows generating random mission strings constructed with an input placeholder list.
  76. Example Usage::
  77. >>> observation_space = MissionSpace(mission_func=lambda color: f"Get the {color} ball.",
  78. ordered_placeholders=[["green", "blue"]])
  79. >>> observation_space.sample()
  80. "Get the green ball."
  81. >>> observation_space = MissionSpace(mission_func=lambda : "Get the ball.".,
  82. ordered_placeholders=None)
  83. >>> observation_space.sample()
  84. "Get the ball."
  85. """
  86. def __init__(
  87. self,
  88. mission_func: Callable[..., str],
  89. ordered_placeholders: Optional["list[list[str]]"] = None,
  90. seed: Optional[Union[int, seeding.RandomNumberGenerator]] = None,
  91. ):
  92. r"""Constructor of :class:`MissionSpace` space.
  93. Args:
  94. mission_func (lambda _placeholders(str): _mission(str)): Function that generates a mission string from random placeholders.
  95. ordered_placeholders (Optional["list[list[str]]"]): List of lists of placeholders ordered in placing order in the mission function mission_func.
  96. seed: seed: The seed for sampling from the space.
  97. """
  98. # Check that the ordered placeholders and mission function are well defined.
  99. if ordered_placeholders is not None:
  100. assert (
  101. len(ordered_placeholders) == mission_func.__code__.co_argcount
  102. ), f"The number of placeholders {len(ordered_placeholders)} is different from the number of parameters in the mission function {mission_func.__code__.co_argcount}."
  103. for placeholder_list in ordered_placeholders:
  104. assert check_if_no_duplicate(
  105. placeholder_list
  106. ), "Make sure that the placeholders don't have any duplicate values."
  107. else:
  108. assert (
  109. mission_func.__code__.co_argcount == 0
  110. ), f"If the ordered placeholders are {ordered_placeholders}, the mission function shouldn't have any parameters."
  111. self.ordered_placeholders = ordered_placeholders
  112. self.mission_func = mission_func
  113. super().__init__(dtype=str, seed=seed)
  114. # Check that mission_func returns a string
  115. sampled_mission = self.sample()
  116. assert isinstance(
  117. sampled_mission, str
  118. ), f"mission_func must return type str not {type(sampled_mission)}"
  119. def sample(self) -> str:
  120. """Sample a random mission string."""
  121. if self.ordered_placeholders is not None:
  122. placeholders = []
  123. for rand_var_list in self.ordered_placeholders:
  124. idx = self.np_random.integers(0, len(rand_var_list))
  125. placeholders.append(rand_var_list[idx])
  126. return self.mission_func(*placeholders)
  127. else:
  128. return self.mission_func()
  129. def contains(self, x: Any) -> bool:
  130. """Return boolean specifying if x is a valid member of this space."""
  131. # Store a list of all the placeholders from self.ordered_placeholders that appear in x
  132. if self.ordered_placeholders is not None:
  133. check_placeholder_list = []
  134. for placeholder_list in self.ordered_placeholders:
  135. for placeholder in placeholder_list:
  136. if placeholder in x:
  137. check_placeholder_list.append(placeholder)
  138. # Remove duplicates from the list
  139. check_placeholder_list = list(set(check_placeholder_list))
  140. start_id_placeholder = []
  141. end_id_placeholder = []
  142. # Get the starting and ending id of the identified placeholders with possible duplicates
  143. new_check_placeholder_list = []
  144. for placeholder in check_placeholder_list:
  145. new_start_id_placeholder = [
  146. i for i in range(len(x)) if x.startswith(placeholder, i)
  147. ]
  148. new_check_placeholder_list += [placeholder] * len(
  149. new_start_id_placeholder
  150. )
  151. end_id_placeholder += [
  152. start_id + len(placeholder) - 1
  153. for start_id in new_start_id_placeholder
  154. ]
  155. start_id_placeholder += new_start_id_placeholder
  156. # Order by starting id the placeholders
  157. ordered_placeholder_list = sorted(
  158. zip(
  159. start_id_placeholder, end_id_placeholder, new_check_placeholder_list
  160. )
  161. )
  162. # Check for repeated placeholders contained in each other
  163. remove_placeholder_id = []
  164. for i, placeholder_1 in enumerate(ordered_placeholder_list):
  165. starting_id = i + 1
  166. for j, placeholder_2 in enumerate(
  167. ordered_placeholder_list[starting_id:]
  168. ):
  169. # Check if place holder ids overlap and keep the longest
  170. if max(placeholder_1[0], placeholder_2[0]) < min(
  171. placeholder_1[1], placeholder_2[1]
  172. ):
  173. remove_placeholder = min(
  174. placeholder_1[2], placeholder_2[2], key=len
  175. )
  176. if remove_placeholder == placeholder_1[2]:
  177. remove_placeholder_id.append(i)
  178. else:
  179. remove_placeholder_id.append(i + j + 1)
  180. for id in remove_placeholder_id:
  181. del ordered_placeholder_list[id]
  182. final_placeholders = [
  183. placeholder[2] for placeholder in ordered_placeholder_list
  184. ]
  185. # Check that the identified final placeholders are in the same order as the original placeholders.
  186. for orered_placeholder, final_placeholder in zip(
  187. self.ordered_placeholders, final_placeholders
  188. ):
  189. if final_placeholder in orered_placeholder:
  190. continue
  191. else:
  192. return False
  193. try:
  194. mission_string_with_placeholders = self.mission_func(
  195. *final_placeholders
  196. )
  197. except Exception as e:
  198. print(
  199. f"{x} is not contained in MissionSpace due to the following exception: {e}"
  200. )
  201. return False
  202. return bool(mission_string_with_placeholders == x)
  203. else:
  204. return bool(self.mission_func() == x)
  205. def __repr__(self) -> str:
  206. """Gives a string representation of this space."""
  207. return f"MissionSpace({self.mission_func}, {self.ordered_placeholders})"
  208. def __eq__(self, other) -> bool:
  209. """Check whether ``other`` is equivalent to this instance."""
  210. if isinstance(other, MissionSpace):
  211. # Check that place holder lists are the same
  212. if self.ordered_placeholders is not None:
  213. # Check length
  214. if (len(self.order_placeholder) == len(other.order_placeholder)) and (
  215. all(
  216. set(i) == set(j)
  217. for i, j in zip(self.order_placeholder, other.order_placeholder)
  218. )
  219. ):
  220. # Check mission string is the same with dummy space placeholders
  221. test_placeholders = [""] * len(self.order_placeholder)
  222. mission = self.mission_func(*test_placeholders)
  223. other_mission = other.mission_func(*test_placeholders)
  224. return mission == other_mission
  225. else:
  226. # Check that other is also None
  227. if other.ordered_placeholders is None:
  228. # Check mission string is the same
  229. mission = self.mission_func()
  230. other_mission = other.mission_func()
  231. return mission == other_mission
  232. # If none of the statements above return then False
  233. return False
  234. class WorldObj:
  235. """
  236. Base class for grid world objects
  237. """
  238. def __init__(self, type, color):
  239. assert type in OBJECT_TO_IDX, type
  240. assert color in COLOR_TO_IDX, color
  241. self.type = type
  242. self.color = color
  243. self.contains = None
  244. # Initial position of the object
  245. self.init_pos = None
  246. # Current position of the object
  247. self.cur_pos = None
  248. def can_overlap(self):
  249. """Can the agent overlap with this?"""
  250. return False
  251. def can_pickup(self):
  252. """Can the agent pick this up?"""
  253. return False
  254. def can_contain(self):
  255. """Can this contain another object?"""
  256. return False
  257. def see_behind(self):
  258. """Can the agent see behind this object?"""
  259. return True
  260. def toggle(self, env, pos):
  261. """Method to trigger/toggle an action this object performs"""
  262. return False
  263. def encode(self):
  264. """Encode the a description of this object as a 3-tuple of integers"""
  265. return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], 0)
  266. @staticmethod
  267. def decode(type_idx, color_idx, state):
  268. """Create an object from a 3-tuple state description"""
  269. obj_type = IDX_TO_OBJECT[type_idx]
  270. color = IDX_TO_COLOR[color_idx]
  271. if obj_type == "empty" or obj_type == "unseen":
  272. return None
  273. # State, 0: open, 1: closed, 2: locked
  274. is_open = state == 0
  275. is_locked = state == 2
  276. if obj_type == "wall":
  277. v = Wall(color)
  278. elif obj_type == "floor":
  279. v = Floor(color)
  280. elif obj_type == "ball":
  281. v = Ball(color)
  282. elif obj_type == "key":
  283. v = Key(color)
  284. elif obj_type == "box":
  285. v = Box(color)
  286. elif obj_type == "door":
  287. v = Door(color, is_open, is_locked)
  288. elif obj_type == "goal":
  289. v = Goal()
  290. elif obj_type == "lava":
  291. v = Lava()
  292. else:
  293. assert False, "unknown object type in decode '%s'" % obj_type
  294. return v
  295. def render(self, r):
  296. """Draw this object with the given renderer"""
  297. raise NotImplementedError
  298. class Goal(WorldObj):
  299. def __init__(self):
  300. super().__init__("goal", "green")
  301. def can_overlap(self):
  302. return True
  303. def render(self, img):
  304. fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color])
  305. class Floor(WorldObj):
  306. """
  307. Colored floor tile the agent can walk over
  308. """
  309. def __init__(self, color="blue"):
  310. super().__init__("floor", color)
  311. def can_overlap(self):
  312. return True
  313. def render(self, img):
  314. # Give the floor a pale color
  315. color = COLORS[self.color] / 2
  316. fill_coords(img, point_in_rect(0.031, 1, 0.031, 1), color)
  317. class Lava(WorldObj):
  318. def __init__(self):
  319. super().__init__("lava", "red")
  320. def can_overlap(self):
  321. return True
  322. def render(self, img):
  323. c = (255, 128, 0)
  324. # Background color
  325. fill_coords(img, point_in_rect(0, 1, 0, 1), c)
  326. # Little waves
  327. for i in range(3):
  328. ylo = 0.3 + 0.2 * i
  329. yhi = 0.4 + 0.2 * i
  330. fill_coords(img, point_in_line(0.1, ylo, 0.3, yhi, r=0.03), (0, 0, 0))
  331. fill_coords(img, point_in_line(0.3, yhi, 0.5, ylo, r=0.03), (0, 0, 0))
  332. fill_coords(img, point_in_line(0.5, ylo, 0.7, yhi, r=0.03), (0, 0, 0))
  333. fill_coords(img, point_in_line(0.7, yhi, 0.9, ylo, r=0.03), (0, 0, 0))
  334. class Wall(WorldObj):
  335. def __init__(self, color="grey"):
  336. super().__init__("wall", color)
  337. def see_behind(self):
  338. return False
  339. def render(self, img):
  340. fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color])
  341. class Door(WorldObj):
  342. def __init__(self, color, is_open=False, is_locked=False):
  343. super().__init__("door", color)
  344. self.is_open = is_open
  345. self.is_locked = is_locked
  346. def can_overlap(self):
  347. """The agent can only walk over this cell when the door is open"""
  348. return self.is_open
  349. def see_behind(self):
  350. return self.is_open
  351. def toggle(self, env, pos):
  352. # If the player has the right key to open the door
  353. if self.is_locked:
  354. if isinstance(env.carrying, Key) and env.carrying.color == self.color:
  355. self.is_locked = False
  356. self.is_open = True
  357. return True
  358. return False
  359. self.is_open = not self.is_open
  360. return True
  361. def encode(self):
  362. """Encode the a description of this object as a 3-tuple of integers"""
  363. # State, 0: open, 1: closed, 2: locked
  364. if self.is_open:
  365. state = 0
  366. elif self.is_locked:
  367. state = 2
  368. # if door is closed and unlocked
  369. elif not self.is_open:
  370. state = 1
  371. else:
  372. raise ValueError(
  373. f"There is no possible state encoding for the state:\n -Door Open: {self.is_open}\n -Door Closed: {not self.is_open}\n -Door Locked: {self.is_locked}"
  374. )
  375. return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], state)
  376. def render(self, img):
  377. c = COLORS[self.color]
  378. if self.is_open:
  379. fill_coords(img, point_in_rect(0.88, 1.00, 0.00, 1.00), c)
  380. fill_coords(img, point_in_rect(0.92, 0.96, 0.04, 0.96), (0, 0, 0))
  381. return
  382. # Door frame and door
  383. if self.is_locked:
  384. fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)
  385. fill_coords(img, point_in_rect(0.06, 0.94, 0.06, 0.94), 0.45 * np.array(c))
  386. # Draw key slot
  387. fill_coords(img, point_in_rect(0.52, 0.75, 0.50, 0.56), c)
  388. else:
  389. fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)
  390. fill_coords(img, point_in_rect(0.04, 0.96, 0.04, 0.96), (0, 0, 0))
  391. fill_coords(img, point_in_rect(0.08, 0.92, 0.08, 0.92), c)
  392. fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), (0, 0, 0))
  393. # Draw door handle
  394. fill_coords(img, point_in_circle(cx=0.75, cy=0.50, r=0.08), c)
  395. class Key(WorldObj):
  396. def __init__(self, color="blue"):
  397. super().__init__("key", color)
  398. def can_pickup(self):
  399. return True
  400. def render(self, img):
  401. c = COLORS[self.color]
  402. # Vertical quad
  403. fill_coords(img, point_in_rect(0.50, 0.63, 0.31, 0.88), c)
  404. # Teeth
  405. fill_coords(img, point_in_rect(0.38, 0.50, 0.59, 0.66), c)
  406. fill_coords(img, point_in_rect(0.38, 0.50, 0.81, 0.88), c)
  407. # Ring
  408. fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.190), c)
  409. fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.064), (0, 0, 0))
  410. class Ball(WorldObj):
  411. def __init__(self, color="blue"):
  412. super().__init__("ball", color)
  413. def can_pickup(self):
  414. return True
  415. def render(self, img):
  416. fill_coords(img, point_in_circle(0.5, 0.5, 0.31), COLORS[self.color])
  417. class Box(WorldObj):
  418. def __init__(self, color, contains=None):
  419. super().__init__("box", color)
  420. self.contains = contains
  421. def can_pickup(self):
  422. return True
  423. def render(self, img):
  424. c = COLORS[self.color]
  425. # Outline
  426. fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), c)
  427. fill_coords(img, point_in_rect(0.18, 0.82, 0.18, 0.82), (0, 0, 0))
  428. # Horizontal slit
  429. fill_coords(img, point_in_rect(0.16, 0.84, 0.47, 0.53), c)
  430. def toggle(self, env, pos):
  431. # Replace the box by its contents
  432. env.grid.set(pos[0], pos[1], self.contains)
  433. return True
  434. class Grid:
  435. """
  436. Represent a grid and operations on it
  437. """
  438. # Static cache of pre-renderer tiles
  439. tile_cache = {}
  440. def __init__(self, width, height):
  441. assert width >= 3
  442. assert height >= 3
  443. self.width = width
  444. self.height = height
  445. self.grid = [None] * width * height
  446. def __contains__(self, key):
  447. if isinstance(key, WorldObj):
  448. for e in self.grid:
  449. if e is key:
  450. return True
  451. elif isinstance(key, tuple):
  452. for e in self.grid:
  453. if e is None:
  454. continue
  455. if (e.color, e.type) == key:
  456. return True
  457. if key[0] is None and key[1] == e.type:
  458. return True
  459. return False
  460. def __eq__(self, other):
  461. grid1 = self.encode()
  462. grid2 = other.encode()
  463. return np.array_equal(grid2, grid1)
  464. def __ne__(self, other):
  465. return not self == other
  466. def copy(self):
  467. from copy import deepcopy
  468. return deepcopy(self)
  469. def set(self, i, j, v):
  470. assert i >= 0 and i < self.width
  471. assert j >= 0 and j < self.height
  472. self.grid[j * self.width + i] = v
  473. def get(self, i, j):
  474. assert i >= 0 and i < self.width
  475. assert j >= 0 and j < self.height
  476. return self.grid[j * self.width + i]
  477. def horz_wall(self, x, y, length=None, obj_type=Wall):
  478. if length is None:
  479. length = self.width - x
  480. for i in range(0, length):
  481. self.set(x + i, y, obj_type())
  482. def vert_wall(self, x, y, length=None, obj_type=Wall):
  483. if length is None:
  484. length = self.height - y
  485. for j in range(0, length):
  486. self.set(x, y + j, obj_type())
  487. def wall_rect(self, x, y, w, h):
  488. self.horz_wall(x, y, w)
  489. self.horz_wall(x, y + h - 1, w)
  490. self.vert_wall(x, y, h)
  491. self.vert_wall(x + w - 1, y, h)
  492. def rotate_left(self):
  493. """
  494. Rotate the grid to the left (counter-clockwise)
  495. """
  496. grid = Grid(self.height, self.width)
  497. for i in range(self.width):
  498. for j in range(self.height):
  499. v = self.get(i, j)
  500. grid.set(j, grid.height - 1 - i, v)
  501. return grid
  502. def slice(self, topX, topY, width, height):
  503. """
  504. Get a subset of the grid
  505. """
  506. grid = Grid(width, height)
  507. for j in range(0, height):
  508. for i in range(0, width):
  509. x = topX + i
  510. y = topY + j
  511. if x >= 0 and x < self.width and y >= 0 and y < self.height:
  512. v = self.get(x, y)
  513. else:
  514. v = Wall()
  515. grid.set(i, j, v)
  516. return grid
  517. @classmethod
  518. def render_tile(
  519. cls, obj, agent_dir=None, highlight=False, tile_size=TILE_PIXELS, subdivs=3
  520. ):
  521. """
  522. Render a tile and cache the result
  523. """
  524. # Hash map lookup key for the cache
  525. key = (agent_dir, highlight, tile_size)
  526. key = obj.encode() + key if obj else key
  527. if key in cls.tile_cache:
  528. return cls.tile_cache[key]
  529. img = np.zeros(
  530. shape=(tile_size * subdivs, tile_size * subdivs, 3), dtype=np.uint8
  531. )
  532. # Draw the grid lines (top and left edges)
  533. fill_coords(img, point_in_rect(0, 0.031, 0, 1), (100, 100, 100))
  534. fill_coords(img, point_in_rect(0, 1, 0, 0.031), (100, 100, 100))
  535. if obj is not None:
  536. obj.render(img)
  537. # Overlay the agent on top
  538. if agent_dir is not None:
  539. tri_fn = point_in_triangle(
  540. (0.12, 0.19),
  541. (0.87, 0.50),
  542. (0.12, 0.81),
  543. )
  544. # Rotate the agent based on its direction
  545. tri_fn = rotate_fn(tri_fn, cx=0.5, cy=0.5, theta=0.5 * math.pi * agent_dir)
  546. fill_coords(img, tri_fn, (255, 0, 0))
  547. # Highlight the cell if needed
  548. if highlight:
  549. highlight_img(img)
  550. # Downsample the image to perform supersampling/anti-aliasing
  551. img = downsample(img, subdivs)
  552. # Cache the rendered tile
  553. cls.tile_cache[key] = img
  554. return img
  555. def render(self, tile_size, agent_pos, agent_dir=None, highlight_mask=None):
  556. """
  557. Render this grid at a given scale
  558. :param r: target renderer object
  559. :param tile_size: tile size in pixels
  560. """
  561. if highlight_mask is None:
  562. highlight_mask = np.zeros(shape=(self.width, self.height), dtype=bool)
  563. # Compute the total grid size
  564. width_px = self.width * tile_size
  565. height_px = self.height * tile_size
  566. img = np.zeros(shape=(height_px, width_px, 3), dtype=np.uint8)
  567. # Render the grid
  568. for j in range(0, self.height):
  569. for i in range(0, self.width):
  570. cell = self.get(i, j)
  571. agent_here = np.array_equal(agent_pos, (i, j))
  572. tile_img = Grid.render_tile(
  573. cell,
  574. agent_dir=agent_dir if agent_here else None,
  575. highlight=highlight_mask[i, j],
  576. tile_size=tile_size,
  577. )
  578. ymin = j * tile_size
  579. ymax = (j + 1) * tile_size
  580. xmin = i * tile_size
  581. xmax = (i + 1) * tile_size
  582. img[ymin:ymax, xmin:xmax, :] = tile_img
  583. return img
  584. def encode(self, vis_mask=None):
  585. """
  586. Produce a compact numpy encoding of the grid
  587. """
  588. if vis_mask is None:
  589. vis_mask = np.ones((self.width, self.height), dtype=bool)
  590. array = np.zeros((self.width, self.height, 3), dtype="uint8")
  591. for i in range(self.width):
  592. for j in range(self.height):
  593. if vis_mask[i, j]:
  594. v = self.get(i, j)
  595. if v is None:
  596. array[i, j, 0] = OBJECT_TO_IDX["empty"]
  597. array[i, j, 1] = 0
  598. array[i, j, 2] = 0
  599. else:
  600. array[i, j, :] = v.encode()
  601. return array
  602. @staticmethod
  603. def decode(array):
  604. """
  605. Decode an array grid encoding back into a grid
  606. """
  607. width, height, channels = array.shape
  608. assert channels == 3
  609. vis_mask = np.ones(shape=(width, height), dtype=bool)
  610. grid = Grid(width, height)
  611. for i in range(width):
  612. for j in range(height):
  613. type_idx, color_idx, state = array[i, j]
  614. v = WorldObj.decode(type_idx, color_idx, state)
  615. grid.set(i, j, v)
  616. vis_mask[i, j] = type_idx != OBJECT_TO_IDX["unseen"]
  617. return grid, vis_mask
  618. def process_vis(self, agent_pos):
  619. mask = np.zeros(shape=(self.width, self.height), dtype=bool)
  620. mask[agent_pos[0], agent_pos[1]] = True
  621. for j in reversed(range(0, self.height)):
  622. for i in range(0, self.width - 1):
  623. if not mask[i, j]:
  624. continue
  625. cell = self.get(i, j)
  626. if cell and not cell.see_behind():
  627. continue
  628. mask[i + 1, j] = True
  629. if j > 0:
  630. mask[i + 1, j - 1] = True
  631. mask[i, j - 1] = True
  632. for i in reversed(range(1, self.width)):
  633. if not mask[i, j]:
  634. continue
  635. cell = self.get(i, j)
  636. if cell and not cell.see_behind():
  637. continue
  638. mask[i - 1, j] = True
  639. if j > 0:
  640. mask[i - 1, j - 1] = True
  641. mask[i, j - 1] = True
  642. for j in range(0, self.height):
  643. for i in range(0, self.width):
  644. if not mask[i, j]:
  645. self.set(i, j, None)
  646. return mask
  647. class MiniGridEnv(gym.Env):
  648. """
  649. 2D grid world game environment
  650. """
  651. metadata = {
  652. "render_modes": ["human", "rgb_array", "single_rgb_array"],
  653. "render_fps": 10,
  654. }
  655. # Enumeration of possible actions
  656. class Actions(IntEnum):
  657. # Turn left, turn right, move forward
  658. left = 0
  659. right = 1
  660. forward = 2
  661. # Pick up an object
  662. pickup = 3
  663. # Drop an object
  664. drop = 4
  665. # Toggle/activate an object
  666. toggle = 5
  667. # Done completing task
  668. done = 6
  669. def __init__(
  670. self,
  671. mission_space: MissionSpace,
  672. grid_size: int = None,
  673. width: int = None,
  674. height: int = None,
  675. max_steps: int = 100,
  676. see_through_walls: bool = False,
  677. agent_view_size: int = 7,
  678. render_mode: Optional[str] = None,
  679. highlight: bool = True,
  680. tile_size: int = TILE_PIXELS,
  681. agent_pov: bool = False,
  682. ):
  683. # Rendering attributes
  684. self.render_mode = render_mode
  685. # Initialize mission
  686. self.mission = mission_space.sample()
  687. # Can't set both grid_size and width/height
  688. if grid_size:
  689. assert width is None and height is None
  690. width = grid_size
  691. height = grid_size
  692. # Action enumeration for this environment
  693. self.actions = MiniGridEnv.Actions
  694. # Actions are discrete integer values
  695. self.action_space = spaces.Discrete(len(self.actions))
  696. # Number of cells (width and height) in the agent view
  697. assert agent_view_size % 2 == 1
  698. assert agent_view_size >= 3
  699. self.agent_view_size = agent_view_size
  700. # Observations are dictionaries containing an
  701. # encoding of the grid and a textual 'mission' string
  702. image_observation_space = spaces.Box(
  703. low=0,
  704. high=255,
  705. shape=(self.agent_view_size, self.agent_view_size, 3),
  706. dtype="uint8",
  707. )
  708. self.observation_space = spaces.Dict(
  709. {
  710. "image": image_observation_space,
  711. "direction": spaces.Discrete(4),
  712. "mission": mission_space,
  713. }
  714. )
  715. # Range of possible rewards
  716. self.reward_range = (0, 1)
  717. self.window: Window = None
  718. # Environment configuration
  719. self.width = width
  720. self.height = height
  721. self.max_steps = max_steps
  722. self.see_through_walls = see_through_walls
  723. # Current position and direction of the agent
  724. self.agent_pos: np.ndarray = None
  725. self.agent_dir: int = None
  726. # Current grid and mission and carryinh
  727. self.grid = Grid(width, height)
  728. self.carrying = None
  729. frame_rendering = partial(
  730. self._render, highlight=highlight, tile_size=tile_size, agent_pov=agent_pov
  731. )
  732. self.renderer = Renderer(self.render_mode, frame_rendering)
  733. # Initialize the state
  734. self.reset()
  735. def reset(self, *, seed=None, return_info=False, options=None):
  736. super().reset(seed=seed)
  737. # Reinitialize episode-specific variables
  738. self.agent_pos = (-1, -1)
  739. self.agent_dir = -1
  740. # Generate a new random grid at the start of each episode
  741. self._gen_grid(self.width, self.height)
  742. # These fields should be defined by _gen_grid
  743. assert (
  744. self.agent_pos >= (0, 0)
  745. if isinstance(self.agent_pos, tuple)
  746. else all(self.agent_pos >= 0) and self.agent_dir >= 0
  747. )
  748. # Check that the agent doesn't overlap with an object
  749. start_cell = self.grid.get(*self.agent_pos)
  750. assert start_cell is None or start_cell.can_overlap()
  751. # Item picked up, being carried, initially nothing
  752. self.carrying = None
  753. # Step count since episode start
  754. self.step_count = 0
  755. # Return first observation
  756. obs = self.gen_obs()
  757. # Reset Renderer
  758. self.renderer.reset()
  759. self.renderer.render_step()
  760. if not return_info:
  761. return obs
  762. else:
  763. return obs, {}
  764. def hash(self, size=16):
  765. """Compute a hash that uniquely identifies the current state of the environment.
  766. :param size: Size of the hashing
  767. """
  768. sample_hash = hashlib.sha256()
  769. to_encode = [self.grid.encode().tolist(), self.agent_pos, self.agent_dir]
  770. for item in to_encode:
  771. sample_hash.update(str(item).encode("utf8"))
  772. return sample_hash.hexdigest()[:size]
  773. @property
  774. def steps_remaining(self):
  775. return self.max_steps - self.step_count
  776. def __str__(self):
  777. """
  778. Produce a pretty string of the environment's grid along with the agent.
  779. A grid cell is represented by 2-character string, the first one for
  780. the object and the second one for the color.
  781. """
  782. # Map of object types to short string
  783. OBJECT_TO_STR = {
  784. "wall": "W",
  785. "floor": "F",
  786. "door": "D",
  787. "key": "K",
  788. "ball": "A",
  789. "box": "B",
  790. "goal": "G",
  791. "lava": "V",
  792. }
  793. # Map agent's direction to short string
  794. AGENT_DIR_TO_STR = {0: ">", 1: "V", 2: "<", 3: "^"}
  795. str = ""
  796. for j in range(self.grid.height):
  797. for i in range(self.grid.width):
  798. if i == self.agent_pos[0] and j == self.agent_pos[1]:
  799. str += 2 * AGENT_DIR_TO_STR[self.agent_dir]
  800. continue
  801. c = self.grid.get(i, j)
  802. if c is None:
  803. str += " "
  804. continue
  805. if c.type == "door":
  806. if c.is_open:
  807. str += "__"
  808. elif c.is_locked:
  809. str += "L" + c.color[0].upper()
  810. else:
  811. str += "D" + c.color[0].upper()
  812. continue
  813. str += OBJECT_TO_STR[c.type] + c.color[0].upper()
  814. if j < self.grid.height - 1:
  815. str += "\n"
  816. return str
  817. @abstractmethod
  818. def _gen_grid(self, width, height):
  819. pass
  820. def _reward(self):
  821. """
  822. Compute the reward to be given upon success
  823. """
  824. return 1 - 0.9 * (self.step_count / self.max_steps)
  825. def _rand_int(self, low, high):
  826. """
  827. Generate random integer in [low,high[
  828. """
  829. return self.np_random.integers(low, high)
  830. def _rand_float(self, low, high):
  831. """
  832. Generate random float in [low,high[
  833. """
  834. return self.np_random.uniform(low, high)
  835. def _rand_bool(self):
  836. """
  837. Generate random boolean value
  838. """
  839. return self.np_random.integers(0, 2) == 0
  840. def _rand_elem(self, iterable):
  841. """
  842. Pick a random element in a list
  843. """
  844. lst = list(iterable)
  845. idx = self._rand_int(0, len(lst))
  846. return lst[idx]
  847. def _rand_subset(self, iterable, num_elems):
  848. """
  849. Sample a random subset of distinct elements of a list
  850. """
  851. lst = list(iterable)
  852. assert num_elems <= len(lst)
  853. out = []
  854. while len(out) < num_elems:
  855. elem = self._rand_elem(lst)
  856. lst.remove(elem)
  857. out.append(elem)
  858. return out
  859. def _rand_color(self):
  860. """
  861. Generate a random color name (string)
  862. """
  863. return self._rand_elem(COLOR_NAMES)
  864. def _rand_pos(self, xLow, xHigh, yLow, yHigh):
  865. """
  866. Generate a random (x,y) position tuple
  867. """
  868. return (
  869. self.np_random.integers(xLow, xHigh),
  870. self.np_random.integers(yLow, yHigh),
  871. )
  872. def place_obj(self, obj, top=None, size=None, reject_fn=None, max_tries=math.inf):
  873. """
  874. Place an object at an empty position in the grid
  875. :param top: top-left position of the rectangle where to place
  876. :param size: size of the rectangle where to place
  877. :param reject_fn: function to filter out potential positions
  878. """
  879. if top is None:
  880. top = (0, 0)
  881. else:
  882. top = (max(top[0], 0), max(top[1], 0))
  883. if size is None:
  884. size = (self.grid.width, self.grid.height)
  885. num_tries = 0
  886. while True:
  887. # This is to handle with rare cases where rejection sampling
  888. # gets stuck in an infinite loop
  889. if num_tries > max_tries:
  890. raise RecursionError("rejection sampling failed in place_obj")
  891. num_tries += 1
  892. pos = np.array(
  893. (
  894. self._rand_int(top[0], min(top[0] + size[0], self.grid.width)),
  895. self._rand_int(top[1], min(top[1] + size[1], self.grid.height)),
  896. )
  897. )
  898. pos = tuple(pos)
  899. # Don't place the object on top of another object
  900. if self.grid.get(*pos) is not None:
  901. continue
  902. # Don't place the object where the agent is
  903. if np.array_equal(pos, self.agent_pos):
  904. continue
  905. # Check if there is a filtering criterion
  906. if reject_fn and reject_fn(self, pos):
  907. continue
  908. break
  909. self.grid.set(pos[0], pos[1], obj)
  910. if obj is not None:
  911. obj.init_pos = pos
  912. obj.cur_pos = pos
  913. return pos
  914. def put_obj(self, obj, i, j):
  915. """
  916. Put an object at a specific position in the grid
  917. """
  918. self.grid.set(i, j, obj)
  919. obj.init_pos = (i, j)
  920. obj.cur_pos = (i, j)
  921. def place_agent(self, top=None, size=None, rand_dir=True, max_tries=math.inf):
  922. """
  923. Set the agent's starting point at an empty position in the grid
  924. """
  925. self.agent_pos = (-1, -1)
  926. pos = self.place_obj(None, top, size, max_tries=max_tries)
  927. self.agent_pos = pos
  928. if rand_dir:
  929. self.agent_dir = self._rand_int(0, 4)
  930. return pos
  931. @property
  932. def dir_vec(self):
  933. """
  934. Get the direction vector for the agent, pointing in the direction
  935. of forward movement.
  936. """
  937. assert self.agent_dir >= 0 and self.agent_dir < 4
  938. return DIR_TO_VEC[self.agent_dir]
  939. @property
  940. def right_vec(self):
  941. """
  942. Get the vector pointing to the right of the agent.
  943. """
  944. dx, dy = self.dir_vec
  945. return np.array((-dy, dx))
  946. @property
  947. def front_pos(self):
  948. """
  949. Get the position of the cell that is right in front of the agent
  950. """
  951. return self.agent_pos + self.dir_vec
  952. def get_view_coords(self, i, j):
  953. """
  954. Translate and rotate absolute grid coordinates (i, j) into the
  955. agent's partially observable view (sub-grid). Note that the resulting
  956. coordinates may be negative or outside of the agent's view size.
  957. """
  958. ax, ay = self.agent_pos
  959. dx, dy = self.dir_vec
  960. rx, ry = self.right_vec
  961. # Compute the absolute coordinates of the top-left view corner
  962. sz = self.agent_view_size
  963. hs = self.agent_view_size // 2
  964. tx = ax + (dx * (sz - 1)) - (rx * hs)
  965. ty = ay + (dy * (sz - 1)) - (ry * hs)
  966. lx = i - tx
  967. ly = j - ty
  968. # Project the coordinates of the object relative to the top-left
  969. # corner onto the agent's own coordinate system
  970. vx = rx * lx + ry * ly
  971. vy = -(dx * lx + dy * ly)
  972. return vx, vy
  973. def get_view_exts(self, agent_view_size=None):
  974. """
  975. Get the extents of the square set of tiles visible to the agent
  976. Note: the bottom extent indices are not included in the set
  977. if agent_view_size is None, use self.agent_view_size
  978. """
  979. agent_view_size = agent_view_size or self.agent_view_size
  980. # Facing right
  981. if self.agent_dir == 0:
  982. topX = self.agent_pos[0]
  983. topY = self.agent_pos[1] - agent_view_size // 2
  984. # Facing down
  985. elif self.agent_dir == 1:
  986. topX = self.agent_pos[0] - agent_view_size // 2
  987. topY = self.agent_pos[1]
  988. # Facing left
  989. elif self.agent_dir == 2:
  990. topX = self.agent_pos[0] - agent_view_size + 1
  991. topY = self.agent_pos[1] - agent_view_size // 2
  992. # Facing up
  993. elif self.agent_dir == 3:
  994. topX = self.agent_pos[0] - agent_view_size // 2
  995. topY = self.agent_pos[1] - agent_view_size + 1
  996. else:
  997. assert False, "invalid agent direction"
  998. botX = topX + agent_view_size
  999. botY = topY + agent_view_size
  1000. return (topX, topY, botX, botY)
  1001. def relative_coords(self, x, y):
  1002. """
  1003. Check if a grid position belongs to the agent's field of view, and returns the corresponding coordinates
  1004. """
  1005. vx, vy = self.get_view_coords(x, y)
  1006. if vx < 0 or vy < 0 or vx >= self.agent_view_size or vy >= self.agent_view_size:
  1007. return None
  1008. return vx, vy
  1009. def in_view(self, x, y):
  1010. """
  1011. check if a grid position is visible to the agent
  1012. """
  1013. return self.relative_coords(x, y) is not None
  1014. def agent_sees(self, x, y):
  1015. """
  1016. Check if a non-empty grid position is visible to the agent
  1017. """
  1018. coordinates = self.relative_coords(x, y)
  1019. if coordinates is None:
  1020. return False
  1021. vx, vy = coordinates
  1022. obs = self.gen_obs()
  1023. obs_grid, _ = Grid.decode(obs["image"])
  1024. obs_cell = obs_grid.get(vx, vy)
  1025. world_cell = self.grid.get(x, y)
  1026. assert world_cell is not None
  1027. return obs_cell is not None and obs_cell.type == world_cell.type
  1028. def step(self, action):
  1029. self.step_count += 1
  1030. reward = 0
  1031. terminated = False
  1032. truncated = False
  1033. # Get the position in front of the agent
  1034. fwd_pos = self.front_pos
  1035. # Get the contents of the cell in front of the agent
  1036. fwd_cell = self.grid.get(*fwd_pos)
  1037. # Rotate left
  1038. if action == self.actions.left:
  1039. self.agent_dir -= 1
  1040. if self.agent_dir < 0:
  1041. self.agent_dir += 4
  1042. # Rotate right
  1043. elif action == self.actions.right:
  1044. self.agent_dir = (self.agent_dir + 1) % 4
  1045. # Move forward
  1046. elif action == self.actions.forward:
  1047. if fwd_cell is None or fwd_cell.can_overlap():
  1048. self.agent_pos = tuple(fwd_pos)
  1049. if fwd_cell is not None and fwd_cell.type == "goal":
  1050. terminated = True
  1051. reward = self._reward()
  1052. if fwd_cell is not None and fwd_cell.type == "lava":
  1053. terminated = True
  1054. # Pick up an object
  1055. elif action == self.actions.pickup:
  1056. if fwd_cell and fwd_cell.can_pickup():
  1057. if self.carrying is None:
  1058. self.carrying = fwd_cell
  1059. self.carrying.cur_pos = np.array([-1, -1])
  1060. self.grid.set(fwd_pos[0], fwd_pos[1], None)
  1061. # Drop an object
  1062. elif action == self.actions.drop:
  1063. if not fwd_cell and self.carrying:
  1064. self.grid.set(fwd_pos[0], fwd_pos[1], self.carrying)
  1065. self.carrying.cur_pos = fwd_pos
  1066. self.carrying = None
  1067. # Toggle/activate an object
  1068. elif action == self.actions.toggle:
  1069. if fwd_cell:
  1070. fwd_cell.toggle(self, fwd_pos)
  1071. # Done action (not used by default)
  1072. elif action == self.actions.done:
  1073. pass
  1074. else:
  1075. raise ValueError(f"Unknown action: {action}")
  1076. if self.step_count >= self.max_steps:
  1077. truncated = True
  1078. obs = self.gen_obs()
  1079. # Reset Renderer
  1080. self.renderer.render_step()
  1081. return obs, reward, terminated, truncated, {}
  1082. def gen_obs_grid(self, agent_view_size=None):
  1083. """
  1084. Generate the sub-grid observed by the agent.
  1085. This method also outputs a visibility mask telling us which grid
  1086. cells the agent can actually see.
  1087. if agent_view_size is None, self.agent_view_size is used
  1088. """
  1089. topX, topY, botX, botY = self.get_view_exts(agent_view_size)
  1090. agent_view_size = agent_view_size or self.agent_view_size
  1091. grid = self.grid.slice(topX, topY, agent_view_size, agent_view_size)
  1092. for i in range(self.agent_dir + 1):
  1093. grid = grid.rotate_left()
  1094. # Process occluders and visibility
  1095. # Note that this incurs some performance cost
  1096. if not self.see_through_walls:
  1097. vis_mask = grid.process_vis(
  1098. agent_pos=(agent_view_size // 2, agent_view_size - 1)
  1099. )
  1100. else:
  1101. vis_mask = np.ones(shape=(grid.width, grid.height), dtype=bool)
  1102. # Make it so the agent sees what it's carrying
  1103. # We do this by placing the carried object at the agent's position
  1104. # in the agent's partially observable view
  1105. agent_pos = grid.width // 2, grid.height - 1
  1106. if self.carrying:
  1107. grid.set(*agent_pos, self.carrying)
  1108. else:
  1109. grid.set(*agent_pos, None)
  1110. return grid, vis_mask
  1111. def gen_obs(self):
  1112. """
  1113. Generate the agent's view (partially observable, low-resolution encoding)
  1114. """
  1115. grid, vis_mask = self.gen_obs_grid()
  1116. # Encode the partially observable view into a numpy array
  1117. image = grid.encode(vis_mask)
  1118. # Observations are dictionaries containing:
  1119. # - an image (partially observable view of the environment)
  1120. # - the agent's direction/orientation (acting as a compass)
  1121. # - a textual mission string (instructions for the agent)
  1122. obs = {"image": image, "direction": self.agent_dir, "mission": self.mission}
  1123. return obs
  1124. def get_pov_render(self, tile_size):
  1125. """
  1126. Render an agent's POV observation for visualization
  1127. """
  1128. grid, vis_mask = self.gen_obs_grid()
  1129. # Render the whole grid
  1130. img = grid.render(
  1131. tile_size,
  1132. agent_pos=(self.agent_view_size // 2, self.agent_view_size - 1),
  1133. agent_dir=3,
  1134. highlight_mask=vis_mask,
  1135. )
  1136. return img
  1137. def get_full_render(self, highlight, tile_size):
  1138. """
  1139. Render a non-paratial observation for visualization
  1140. """
  1141. # Compute which cells are visible to the agent
  1142. _, vis_mask = self.gen_obs_grid()
  1143. # Compute the world coordinates of the bottom-left corner
  1144. # of the agent's view area
  1145. f_vec = self.dir_vec
  1146. r_vec = self.right_vec
  1147. top_left = (
  1148. self.agent_pos
  1149. + f_vec * (self.agent_view_size - 1)
  1150. - r_vec * (self.agent_view_size // 2)
  1151. )
  1152. # Mask of which cells to highlight
  1153. highlight_mask = np.zeros(shape=(self.width, self.height), dtype=bool)
  1154. # For each cell in the visibility mask
  1155. for vis_j in range(0, self.agent_view_size):
  1156. for vis_i in range(0, self.agent_view_size):
  1157. # If this cell is not visible, don't highlight it
  1158. if not vis_mask[vis_i, vis_j]:
  1159. continue
  1160. # Compute the world coordinates of this cell
  1161. abs_i, abs_j = top_left - (f_vec * vis_j) + (r_vec * vis_i)
  1162. if abs_i < 0 or abs_i >= self.width:
  1163. continue
  1164. if abs_j < 0 or abs_j >= self.height:
  1165. continue
  1166. # Mark this cell to be highlighted
  1167. highlight_mask[abs_i, abs_j] = True
  1168. # Render the whole grid
  1169. img = self.grid.render(
  1170. tile_size,
  1171. self.agent_pos,
  1172. self.agent_dir,
  1173. highlight_mask=highlight_mask if highlight else None,
  1174. )
  1175. return img
  1176. def get_frame(
  1177. self,
  1178. highlight: bool = True,
  1179. tile_size: int = TILE_PIXELS,
  1180. agent_pov: bool = False,
  1181. ):
  1182. "Returns an image corresponding to the whole environment or the agent's pov"
  1183. if agent_pov:
  1184. return self.get_pov_render(tile_size)
  1185. else:
  1186. return self.get_full_render(highlight, tile_size)
  1187. def _render(self, mode: str = "human", **kwargs):
  1188. """
  1189. Render the whole-grid human view
  1190. """
  1191. img = self.get_frame(**kwargs)
  1192. mode = self.render_mode
  1193. if mode == "human":
  1194. if self.window is None:
  1195. self.window = Window("gym_minigrid")
  1196. self.window.show(block=False)
  1197. self.window.set_caption(self.mission)
  1198. self.window.show_img(img)
  1199. else:
  1200. return img
  1201. def render(
  1202. self,
  1203. mode: str = "human",
  1204. highlight: Optional[bool] = None,
  1205. tile_size: Optional[int] = None,
  1206. agent_pov: Optional[bool] = None,
  1207. ):
  1208. if self.render_mode is not None:
  1209. assert (
  1210. highlight is None and tile_size is None and agent_pov is None
  1211. ), "Unexpected argument for render. Specify render arguments at environment initialization."
  1212. return self.renderer.get_renders()
  1213. else:
  1214. highlight = highlight if highlight is not None else True
  1215. tile_size = tile_size if tile_size is not None else TILE_PIXELS
  1216. agent_pov = agent_pov if agent_pov is not None else False
  1217. return self._render(
  1218. mode=mode, highlight=highlight, tile_size=tile_size, agent_pov=agent_pov
  1219. )
  1220. def close(self):
  1221. if self.window:
  1222. self.window.close()