minigrid_env.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744
  1. from __future__ import annotations
  2. import hashlib
  3. import math
  4. from abc import abstractmethod
  5. from typing import Any, Iterable, SupportsFloat, TypeVar
  6. import gymnasium as gym
  7. import numpy as np
  8. from gymnasium import spaces
  9. from gymnasium.core import ActType, ObsType
  10. from minigrid.core.actions import Actions
  11. from minigrid.core.constants import COLOR_NAMES, DIR_TO_VEC, TILE_PIXELS
  12. from minigrid.core.grid import Grid
  13. from minigrid.core.mission import MissionSpace
  14. from minigrid.core.world_object import Point, WorldObj
  15. from minigrid.utils.window import Window
  16. T = TypeVar("T")
  17. class MiniGridEnv(gym.Env):
  18. """
  19. 2D grid world game environment
  20. """
  21. metadata = {
  22. "render_modes": ["human", "rgb_array"],
  23. "render_fps": 10,
  24. }
  25. def __init__(
  26. self,
  27. mission_space: MissionSpace,
  28. grid_size: int | None = None,
  29. width: int | None = None,
  30. height: int | None = None,
  31. max_steps: int = 100,
  32. see_through_walls: bool = False,
  33. agent_view_size: int = 7,
  34. render_mode: str | None = None,
  35. highlight: bool = True,
  36. tile_size: int = TILE_PIXELS,
  37. agent_pov: bool = False,
  38. ):
  39. # Initialize mission
  40. self.mission = mission_space.sample()
  41. # Can't set both grid_size and width/height
  42. if grid_size:
  43. assert width is None and height is None
  44. width = grid_size
  45. height = grid_size
  46. assert width is not None and height is not None
  47. # Action enumeration for this environment
  48. self.actions = Actions
  49. # Actions are discrete integer values
  50. self.action_space = spaces.Discrete(len(self.actions))
  51. # Number of cells (width and height) in the agent view
  52. assert agent_view_size % 2 == 1
  53. assert agent_view_size >= 3
  54. self.agent_view_size = agent_view_size
  55. # Observations are dictionaries containing an
  56. # encoding of the grid and a textual 'mission' string
  57. image_observation_space = spaces.Box(
  58. low=0,
  59. high=255,
  60. shape=(self.agent_view_size, self.agent_view_size, 3),
  61. dtype="uint8",
  62. )
  63. self.observation_space = spaces.Dict(
  64. {
  65. "image": image_observation_space,
  66. "direction": spaces.Discrete(4),
  67. "mission": mission_space,
  68. }
  69. )
  70. # Range of possible rewards
  71. self.reward_range = (0, 1)
  72. self.window: Window = None
  73. # Environment configuration
  74. self.width = width
  75. self.height = height
  76. assert isinstance(
  77. max_steps, int
  78. ), f"The argument max_steps must be an integer, got: {type(max_steps)}"
  79. self.max_steps = max_steps
  80. self.see_through_walls = see_through_walls
  81. # Current position and direction of the agent
  82. self.agent_pos: np.ndarray | tuple[int, int] = None
  83. self.agent_dir: int = None
  84. # Current grid and mission and carrying
  85. self.grid = Grid(width, height)
  86. self.carrying = None
  87. # Rendering attributes
  88. self.render_mode = render_mode
  89. self.highlight = highlight
  90. self.tile_size = tile_size
  91. self.agent_pov = agent_pov
  92. def reset(
  93. self,
  94. *,
  95. seed: int | None = None,
  96. options: dict[str, Any] | None = None,
  97. ) -> tuple[ObsType, dict[str, Any]]:
  98. super().reset(seed=seed)
  99. # Reinitialize episode-specific variables
  100. self.agent_pos = (-1, -1)
  101. self.agent_dir = -1
  102. # Generate a new random grid at the start of each episode
  103. self._gen_grid(self.width, self.height)
  104. # These fields should be defined by _gen_grid
  105. assert (
  106. self.agent_pos >= (0, 0)
  107. if isinstance(self.agent_pos, tuple)
  108. else all(self.agent_pos >= 0) and self.agent_dir >= 0
  109. )
  110. # Check that the agent doesn't overlap with an object
  111. start_cell = self.grid.get(*self.agent_pos)
  112. assert start_cell is None or start_cell.can_overlap()
  113. # Item picked up, being carried, initially nothing
  114. self.carrying = None
  115. # Step count since episode start
  116. self.step_count = 0
  117. if self.render_mode == "human":
  118. self.render()
  119. # Return first observation
  120. obs = self.gen_obs()
  121. return obs, {}
  122. def hash(self, size=16):
  123. """Compute a hash that uniquely identifies the current state of the environment.
  124. :param size: Size of the hashing
  125. """
  126. sample_hash = hashlib.sha256()
  127. to_encode = [self.grid.encode().tolist(), self.agent_pos, self.agent_dir]
  128. for item in to_encode:
  129. sample_hash.update(str(item).encode("utf8"))
  130. return sample_hash.hexdigest()[:size]
  131. @property
  132. def steps_remaining(self):
  133. return self.max_steps - self.step_count
  134. def __str__(self):
  135. """
  136. Produce a pretty string of the environment's grid along with the agent.
  137. A grid cell is represented by 2-character string, the first one for
  138. the object and the second one for the color.
  139. """
  140. # Map of object types to short string
  141. OBJECT_TO_STR = {
  142. "wall": "W",
  143. "floor": "F",
  144. "door": "D",
  145. "key": "K",
  146. "ball": "A",
  147. "box": "B",
  148. "goal": "G",
  149. "lava": "V",
  150. }
  151. # Map agent's direction to short string
  152. AGENT_DIR_TO_STR = {0: ">", 1: "V", 2: "<", 3: "^"}
  153. output = ""
  154. for j in range(self.grid.height):
  155. for i in range(self.grid.width):
  156. if i == self.agent_pos[0] and j == self.agent_pos[1]:
  157. output += 2 * AGENT_DIR_TO_STR[self.agent_dir]
  158. continue
  159. tile = self.grid.get(i, j)
  160. if tile is None:
  161. output += " "
  162. continue
  163. if tile.type == "door":
  164. if tile.is_open:
  165. output += "__"
  166. elif tile.is_locked:
  167. output += "L" + tile.color[0].upper()
  168. else:
  169. output += "D" + tile.color[0].upper()
  170. continue
  171. output += OBJECT_TO_STR[tile.type] + tile.color[0].upper()
  172. if j < self.grid.height - 1:
  173. output += "\n"
  174. return output
  175. @abstractmethod
  176. def _gen_grid(self, width, height):
  177. pass
  178. def _reward(self) -> float:
  179. """
  180. Compute the reward to be given upon success
  181. """
  182. return 1 - 0.9 * (self.step_count / self.max_steps)
  183. def _rand_int(self, low: int, high: int) -> int:
  184. """
  185. Generate random integer in [low,high[
  186. """
  187. return self.np_random.integers(low, high)
  188. def _rand_float(self, low: float, high: float) -> float:
  189. """
  190. Generate random float in [low,high[
  191. """
  192. return self.np_random.uniform(low, high)
  193. def _rand_bool(self) -> bool:
  194. """
  195. Generate random boolean value
  196. """
  197. return self.np_random.integers(0, 2) == 0
  198. def _rand_elem(self, iterable: Iterable[T]) -> T:
  199. """
  200. Pick a random element in a list
  201. """
  202. lst = list(iterable)
  203. idx = self._rand_int(0, len(lst))
  204. return lst[idx]
  205. def _rand_subset(self, iterable: Iterable[T], num_elems: int) -> list[T]:
  206. """
  207. Sample a random subset of distinct elements of a list
  208. """
  209. lst = list(iterable)
  210. assert num_elems <= len(lst)
  211. out: list[T] = []
  212. while len(out) < num_elems:
  213. elem = self._rand_elem(lst)
  214. lst.remove(elem)
  215. out.append(elem)
  216. return out
  217. def _rand_color(self) -> str:
  218. """
  219. Generate a random color name (string)
  220. """
  221. return self._rand_elem(COLOR_NAMES)
  222. def _rand_pos(
  223. self, x_low: int, x_high: int, y_low: int, y_high: int
  224. ) -> tuple[int, int]:
  225. """
  226. Generate a random (x,y) position tuple
  227. """
  228. return (
  229. self.np_random.integers(x_low, x_high),
  230. self.np_random.integers(y_low, y_high),
  231. )
  232. def place_obj(
  233. self,
  234. obj: WorldObj | None,
  235. top: Point = None,
  236. size: tuple[int, int] = None,
  237. reject_fn=None,
  238. max_tries=math.inf,
  239. ):
  240. """
  241. Place an object at an empty position in the grid
  242. :param top: top-left position of the rectangle where to place
  243. :param size: size of the rectangle where to place
  244. :param reject_fn: function to filter out potential positions
  245. """
  246. if top is None:
  247. top = (0, 0)
  248. else:
  249. top = (max(top[0], 0), max(top[1], 0))
  250. if size is None:
  251. size = (self.grid.width, self.grid.height)
  252. num_tries = 0
  253. while True:
  254. # This is to handle with rare cases where rejection sampling
  255. # gets stuck in an infinite loop
  256. if num_tries > max_tries:
  257. raise RecursionError("rejection sampling failed in place_obj")
  258. num_tries += 1
  259. pos = (
  260. self._rand_int(top[0], min(top[0] + size[0], self.grid.width)),
  261. self._rand_int(top[1], min(top[1] + size[1], self.grid.height)),
  262. )
  263. # Don't place the object on top of another object
  264. if self.grid.get(*pos) is not None:
  265. continue
  266. # Don't place the object where the agent is
  267. if np.array_equal(pos, self.agent_pos):
  268. continue
  269. # Check if there is a filtering criterion
  270. if reject_fn and reject_fn(self, pos):
  271. continue
  272. break
  273. self.grid.set(pos[0], pos[1], obj)
  274. if obj is not None:
  275. obj.init_pos = pos
  276. obj.cur_pos = pos
  277. return pos
  278. def put_obj(self, obj: WorldObj, i: int, j: int):
  279. """
  280. Put an object at a specific position in the grid
  281. """
  282. self.grid.set(i, j, obj)
  283. obj.init_pos = (i, j)
  284. obj.cur_pos = (i, j)
  285. def place_agent(self, top=None, size=None, rand_dir=True, max_tries=math.inf):
  286. """
  287. Set the agent's starting point at an empty position in the grid
  288. """
  289. self.agent_pos = (-1, -1)
  290. pos = self.place_obj(None, top, size, max_tries=max_tries)
  291. self.agent_pos = pos
  292. if rand_dir:
  293. self.agent_dir = self._rand_int(0, 4)
  294. return pos
  295. @property
  296. def dir_vec(self):
  297. """
  298. Get the direction vector for the agent, pointing in the direction
  299. of forward movement.
  300. """
  301. assert (
  302. self.agent_dir >= 0 and self.agent_dir < 4
  303. ), f"Invalid agent_dir: {self.agent_dir} is not within range(0, 4)"
  304. return DIR_TO_VEC[self.agent_dir]
  305. @property
  306. def right_vec(self):
  307. """
  308. Get the vector pointing to the right of the agent.
  309. """
  310. dx, dy = self.dir_vec
  311. return np.array((-dy, dx))
  312. @property
  313. def front_pos(self):
  314. """
  315. Get the position of the cell that is right in front of the agent
  316. """
  317. return self.agent_pos + self.dir_vec
  318. def get_view_coords(self, i, j):
  319. """
  320. Translate and rotate absolute grid coordinates (i, j) into the
  321. agent's partially observable view (sub-grid). Note that the resulting
  322. coordinates may be negative or outside of the agent's view size.
  323. """
  324. ax, ay = self.agent_pos
  325. dx, dy = self.dir_vec
  326. rx, ry = self.right_vec
  327. # Compute the absolute coordinates of the top-left view corner
  328. sz = self.agent_view_size
  329. hs = self.agent_view_size // 2
  330. tx = ax + (dx * (sz - 1)) - (rx * hs)
  331. ty = ay + (dy * (sz - 1)) - (ry * hs)
  332. lx = i - tx
  333. ly = j - ty
  334. # Project the coordinates of the object relative to the top-left
  335. # corner onto the agent's own coordinate system
  336. vx = rx * lx + ry * ly
  337. vy = -(dx * lx + dy * ly)
  338. return vx, vy
  339. def get_view_exts(self, agent_view_size=None):
  340. """
  341. Get the extents of the square set of tiles visible to the agent
  342. Note: the bottom extent indices are not included in the set
  343. if agent_view_size is None, use self.agent_view_size
  344. """
  345. agent_view_size = agent_view_size or self.agent_view_size
  346. # Facing right
  347. if self.agent_dir == 0:
  348. topX = self.agent_pos[0]
  349. topY = self.agent_pos[1] - agent_view_size // 2
  350. # Facing down
  351. elif self.agent_dir == 1:
  352. topX = self.agent_pos[0] - agent_view_size // 2
  353. topY = self.agent_pos[1]
  354. # Facing left
  355. elif self.agent_dir == 2:
  356. topX = self.agent_pos[0] - agent_view_size + 1
  357. topY = self.agent_pos[1] - agent_view_size // 2
  358. # Facing up
  359. elif self.agent_dir == 3:
  360. topX = self.agent_pos[0] - agent_view_size // 2
  361. topY = self.agent_pos[1] - agent_view_size + 1
  362. else:
  363. assert False, "invalid agent direction"
  364. botX = topX + agent_view_size
  365. botY = topY + agent_view_size
  366. return topX, topY, botX, botY
  367. def relative_coords(self, x, y):
  368. """
  369. Check if a grid position belongs to the agent's field of view, and returns the corresponding coordinates
  370. """
  371. vx, vy = self.get_view_coords(x, y)
  372. if vx < 0 or vy < 0 or vx >= self.agent_view_size or vy >= self.agent_view_size:
  373. return None
  374. return vx, vy
  375. def in_view(self, x, y):
  376. """
  377. check if a grid position is visible to the agent
  378. """
  379. return self.relative_coords(x, y) is not None
  380. def agent_sees(self, x, y):
  381. """
  382. Check if a non-empty grid position is visible to the agent
  383. """
  384. coordinates = self.relative_coords(x, y)
  385. if coordinates is None:
  386. return False
  387. vx, vy = coordinates
  388. obs = self.gen_obs()
  389. obs_grid, _ = Grid.decode(obs["image"])
  390. obs_cell = obs_grid.get(vx, vy)
  391. world_cell = self.grid.get(x, y)
  392. assert world_cell is not None
  393. return obs_cell is not None and obs_cell.type == world_cell.type
  394. def step(
  395. self, action: ActType
  396. ) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
  397. self.step_count += 1
  398. reward = 0
  399. terminated = False
  400. truncated = False
  401. # Get the position in front of the agent
  402. fwd_pos = self.front_pos
  403. # Get the contents of the cell in front of the agent
  404. fwd_cell = self.grid.get(*fwd_pos)
  405. # Rotate left
  406. if action == self.actions.left:
  407. self.agent_dir -= 1
  408. if self.agent_dir < 0:
  409. self.agent_dir += 4
  410. # Rotate right
  411. elif action == self.actions.right:
  412. self.agent_dir = (self.agent_dir + 1) % 4
  413. # Move forward
  414. elif action == self.actions.forward:
  415. if fwd_cell is None or fwd_cell.can_overlap():
  416. self.agent_pos = tuple(fwd_pos)
  417. if fwd_cell is not None and fwd_cell.type == "goal":
  418. terminated = True
  419. reward = self._reward()
  420. if fwd_cell is not None and fwd_cell.type == "lava":
  421. terminated = True
  422. # Pick up an object
  423. elif action == self.actions.pickup:
  424. if fwd_cell and fwd_cell.can_pickup():
  425. if self.carrying is None:
  426. self.carrying = fwd_cell
  427. self.carrying.cur_pos = np.array([-1, -1])
  428. self.grid.set(fwd_pos[0], fwd_pos[1], None)
  429. # Drop an object
  430. elif action == self.actions.drop:
  431. if not fwd_cell and self.carrying:
  432. self.grid.set(fwd_pos[0], fwd_pos[1], self.carrying)
  433. self.carrying.cur_pos = fwd_pos
  434. self.carrying = None
  435. # Toggle/activate an object
  436. elif action == self.actions.toggle:
  437. if fwd_cell:
  438. fwd_cell.toggle(self, fwd_pos)
  439. # Done action (not used by default)
  440. elif action == self.actions.done:
  441. pass
  442. else:
  443. raise ValueError(f"Unknown action: {action}")
  444. if self.step_count >= self.max_steps:
  445. truncated = True
  446. if self.render_mode == "human":
  447. self.render()
  448. obs = self.gen_obs()
  449. return obs, reward, terminated, truncated, {}
  450. def gen_obs_grid(self, agent_view_size=None):
  451. """
  452. Generate the sub-grid observed by the agent.
  453. This method also outputs a visibility mask telling us which grid
  454. cells the agent can actually see.
  455. if agent_view_size is None, self.agent_view_size is used
  456. """
  457. topX, topY, botX, botY = self.get_view_exts(agent_view_size)
  458. agent_view_size = agent_view_size or self.agent_view_size
  459. grid = self.grid.slice(topX, topY, agent_view_size, agent_view_size)
  460. for i in range(self.agent_dir + 1):
  461. grid = grid.rotate_left()
  462. # Process occluders and visibility
  463. # Note that this incurs some performance cost
  464. if not self.see_through_walls:
  465. vis_mask = grid.process_vis(
  466. agent_pos=(agent_view_size // 2, agent_view_size - 1)
  467. )
  468. else:
  469. vis_mask = np.ones(shape=(grid.width, grid.height), dtype=bool)
  470. # Make it so the agent sees what it's carrying
  471. # We do this by placing the carried object at the agent's position
  472. # in the agent's partially observable view
  473. agent_pos = grid.width // 2, grid.height - 1
  474. if self.carrying:
  475. grid.set(*agent_pos, self.carrying)
  476. else:
  477. grid.set(*agent_pos, None)
  478. return grid, vis_mask
  479. def gen_obs(self):
  480. """
  481. Generate the agent's view (partially observable, low-resolution encoding)
  482. """
  483. grid, vis_mask = self.gen_obs_grid()
  484. # Encode the partially observable view into a numpy array
  485. image = grid.encode(vis_mask)
  486. # Observations are dictionaries containing:
  487. # - an image (partially observable view of the environment)
  488. # - the agent's direction/orientation (acting as a compass)
  489. # - a textual mission string (instructions for the agent)
  490. obs = {"image": image, "direction": self.agent_dir, "mission": self.mission}
  491. return obs
  492. def get_pov_render(self, tile_size):
  493. """
  494. Render an agent's POV observation for visualization
  495. """
  496. grid, vis_mask = self.gen_obs_grid()
  497. # Render the whole grid
  498. img = grid.render(
  499. tile_size,
  500. agent_pos=(self.agent_view_size // 2, self.agent_view_size - 1),
  501. agent_dir=3,
  502. highlight_mask=vis_mask,
  503. )
  504. return img
  505. def get_full_render(self, highlight, tile_size):
  506. """
  507. Render a non-paratial observation for visualization
  508. """
  509. # Compute which cells are visible to the agent
  510. _, vis_mask = self.gen_obs_grid()
  511. # Compute the world coordinates of the bottom-left corner
  512. # of the agent's view area
  513. f_vec = self.dir_vec
  514. r_vec = self.right_vec
  515. top_left = (
  516. self.agent_pos
  517. + f_vec * (self.agent_view_size - 1)
  518. - r_vec * (self.agent_view_size // 2)
  519. )
  520. # Mask of which cells to highlight
  521. highlight_mask = np.zeros(shape=(self.width, self.height), dtype=bool)
  522. # For each cell in the visibility mask
  523. for vis_j in range(0, self.agent_view_size):
  524. for vis_i in range(0, self.agent_view_size):
  525. # If this cell is not visible, don't highlight it
  526. if not vis_mask[vis_i, vis_j]:
  527. continue
  528. # Compute the world coordinates of this cell
  529. abs_i, abs_j = top_left - (f_vec * vis_j) + (r_vec * vis_i)
  530. if abs_i < 0 or abs_i >= self.width:
  531. continue
  532. if abs_j < 0 or abs_j >= self.height:
  533. continue
  534. # Mark this cell to be highlighted
  535. highlight_mask[abs_i, abs_j] = True
  536. # Render the whole grid
  537. img = self.grid.render(
  538. tile_size,
  539. self.agent_pos,
  540. self.agent_dir,
  541. highlight_mask=highlight_mask if highlight else None,
  542. )
  543. return img
  544. def get_frame(
  545. self,
  546. highlight: bool = True,
  547. tile_size: int = TILE_PIXELS,
  548. agent_pov: bool = False,
  549. ):
  550. """Returns an RGB image corresponding to the whole environment or the agent's point of view.
  551. Args:
  552. highlight (bool): If true, the agent's field of view or point of view is highlighted with a lighter gray color.
  553. tile_size (int): How many pixels will form a tile from the NxM grid.
  554. agent_pov (bool): If true, the rendered frame will only contain the point of view of the agent.
  555. Returns:
  556. frame (np.ndarray): A frame of type numpy.ndarray with shape (x, y, 3) representing RGB values for the x-by-y pixel image.
  557. """
  558. if agent_pov:
  559. return self.get_pov_render(tile_size)
  560. else:
  561. return self.get_full_render(highlight, tile_size)
  562. def render(self):
  563. img = self.get_frame(self.highlight, self.tile_size, self.agent_pov)
  564. if self.render_mode == "human":
  565. if self.window is None:
  566. self.window = Window("minigrid")
  567. self.window.show(block=False)
  568. self.window.set_caption(self.mission)
  569. self.window.show_img(img)
  570. elif self.render_mode == "rgb_array":
  571. return img
  572. def close(self):
  573. if self.window:
  574. self.window.close()