minigrid.py 46 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538
  1. import hashlib
  2. import math
  3. from abc import abstractmethod
  4. from enum import IntEnum
  5. from typing import Any, Callable, Optional, Union
  6. import gymnasium as gym
  7. import numpy as np
  8. from gymnasium import spaces
  9. from gymnasium.utils import seeding
  10. # Size in pixels of a tile in the full-scale human view
  11. from minigrid.rendering import (
  12. downsample,
  13. fill_coords,
  14. highlight_img,
  15. point_in_circle,
  16. point_in_line,
  17. point_in_rect,
  18. point_in_triangle,
  19. rotate_fn,
  20. )
  21. from minigrid.window import Window
  22. TILE_PIXELS = 32
  23. # Map of color names to RGB values
  24. COLORS = {
  25. "red": np.array([255, 0, 0]),
  26. "green": np.array([0, 255, 0]),
  27. "blue": np.array([0, 0, 255]),
  28. "purple": np.array([112, 39, 195]),
  29. "yellow": np.array([255, 255, 0]),
  30. "grey": np.array([100, 100, 100]),
  31. }
  32. COLOR_NAMES = sorted(list(COLORS.keys()))
  33. # Used to map colors to integers
  34. COLOR_TO_IDX = {"red": 0, "green": 1, "blue": 2, "purple": 3, "yellow": 4, "grey": 5}
  35. IDX_TO_COLOR = dict(zip(COLOR_TO_IDX.values(), COLOR_TO_IDX.keys()))
  36. # Map of object type to integers
  37. OBJECT_TO_IDX = {
  38. "unseen": 0,
  39. "empty": 1,
  40. "wall": 2,
  41. "floor": 3,
  42. "door": 4,
  43. "key": 5,
  44. "ball": 6,
  45. "box": 7,
  46. "goal": 8,
  47. "lava": 9,
  48. "agent": 10,
  49. }
  50. IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys()))
  51. # Map of state names to integers
  52. STATE_TO_IDX = {
  53. "open": 0,
  54. "closed": 1,
  55. "locked": 2,
  56. }
  57. # Map of agent direction indices to vectors
  58. DIR_TO_VEC = [
  59. # Pointing right (positive X)
  60. np.array((1, 0)),
  61. # Down (positive Y)
  62. np.array((0, 1)),
  63. # Pointing left (negative X)
  64. np.array((-1, 0)),
  65. # Up (negative Y)
  66. np.array((0, -1)),
  67. ]
  68. def check_if_no_duplicate(duplicate_list: list) -> bool:
  69. """Check if given list contains any duplicates"""
  70. return len(set(duplicate_list)) == len(duplicate_list)
  71. class MissionSpace(spaces.Space[str]):
  72. r"""A space representing a mission for the Gym-Minigrid environments.
  73. The space allows generating random mission strings constructed with an input placeholder list.
  74. Example Usage::
  75. >>> observation_space = MissionSpace(mission_func=lambda color: f"Get the {color} ball.",
  76. ordered_placeholders=[["green", "blue"]])
  77. >>> observation_space.sample()
  78. "Get the green ball."
  79. >>> observation_space = MissionSpace(mission_func=lambda : "Get the ball.".,
  80. ordered_placeholders=None)
  81. >>> observation_space.sample()
  82. "Get the ball."
  83. """
  84. def __init__(
  85. self,
  86. mission_func: Callable[..., str],
  87. ordered_placeholders: Optional["list[list[str]]"] = None,
  88. seed: Optional[Union[int, seeding.RandomNumberGenerator]] = None,
  89. ):
  90. r"""Constructor of :class:`MissionSpace` space.
  91. Args:
  92. mission_func (lambda _placeholders(str): _mission(str)): Function that generates a mission string from random placeholders.
  93. ordered_placeholders (Optional["list[list[str]]"]): List of lists of placeholders ordered in placing order in the mission function mission_func.
  94. seed: seed: The seed for sampling from the space.
  95. """
  96. # Check that the ordered placeholders and mission function are well defined.
  97. if ordered_placeholders is not None:
  98. assert (
  99. len(ordered_placeholders) == mission_func.__code__.co_argcount
  100. ), f"The number of placeholders {len(ordered_placeholders)} is different from the number of parameters in the mission function {mission_func.__code__.co_argcount}."
  101. for placeholder_list in ordered_placeholders:
  102. assert check_if_no_duplicate(
  103. placeholder_list
  104. ), "Make sure that the placeholders don't have any duplicate values."
  105. else:
  106. assert (
  107. mission_func.__code__.co_argcount == 0
  108. ), f"If the ordered placeholders are {ordered_placeholders}, the mission function shouldn't have any parameters."
  109. self.ordered_placeholders = ordered_placeholders
  110. self.mission_func = mission_func
  111. super().__init__(dtype=str, seed=seed)
  112. # Check that mission_func returns a string
  113. sampled_mission = self.sample()
  114. assert isinstance(
  115. sampled_mission, str
  116. ), f"mission_func must return type str not {type(sampled_mission)}"
  117. def sample(self) -> str:
  118. """Sample a random mission string."""
  119. if self.ordered_placeholders is not None:
  120. placeholders = []
  121. for rand_var_list in self.ordered_placeholders:
  122. idx = self.np_random.integers(0, len(rand_var_list))
  123. placeholders.append(rand_var_list[idx])
  124. return self.mission_func(*placeholders)
  125. else:
  126. return self.mission_func()
  127. def contains(self, x: Any) -> bool:
  128. """Return boolean specifying if x is a valid member of this space."""
  129. # Store a list of all the placeholders from self.ordered_placeholders that appear in x
  130. if self.ordered_placeholders is not None:
  131. check_placeholder_list = []
  132. for placeholder_list in self.ordered_placeholders:
  133. for placeholder in placeholder_list:
  134. if placeholder in x:
  135. check_placeholder_list.append(placeholder)
  136. # Remove duplicates from the list
  137. check_placeholder_list = list(set(check_placeholder_list))
  138. start_id_placeholder = []
  139. end_id_placeholder = []
  140. # Get the starting and ending id of the identified placeholders with possible duplicates
  141. new_check_placeholder_list = []
  142. for placeholder in check_placeholder_list:
  143. new_start_id_placeholder = [
  144. i for i in range(len(x)) if x.startswith(placeholder, i)
  145. ]
  146. new_check_placeholder_list += [placeholder] * len(
  147. new_start_id_placeholder
  148. )
  149. end_id_placeholder += [
  150. start_id + len(placeholder) - 1
  151. for start_id in new_start_id_placeholder
  152. ]
  153. start_id_placeholder += new_start_id_placeholder
  154. # Order by starting id the placeholders
  155. ordered_placeholder_list = sorted(
  156. zip(
  157. start_id_placeholder, end_id_placeholder, new_check_placeholder_list
  158. )
  159. )
  160. # Check for repeated placeholders contained in each other
  161. remove_placeholder_id = []
  162. for i, placeholder_1 in enumerate(ordered_placeholder_list):
  163. starting_id = i + 1
  164. for j, placeholder_2 in enumerate(
  165. ordered_placeholder_list[starting_id:]
  166. ):
  167. # Check if place holder ids overlap and keep the longest
  168. if max(placeholder_1[0], placeholder_2[0]) < min(
  169. placeholder_1[1], placeholder_2[1]
  170. ):
  171. remove_placeholder = min(
  172. placeholder_1[2], placeholder_2[2], key=len
  173. )
  174. if remove_placeholder == placeholder_1[2]:
  175. remove_placeholder_id.append(i)
  176. else:
  177. remove_placeholder_id.append(i + j + 1)
  178. for id in remove_placeholder_id:
  179. del ordered_placeholder_list[id]
  180. final_placeholders = [
  181. placeholder[2] for placeholder in ordered_placeholder_list
  182. ]
  183. # Check that the identified final placeholders are in the same order as the original placeholders.
  184. for orered_placeholder, final_placeholder in zip(
  185. self.ordered_placeholders, final_placeholders
  186. ):
  187. if final_placeholder in orered_placeholder:
  188. continue
  189. else:
  190. return False
  191. try:
  192. mission_string_with_placeholders = self.mission_func(
  193. *final_placeholders
  194. )
  195. except Exception as e:
  196. print(
  197. f"{x} is not contained in MissionSpace due to the following exception: {e}"
  198. )
  199. return False
  200. return bool(mission_string_with_placeholders == x)
  201. else:
  202. return bool(self.mission_func() == x)
  203. def __repr__(self) -> str:
  204. """Gives a string representation of this space."""
  205. return f"MissionSpace({self.mission_func}, {self.ordered_placeholders})"
  206. def __eq__(self, other) -> bool:
  207. """Check whether ``other`` is equivalent to this instance."""
  208. if isinstance(other, MissionSpace):
  209. # Check that place holder lists are the same
  210. if self.ordered_placeholders is not None:
  211. # Check length
  212. if (len(self.order_placeholder) == len(other.order_placeholder)) and (
  213. all(
  214. set(i) == set(j)
  215. for i, j in zip(self.order_placeholder, other.order_placeholder)
  216. )
  217. ):
  218. # Check mission string is the same with dummy space placeholders
  219. test_placeholders = [""] * len(self.order_placeholder)
  220. mission = self.mission_func(*test_placeholders)
  221. other_mission = other.mission_func(*test_placeholders)
  222. return mission == other_mission
  223. else:
  224. # Check that other is also None
  225. if other.ordered_placeholders is None:
  226. # Check mission string is the same
  227. mission = self.mission_func()
  228. other_mission = other.mission_func()
  229. return mission == other_mission
  230. # If none of the statements above return then False
  231. return False
  232. class WorldObj:
  233. """
  234. Base class for grid world objects
  235. """
  236. def __init__(self, type, color):
  237. assert type in OBJECT_TO_IDX, type
  238. assert color in COLOR_TO_IDX, color
  239. self.type = type
  240. self.color = color
  241. self.contains = None
  242. # Initial position of the object
  243. self.init_pos = None
  244. # Current position of the object
  245. self.cur_pos = None
  246. def can_overlap(self):
  247. """Can the agent overlap with this?"""
  248. return False
  249. def can_pickup(self):
  250. """Can the agent pick this up?"""
  251. return False
  252. def can_contain(self):
  253. """Can this contain another object?"""
  254. return False
  255. def see_behind(self):
  256. """Can the agent see behind this object?"""
  257. return True
  258. def toggle(self, env, pos):
  259. """Method to trigger/toggle an action this object performs"""
  260. return False
  261. def encode(self):
  262. """Encode the a description of this object as a 3-tuple of integers"""
  263. return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], 0)
  264. @staticmethod
  265. def decode(type_idx, color_idx, state):
  266. """Create an object from a 3-tuple state description"""
  267. obj_type = IDX_TO_OBJECT[type_idx]
  268. color = IDX_TO_COLOR[color_idx]
  269. if obj_type == "empty" or obj_type == "unseen":
  270. return None
  271. # State, 0: open, 1: closed, 2: locked
  272. is_open = state == 0
  273. is_locked = state == 2
  274. if obj_type == "wall":
  275. v = Wall(color)
  276. elif obj_type == "floor":
  277. v = Floor(color)
  278. elif obj_type == "ball":
  279. v = Ball(color)
  280. elif obj_type == "key":
  281. v = Key(color)
  282. elif obj_type == "box":
  283. v = Box(color)
  284. elif obj_type == "door":
  285. v = Door(color, is_open, is_locked)
  286. elif obj_type == "goal":
  287. v = Goal()
  288. elif obj_type == "lava":
  289. v = Lava()
  290. else:
  291. assert False, "unknown object type in decode '%s'" % obj_type
  292. return v
  293. def render(self, r):
  294. """Draw this object with the given renderer"""
  295. raise NotImplementedError
  296. class Goal(WorldObj):
  297. def __init__(self):
  298. super().__init__("goal", "green")
  299. def can_overlap(self):
  300. return True
  301. def render(self, img):
  302. fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color])
  303. class Floor(WorldObj):
  304. """
  305. Colored floor tile the agent can walk over
  306. """
  307. def __init__(self, color="blue"):
  308. super().__init__("floor", color)
  309. def can_overlap(self):
  310. return True
  311. def render(self, img):
  312. # Give the floor a pale color
  313. color = COLORS[self.color] / 2
  314. fill_coords(img, point_in_rect(0.031, 1, 0.031, 1), color)
  315. class Lava(WorldObj):
  316. def __init__(self):
  317. super().__init__("lava", "red")
  318. def can_overlap(self):
  319. return True
  320. def render(self, img):
  321. c = (255, 128, 0)
  322. # Background color
  323. fill_coords(img, point_in_rect(0, 1, 0, 1), c)
  324. # Little waves
  325. for i in range(3):
  326. ylo = 0.3 + 0.2 * i
  327. yhi = 0.4 + 0.2 * i
  328. fill_coords(img, point_in_line(0.1, ylo, 0.3, yhi, r=0.03), (0, 0, 0))
  329. fill_coords(img, point_in_line(0.3, yhi, 0.5, ylo, r=0.03), (0, 0, 0))
  330. fill_coords(img, point_in_line(0.5, ylo, 0.7, yhi, r=0.03), (0, 0, 0))
  331. fill_coords(img, point_in_line(0.7, yhi, 0.9, ylo, r=0.03), (0, 0, 0))
  332. class Wall(WorldObj):
  333. def __init__(self, color="grey"):
  334. super().__init__("wall", color)
  335. def see_behind(self):
  336. return False
  337. def render(self, img):
  338. fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color])
  339. class Door(WorldObj):
  340. def __init__(self, color, is_open=False, is_locked=False):
  341. super().__init__("door", color)
  342. self.is_open = is_open
  343. self.is_locked = is_locked
  344. def can_overlap(self):
  345. """The agent can only walk over this cell when the door is open"""
  346. return self.is_open
  347. def see_behind(self):
  348. return self.is_open
  349. def toggle(self, env, pos):
  350. # If the player has the right key to open the door
  351. if self.is_locked:
  352. if isinstance(env.carrying, Key) and env.carrying.color == self.color:
  353. self.is_locked = False
  354. self.is_open = True
  355. return True
  356. return False
  357. self.is_open = not self.is_open
  358. return True
  359. def encode(self):
  360. """Encode the a description of this object as a 3-tuple of integers"""
  361. # State, 0: open, 1: closed, 2: locked
  362. if self.is_open:
  363. state = 0
  364. elif self.is_locked:
  365. state = 2
  366. # if door is closed and unlocked
  367. elif not self.is_open:
  368. state = 1
  369. else:
  370. raise ValueError(
  371. f"There is no possible state encoding for the state:\n -Door Open: {self.is_open}\n -Door Closed: {not self.is_open}\n -Door Locked: {self.is_locked}"
  372. )
  373. return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], state)
  374. def render(self, img):
  375. c = COLORS[self.color]
  376. if self.is_open:
  377. fill_coords(img, point_in_rect(0.88, 1.00, 0.00, 1.00), c)
  378. fill_coords(img, point_in_rect(0.92, 0.96, 0.04, 0.96), (0, 0, 0))
  379. return
  380. # Door frame and door
  381. if self.is_locked:
  382. fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)
  383. fill_coords(img, point_in_rect(0.06, 0.94, 0.06, 0.94), 0.45 * np.array(c))
  384. # Draw key slot
  385. fill_coords(img, point_in_rect(0.52, 0.75, 0.50, 0.56), c)
  386. else:
  387. fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)
  388. fill_coords(img, point_in_rect(0.04, 0.96, 0.04, 0.96), (0, 0, 0))
  389. fill_coords(img, point_in_rect(0.08, 0.92, 0.08, 0.92), c)
  390. fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), (0, 0, 0))
  391. # Draw door handle
  392. fill_coords(img, point_in_circle(cx=0.75, cy=0.50, r=0.08), c)
  393. class Key(WorldObj):
  394. def __init__(self, color="blue"):
  395. super().__init__("key", color)
  396. def can_pickup(self):
  397. return True
  398. def render(self, img):
  399. c = COLORS[self.color]
  400. # Vertical quad
  401. fill_coords(img, point_in_rect(0.50, 0.63, 0.31, 0.88), c)
  402. # Teeth
  403. fill_coords(img, point_in_rect(0.38, 0.50, 0.59, 0.66), c)
  404. fill_coords(img, point_in_rect(0.38, 0.50, 0.81, 0.88), c)
  405. # Ring
  406. fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.190), c)
  407. fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.064), (0, 0, 0))
  408. class Ball(WorldObj):
  409. def __init__(self, color="blue"):
  410. super().__init__("ball", color)
  411. def can_pickup(self):
  412. return True
  413. def render(self, img):
  414. fill_coords(img, point_in_circle(0.5, 0.5, 0.31), COLORS[self.color])
  415. class Box(WorldObj):
  416. def __init__(self, color, contains=None):
  417. super().__init__("box", color)
  418. self.contains = contains
  419. def can_pickup(self):
  420. return True
  421. def render(self, img):
  422. c = COLORS[self.color]
  423. # Outline
  424. fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), c)
  425. fill_coords(img, point_in_rect(0.18, 0.82, 0.18, 0.82), (0, 0, 0))
  426. # Horizontal slit
  427. fill_coords(img, point_in_rect(0.16, 0.84, 0.47, 0.53), c)
  428. def toggle(self, env, pos):
  429. # Replace the box by its contents
  430. env.grid.set(pos[0], pos[1], self.contains)
  431. return True
  432. class Grid:
  433. """
  434. Represent a grid and operations on it
  435. """
  436. # Static cache of pre-renderer tiles
  437. tile_cache = {}
  438. def __init__(self, width, height):
  439. assert width >= 3
  440. assert height >= 3
  441. self.width = width
  442. self.height = height
  443. self.grid = [None] * width * height
  444. def __contains__(self, key):
  445. if isinstance(key, WorldObj):
  446. for e in self.grid:
  447. if e is key:
  448. return True
  449. elif isinstance(key, tuple):
  450. for e in self.grid:
  451. if e is None:
  452. continue
  453. if (e.color, e.type) == key:
  454. return True
  455. if key[0] is None and key[1] == e.type:
  456. return True
  457. return False
  458. def __eq__(self, other):
  459. grid1 = self.encode()
  460. grid2 = other.encode()
  461. return np.array_equal(grid2, grid1)
  462. def __ne__(self, other):
  463. return not self == other
  464. def copy(self):
  465. from copy import deepcopy
  466. return deepcopy(self)
  467. def set(self, i, j, v):
  468. assert i >= 0 and i < self.width
  469. assert j >= 0 and j < self.height
  470. self.grid[j * self.width + i] = v
  471. def get(self, i, j):
  472. assert i >= 0 and i < self.width
  473. assert j >= 0 and j < self.height
  474. return self.grid[j * self.width + i]
  475. def horz_wall(self, x, y, length=None, obj_type=Wall):
  476. if length is None:
  477. length = self.width - x
  478. for i in range(0, length):
  479. self.set(x + i, y, obj_type())
  480. def vert_wall(self, x, y, length=None, obj_type=Wall):
  481. if length is None:
  482. length = self.height - y
  483. for j in range(0, length):
  484. self.set(x, y + j, obj_type())
  485. def wall_rect(self, x, y, w, h):
  486. self.horz_wall(x, y, w)
  487. self.horz_wall(x, y + h - 1, w)
  488. self.vert_wall(x, y, h)
  489. self.vert_wall(x + w - 1, y, h)
  490. def rotate_left(self):
  491. """
  492. Rotate the grid to the left (counter-clockwise)
  493. """
  494. grid = Grid(self.height, self.width)
  495. for i in range(self.width):
  496. for j in range(self.height):
  497. v = self.get(i, j)
  498. grid.set(j, grid.height - 1 - i, v)
  499. return grid
  500. def slice(self, topX, topY, width, height):
  501. """
  502. Get a subset of the grid
  503. """
  504. grid = Grid(width, height)
  505. for j in range(0, height):
  506. for i in range(0, width):
  507. x = topX + i
  508. y = topY + j
  509. if x >= 0 and x < self.width and y >= 0 and y < self.height:
  510. v = self.get(x, y)
  511. else:
  512. v = Wall()
  513. grid.set(i, j, v)
  514. return grid
  515. @classmethod
  516. def render_tile(
  517. cls, obj, agent_dir=None, highlight=False, tile_size=TILE_PIXELS, subdivs=3
  518. ):
  519. """
  520. Render a tile and cache the result
  521. """
  522. # Hash map lookup key for the cache
  523. key = (agent_dir, highlight, tile_size)
  524. key = obj.encode() + key if obj else key
  525. if key in cls.tile_cache:
  526. return cls.tile_cache[key]
  527. img = np.zeros(
  528. shape=(tile_size * subdivs, tile_size * subdivs, 3), dtype=np.uint8
  529. )
  530. # Draw the grid lines (top and left edges)
  531. fill_coords(img, point_in_rect(0, 0.031, 0, 1), (100, 100, 100))
  532. fill_coords(img, point_in_rect(0, 1, 0, 0.031), (100, 100, 100))
  533. if obj is not None:
  534. obj.render(img)
  535. # Overlay the agent on top
  536. if agent_dir is not None:
  537. tri_fn = point_in_triangle(
  538. (0.12, 0.19),
  539. (0.87, 0.50),
  540. (0.12, 0.81),
  541. )
  542. # Rotate the agent based on its direction
  543. tri_fn = rotate_fn(tri_fn, cx=0.5, cy=0.5, theta=0.5 * math.pi * agent_dir)
  544. fill_coords(img, tri_fn, (255, 0, 0))
  545. # Highlight the cell if needed
  546. if highlight:
  547. highlight_img(img)
  548. # Downsample the image to perform supersampling/anti-aliasing
  549. img = downsample(img, subdivs)
  550. # Cache the rendered tile
  551. cls.tile_cache[key] = img
  552. return img
  553. def render(self, tile_size, agent_pos, agent_dir=None, highlight_mask=None):
  554. """
  555. Render this grid at a given scale
  556. :param r: target renderer object
  557. :param tile_size: tile size in pixels
  558. """
  559. if highlight_mask is None:
  560. highlight_mask = np.zeros(shape=(self.width, self.height), dtype=bool)
  561. # Compute the total grid size
  562. width_px = self.width * tile_size
  563. height_px = self.height * tile_size
  564. img = np.zeros(shape=(height_px, width_px, 3), dtype=np.uint8)
  565. # Render the grid
  566. for j in range(0, self.height):
  567. for i in range(0, self.width):
  568. cell = self.get(i, j)
  569. agent_here = np.array_equal(agent_pos, (i, j))
  570. tile_img = Grid.render_tile(
  571. cell,
  572. agent_dir=agent_dir if agent_here else None,
  573. highlight=highlight_mask[i, j],
  574. tile_size=tile_size,
  575. )
  576. ymin = j * tile_size
  577. ymax = (j + 1) * tile_size
  578. xmin = i * tile_size
  579. xmax = (i + 1) * tile_size
  580. img[ymin:ymax, xmin:xmax, :] = tile_img
  581. return img
  582. def encode(self, vis_mask=None):
  583. """
  584. Produce a compact numpy encoding of the grid
  585. """
  586. if vis_mask is None:
  587. vis_mask = np.ones((self.width, self.height), dtype=bool)
  588. array = np.zeros((self.width, self.height, 3), dtype="uint8")
  589. for i in range(self.width):
  590. for j in range(self.height):
  591. if vis_mask[i, j]:
  592. v = self.get(i, j)
  593. if v is None:
  594. array[i, j, 0] = OBJECT_TO_IDX["empty"]
  595. array[i, j, 1] = 0
  596. array[i, j, 2] = 0
  597. else:
  598. array[i, j, :] = v.encode()
  599. return array
  600. @staticmethod
  601. def decode(array):
  602. """
  603. Decode an array grid encoding back into a grid
  604. """
  605. width, height, channels = array.shape
  606. assert channels == 3
  607. vis_mask = np.ones(shape=(width, height), dtype=bool)
  608. grid = Grid(width, height)
  609. for i in range(width):
  610. for j in range(height):
  611. type_idx, color_idx, state = array[i, j]
  612. v = WorldObj.decode(type_idx, color_idx, state)
  613. grid.set(i, j, v)
  614. vis_mask[i, j] = type_idx != OBJECT_TO_IDX["unseen"]
  615. return grid, vis_mask
  616. def process_vis(self, agent_pos):
  617. mask = np.zeros(shape=(self.width, self.height), dtype=bool)
  618. mask[agent_pos[0], agent_pos[1]] = True
  619. for j in reversed(range(0, self.height)):
  620. for i in range(0, self.width - 1):
  621. if not mask[i, j]:
  622. continue
  623. cell = self.get(i, j)
  624. if cell and not cell.see_behind():
  625. continue
  626. mask[i + 1, j] = True
  627. if j > 0:
  628. mask[i + 1, j - 1] = True
  629. mask[i, j - 1] = True
  630. for i in reversed(range(1, self.width)):
  631. if not mask[i, j]:
  632. continue
  633. cell = self.get(i, j)
  634. if cell and not cell.see_behind():
  635. continue
  636. mask[i - 1, j] = True
  637. if j > 0:
  638. mask[i - 1, j - 1] = True
  639. mask[i, j - 1] = True
  640. for j in range(0, self.height):
  641. for i in range(0, self.width):
  642. if not mask[i, j]:
  643. self.set(i, j, None)
  644. return mask
  645. class MiniGridEnv(gym.Env):
  646. """
  647. 2D grid world game environment
  648. """
  649. metadata = {
  650. "render_modes": ["human", "rgb_array"],
  651. "render_fps": 10,
  652. }
  653. # Enumeration of possible actions
  654. class Actions(IntEnum):
  655. # Turn left, turn right, move forward
  656. left = 0
  657. right = 1
  658. forward = 2
  659. # Pick up an object
  660. pickup = 3
  661. # Drop an object
  662. drop = 4
  663. # Toggle/activate an object
  664. toggle = 5
  665. # Done completing task
  666. done = 6
  667. def __init__(
  668. self,
  669. mission_space: MissionSpace,
  670. grid_size: int = None,
  671. width: int = None,
  672. height: int = None,
  673. max_steps: int = 100,
  674. see_through_walls: bool = False,
  675. agent_view_size: int = 7,
  676. render_mode: Optional[str] = None,
  677. highlight: bool = True,
  678. tile_size: int = TILE_PIXELS,
  679. agent_pov: bool = False,
  680. ):
  681. # Initialize mission
  682. self.mission = mission_space.sample()
  683. # Can't set both grid_size and width/height
  684. if grid_size:
  685. assert width is None and height is None
  686. width = grid_size
  687. height = grid_size
  688. # Action enumeration for this environment
  689. self.actions = MiniGridEnv.Actions
  690. # Actions are discrete integer values
  691. self.action_space = spaces.Discrete(len(self.actions))
  692. # Number of cells (width and height) in the agent view
  693. assert agent_view_size % 2 == 1
  694. assert agent_view_size >= 3
  695. self.agent_view_size = agent_view_size
  696. # Observations are dictionaries containing an
  697. # encoding of the grid and a textual 'mission' string
  698. image_observation_space = spaces.Box(
  699. low=0,
  700. high=255,
  701. shape=(self.agent_view_size, self.agent_view_size, 3),
  702. dtype="uint8",
  703. )
  704. self.observation_space = spaces.Dict(
  705. {
  706. "image": image_observation_space,
  707. "direction": spaces.Discrete(4),
  708. "mission": mission_space,
  709. }
  710. )
  711. # Range of possible rewards
  712. self.reward_range = (0, 1)
  713. self.window: Window = None
  714. # Environment configuration
  715. self.width = width
  716. self.height = height
  717. self.max_steps = max_steps
  718. self.see_through_walls = see_through_walls
  719. # Current position and direction of the agent
  720. self.agent_pos: np.ndarray = None
  721. self.agent_dir: int = None
  722. # Current grid and mission and carryinh
  723. self.grid = Grid(width, height)
  724. self.carrying = None
  725. # Rendering attributes
  726. self.render_mode = render_mode
  727. self.highlight = highlight
  728. self.tile_size = tile_size
  729. self.agent_pov = agent_pov
  730. def reset(self, *, seed=None, options=None):
  731. super().reset(seed=seed)
  732. # Reinitialize episode-specific variables
  733. self.agent_pos = (-1, -1)
  734. self.agent_dir = -1
  735. # Generate a new random grid at the start of each episode
  736. self._gen_grid(self.width, self.height)
  737. # These fields should be defined by _gen_grid
  738. assert (
  739. self.agent_pos >= (0, 0)
  740. if isinstance(self.agent_pos, tuple)
  741. else all(self.agent_pos >= 0) and self.agent_dir >= 0
  742. )
  743. # Check that the agent doesn't overlap with an object
  744. start_cell = self.grid.get(*self.agent_pos)
  745. assert start_cell is None or start_cell.can_overlap()
  746. # Item picked up, being carried, initially nothing
  747. self.carrying = None
  748. # Step count since episode start
  749. self.step_count = 0
  750. if self.render_mode == "human":
  751. self.render()
  752. # Return first observation
  753. obs = self.gen_obs()
  754. return obs, {}
  755. def hash(self, size=16):
  756. """Compute a hash that uniquely identifies the current state of the environment.
  757. :param size: Size of the hashing
  758. """
  759. sample_hash = hashlib.sha256()
  760. to_encode = [self.grid.encode().tolist(), self.agent_pos, self.agent_dir]
  761. for item in to_encode:
  762. sample_hash.update(str(item).encode("utf8"))
  763. return sample_hash.hexdigest()[:size]
  764. @property
  765. def steps_remaining(self):
  766. return self.max_steps - self.step_count
  767. def __str__(self):
  768. """
  769. Produce a pretty string of the environment's grid along with the agent.
  770. A grid cell is represented by 2-character string, the first one for
  771. the object and the second one for the color.
  772. """
  773. # Map of object types to short string
  774. OBJECT_TO_STR = {
  775. "wall": "W",
  776. "floor": "F",
  777. "door": "D",
  778. "key": "K",
  779. "ball": "A",
  780. "box": "B",
  781. "goal": "G",
  782. "lava": "V",
  783. }
  784. # Map agent's direction to short string
  785. AGENT_DIR_TO_STR = {0: ">", 1: "V", 2: "<", 3: "^"}
  786. str = ""
  787. for j in range(self.grid.height):
  788. for i in range(self.grid.width):
  789. if i == self.agent_pos[0] and j == self.agent_pos[1]:
  790. str += 2 * AGENT_DIR_TO_STR[self.agent_dir]
  791. continue
  792. c = self.grid.get(i, j)
  793. if c is None:
  794. str += " "
  795. continue
  796. if c.type == "door":
  797. if c.is_open:
  798. str += "__"
  799. elif c.is_locked:
  800. str += "L" + c.color[0].upper()
  801. else:
  802. str += "D" + c.color[0].upper()
  803. continue
  804. str += OBJECT_TO_STR[c.type] + c.color[0].upper()
  805. if j < self.grid.height - 1:
  806. str += "\n"
  807. return str
  808. @abstractmethod
  809. def _gen_grid(self, width, height):
  810. pass
  811. def _reward(self):
  812. """
  813. Compute the reward to be given upon success
  814. """
  815. return 1 - 0.9 * (self.step_count / self.max_steps)
  816. def _rand_int(self, low, high):
  817. """
  818. Generate random integer in [low,high[
  819. """
  820. return self.np_random.integers(low, high)
  821. def _rand_float(self, low, high):
  822. """
  823. Generate random float in [low,high[
  824. """
  825. return self.np_random.uniform(low, high)
  826. def _rand_bool(self):
  827. """
  828. Generate random boolean value
  829. """
  830. return self.np_random.integers(0, 2) == 0
  831. def _rand_elem(self, iterable):
  832. """
  833. Pick a random element in a list
  834. """
  835. lst = list(iterable)
  836. idx = self._rand_int(0, len(lst))
  837. return lst[idx]
  838. def _rand_subset(self, iterable, num_elems):
  839. """
  840. Sample a random subset of distinct elements of a list
  841. """
  842. lst = list(iterable)
  843. assert num_elems <= len(lst)
  844. out = []
  845. while len(out) < num_elems:
  846. elem = self._rand_elem(lst)
  847. lst.remove(elem)
  848. out.append(elem)
  849. return out
  850. def _rand_color(self):
  851. """
  852. Generate a random color name (string)
  853. """
  854. return self._rand_elem(COLOR_NAMES)
  855. def _rand_pos(self, xLow, xHigh, yLow, yHigh):
  856. """
  857. Generate a random (x,y) position tuple
  858. """
  859. return (
  860. self.np_random.integers(xLow, xHigh),
  861. self.np_random.integers(yLow, yHigh),
  862. )
  863. def place_obj(self, obj, top=None, size=None, reject_fn=None, max_tries=math.inf):
  864. """
  865. Place an object at an empty position in the grid
  866. :param top: top-left position of the rectangle where to place
  867. :param size: size of the rectangle where to place
  868. :param reject_fn: function to filter out potential positions
  869. """
  870. if top is None:
  871. top = (0, 0)
  872. else:
  873. top = (max(top[0], 0), max(top[1], 0))
  874. if size is None:
  875. size = (self.grid.width, self.grid.height)
  876. num_tries = 0
  877. while True:
  878. # This is to handle with rare cases where rejection sampling
  879. # gets stuck in an infinite loop
  880. if num_tries > max_tries:
  881. raise RecursionError("rejection sampling failed in place_obj")
  882. num_tries += 1
  883. pos = np.array(
  884. (
  885. self._rand_int(top[0], min(top[0] + size[0], self.grid.width)),
  886. self._rand_int(top[1], min(top[1] + size[1], self.grid.height)),
  887. )
  888. )
  889. pos = tuple(pos)
  890. # Don't place the object on top of another object
  891. if self.grid.get(*pos) is not None:
  892. continue
  893. # Don't place the object where the agent is
  894. if np.array_equal(pos, self.agent_pos):
  895. continue
  896. # Check if there is a filtering criterion
  897. if reject_fn and reject_fn(self, pos):
  898. continue
  899. break
  900. self.grid.set(pos[0], pos[1], obj)
  901. if obj is not None:
  902. obj.init_pos = pos
  903. obj.cur_pos = pos
  904. return pos
  905. def put_obj(self, obj, i, j):
  906. """
  907. Put an object at a specific position in the grid
  908. """
  909. self.grid.set(i, j, obj)
  910. obj.init_pos = (i, j)
  911. obj.cur_pos = (i, j)
  912. def place_agent(self, top=None, size=None, rand_dir=True, max_tries=math.inf):
  913. """
  914. Set the agent's starting point at an empty position in the grid
  915. """
  916. self.agent_pos = (-1, -1)
  917. pos = self.place_obj(None, top, size, max_tries=max_tries)
  918. self.agent_pos = pos
  919. if rand_dir:
  920. self.agent_dir = self._rand_int(0, 4)
  921. return pos
  922. @property
  923. def dir_vec(self):
  924. """
  925. Get the direction vector for the agent, pointing in the direction
  926. of forward movement.
  927. """
  928. assert self.agent_dir >= 0 and self.agent_dir < 4
  929. return DIR_TO_VEC[self.agent_dir]
  930. @property
  931. def right_vec(self):
  932. """
  933. Get the vector pointing to the right of the agent.
  934. """
  935. dx, dy = self.dir_vec
  936. return np.array((-dy, dx))
  937. @property
  938. def front_pos(self):
  939. """
  940. Get the position of the cell that is right in front of the agent
  941. """
  942. return self.agent_pos + self.dir_vec
  943. def get_view_coords(self, i, j):
  944. """
  945. Translate and rotate absolute grid coordinates (i, j) into the
  946. agent's partially observable view (sub-grid). Note that the resulting
  947. coordinates may be negative or outside of the agent's view size.
  948. """
  949. ax, ay = self.agent_pos
  950. dx, dy = self.dir_vec
  951. rx, ry = self.right_vec
  952. # Compute the absolute coordinates of the top-left view corner
  953. sz = self.agent_view_size
  954. hs = self.agent_view_size // 2
  955. tx = ax + (dx * (sz - 1)) - (rx * hs)
  956. ty = ay + (dy * (sz - 1)) - (ry * hs)
  957. lx = i - tx
  958. ly = j - ty
  959. # Project the coordinates of the object relative to the top-left
  960. # corner onto the agent's own coordinate system
  961. vx = rx * lx + ry * ly
  962. vy = -(dx * lx + dy * ly)
  963. return vx, vy
  964. def get_view_exts(self, agent_view_size=None):
  965. """
  966. Get the extents of the square set of tiles visible to the agent
  967. Note: the bottom extent indices are not included in the set
  968. if agent_view_size is None, use self.agent_view_size
  969. """
  970. agent_view_size = agent_view_size or self.agent_view_size
  971. # Facing right
  972. if self.agent_dir == 0:
  973. topX = self.agent_pos[0]
  974. topY = self.agent_pos[1] - agent_view_size // 2
  975. # Facing down
  976. elif self.agent_dir == 1:
  977. topX = self.agent_pos[0] - agent_view_size // 2
  978. topY = self.agent_pos[1]
  979. # Facing left
  980. elif self.agent_dir == 2:
  981. topX = self.agent_pos[0] - agent_view_size + 1
  982. topY = self.agent_pos[1] - agent_view_size // 2
  983. # Facing up
  984. elif self.agent_dir == 3:
  985. topX = self.agent_pos[0] - agent_view_size // 2
  986. topY = self.agent_pos[1] - agent_view_size + 1
  987. else:
  988. assert False, "invalid agent direction"
  989. botX = topX + agent_view_size
  990. botY = topY + agent_view_size
  991. return (topX, topY, botX, botY)
  992. def relative_coords(self, x, y):
  993. """
  994. Check if a grid position belongs to the agent's field of view, and returns the corresponding coordinates
  995. """
  996. vx, vy = self.get_view_coords(x, y)
  997. if vx < 0 or vy < 0 or vx >= self.agent_view_size or vy >= self.agent_view_size:
  998. return None
  999. return vx, vy
  1000. def in_view(self, x, y):
  1001. """
  1002. check if a grid position is visible to the agent
  1003. """
  1004. return self.relative_coords(x, y) is not None
  1005. def agent_sees(self, x, y):
  1006. """
  1007. Check if a non-empty grid position is visible to the agent
  1008. """
  1009. coordinates = self.relative_coords(x, y)
  1010. if coordinates is None:
  1011. return False
  1012. vx, vy = coordinates
  1013. obs = self.gen_obs()
  1014. obs_grid, _ = Grid.decode(obs["image"])
  1015. obs_cell = obs_grid.get(vx, vy)
  1016. world_cell = self.grid.get(x, y)
  1017. assert world_cell is not None
  1018. return obs_cell is not None and obs_cell.type == world_cell.type
  1019. def step(self, action):
  1020. self.step_count += 1
  1021. reward = 0
  1022. terminated = False
  1023. truncated = False
  1024. # Get the position in front of the agent
  1025. fwd_pos = self.front_pos
  1026. # Get the contents of the cell in front of the agent
  1027. fwd_cell = self.grid.get(*fwd_pos)
  1028. # Rotate left
  1029. if action == self.actions.left:
  1030. self.agent_dir -= 1
  1031. if self.agent_dir < 0:
  1032. self.agent_dir += 4
  1033. # Rotate right
  1034. elif action == self.actions.right:
  1035. self.agent_dir = (self.agent_dir + 1) % 4
  1036. # Move forward
  1037. elif action == self.actions.forward:
  1038. if fwd_cell is None or fwd_cell.can_overlap():
  1039. self.agent_pos = tuple(fwd_pos)
  1040. if fwd_cell is not None and fwd_cell.type == "goal":
  1041. terminated = True
  1042. reward = self._reward()
  1043. if fwd_cell is not None and fwd_cell.type == "lava":
  1044. terminated = True
  1045. # Pick up an object
  1046. elif action == self.actions.pickup:
  1047. if fwd_cell and fwd_cell.can_pickup():
  1048. if self.carrying is None:
  1049. self.carrying = fwd_cell
  1050. self.carrying.cur_pos = np.array([-1, -1])
  1051. self.grid.set(fwd_pos[0], fwd_pos[1], None)
  1052. # Drop an object
  1053. elif action == self.actions.drop:
  1054. if not fwd_cell and self.carrying:
  1055. self.grid.set(fwd_pos[0], fwd_pos[1], self.carrying)
  1056. self.carrying.cur_pos = fwd_pos
  1057. self.carrying = None
  1058. # Toggle/activate an object
  1059. elif action == self.actions.toggle:
  1060. if fwd_cell:
  1061. fwd_cell.toggle(self, fwd_pos)
  1062. # Done action (not used by default)
  1063. elif action == self.actions.done:
  1064. pass
  1065. else:
  1066. raise ValueError(f"Unknown action: {action}")
  1067. if self.step_count >= self.max_steps:
  1068. truncated = True
  1069. if self.render_mode == "human":
  1070. self.render()
  1071. obs = self.gen_obs()
  1072. return obs, reward, terminated, truncated, {}
  1073. def gen_obs_grid(self, agent_view_size=None):
  1074. """
  1075. Generate the sub-grid observed by the agent.
  1076. This method also outputs a visibility mask telling us which grid
  1077. cells the agent can actually see.
  1078. if agent_view_size is None, self.agent_view_size is used
  1079. """
  1080. topX, topY, botX, botY = self.get_view_exts(agent_view_size)
  1081. agent_view_size = agent_view_size or self.agent_view_size
  1082. grid = self.grid.slice(topX, topY, agent_view_size, agent_view_size)
  1083. for i in range(self.agent_dir + 1):
  1084. grid = grid.rotate_left()
  1085. # Process occluders and visibility
  1086. # Note that this incurs some performance cost
  1087. if not self.see_through_walls:
  1088. vis_mask = grid.process_vis(
  1089. agent_pos=(agent_view_size // 2, agent_view_size - 1)
  1090. )
  1091. else:
  1092. vis_mask = np.ones(shape=(grid.width, grid.height), dtype=bool)
  1093. # Make it so the agent sees what it's carrying
  1094. # We do this by placing the carried object at the agent's position
  1095. # in the agent's partially observable view
  1096. agent_pos = grid.width // 2, grid.height - 1
  1097. if self.carrying:
  1098. grid.set(*agent_pos, self.carrying)
  1099. else:
  1100. grid.set(*agent_pos, None)
  1101. return grid, vis_mask
  1102. def gen_obs(self):
  1103. """
  1104. Generate the agent's view (partially observable, low-resolution encoding)
  1105. """
  1106. grid, vis_mask = self.gen_obs_grid()
  1107. # Encode the partially observable view into a numpy array
  1108. image = grid.encode(vis_mask)
  1109. # Observations are dictionaries containing:
  1110. # - an image (partially observable view of the environment)
  1111. # - the agent's direction/orientation (acting as a compass)
  1112. # - a textual mission string (instructions for the agent)
  1113. obs = {"image": image, "direction": self.agent_dir, "mission": self.mission}
  1114. return obs
  1115. def get_pov_render(self, tile_size):
  1116. """
  1117. Render an agent's POV observation for visualization
  1118. """
  1119. grid, vis_mask = self.gen_obs_grid()
  1120. # Render the whole grid
  1121. img = grid.render(
  1122. tile_size,
  1123. agent_pos=(self.agent_view_size // 2, self.agent_view_size - 1),
  1124. agent_dir=3,
  1125. highlight_mask=vis_mask,
  1126. )
  1127. return img
  1128. def get_full_render(self, highlight, tile_size):
  1129. """
  1130. Render a non-paratial observation for visualization
  1131. """
  1132. # Compute which cells are visible to the agent
  1133. _, vis_mask = self.gen_obs_grid()
  1134. # Compute the world coordinates of the bottom-left corner
  1135. # of the agent's view area
  1136. f_vec = self.dir_vec
  1137. r_vec = self.right_vec
  1138. top_left = (
  1139. self.agent_pos
  1140. + f_vec * (self.agent_view_size - 1)
  1141. - r_vec * (self.agent_view_size // 2)
  1142. )
  1143. # Mask of which cells to highlight
  1144. highlight_mask = np.zeros(shape=(self.width, self.height), dtype=bool)
  1145. # For each cell in the visibility mask
  1146. for vis_j in range(0, self.agent_view_size):
  1147. for vis_i in range(0, self.agent_view_size):
  1148. # If this cell is not visible, don't highlight it
  1149. if not vis_mask[vis_i, vis_j]:
  1150. continue
  1151. # Compute the world coordinates of this cell
  1152. abs_i, abs_j = top_left - (f_vec * vis_j) + (r_vec * vis_i)
  1153. if abs_i < 0 or abs_i >= self.width:
  1154. continue
  1155. if abs_j < 0 or abs_j >= self.height:
  1156. continue
  1157. # Mark this cell to be highlighted
  1158. highlight_mask[abs_i, abs_j] = True
  1159. # Render the whole grid
  1160. img = self.grid.render(
  1161. tile_size,
  1162. self.agent_pos,
  1163. self.agent_dir,
  1164. highlight_mask=highlight_mask if highlight else None,
  1165. )
  1166. return img
  1167. def get_frame(
  1168. self,
  1169. highlight: bool = True,
  1170. tile_size: int = TILE_PIXELS,
  1171. agent_pov: bool = False,
  1172. ):
  1173. """Returns an RGB image corresponding to the whole environment or the agent's point of view.
  1174. Args:
  1175. highlight (bool): If true, the agent's field of view or point of view is highlighted with a lighter gray color.
  1176. tile_size (int): How many pixels will form a tile from the NxM grid.
  1177. agent_pov (bool): If true, the rendered frame will only contain the point of view of the agent.
  1178. Returns:
  1179. frame (np.ndarray): A frame of type numpy.ndarray with shape (x, y, 3) representing RGB values for the x-by-y pixel image.
  1180. """
  1181. if agent_pov:
  1182. return self.get_pov_render(tile_size)
  1183. else:
  1184. return self.get_full_render(highlight, tile_size)
  1185. def render(self):
  1186. img = self.get_frame(self.highlight, self.tile_size, self.agent_pov)
  1187. if self.render_mode == "human":
  1188. if self.window is None:
  1189. self.window = Window("minigrid")
  1190. self.window.show(block=False)
  1191. self.window.set_caption(self.mission)
  1192. self.window.show_img(img)
  1193. elif self.render_mode == "rgb_array":
  1194. return img
  1195. def close(self):
  1196. if self.window:
  1197. self.window.close()