minigrid_env.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752
  1. from __future__ import annotations
  2. import hashlib
  3. import math
  4. from abc import abstractmethod
  5. from enum import IntEnum
  6. from typing import Iterable, TypeVar
  7. import gymnasium as gym
  8. import numpy as np
  9. from gymnasium import spaces
  10. from minigrid.core.constants import COLOR_NAMES, DIR_TO_VEC, TILE_PIXELS
  11. from minigrid.core.grid import Grid
  12. from minigrid.core.mission import MissionSpace
  13. from minigrid.core.world_object import Point, WorldObj
  14. from minigrid.utils.window import Window
  15. T = TypeVar("T")
  16. class MiniGridEnv(gym.Env):
  17. """
  18. 2D grid world game environment
  19. """
  20. metadata = {
  21. "render_modes": ["human", "rgb_array"],
  22. "render_fps": 10,
  23. }
  24. # Enumeration of possible actions
  25. class Actions(IntEnum):
  26. # Turn left, turn right, move forward
  27. left = 0
  28. right = 1
  29. forward = 2
  30. # Pick up an object
  31. pickup = 3
  32. # Drop an object
  33. drop = 4
  34. # Toggle/activate an object
  35. toggle = 5
  36. # Done completing task
  37. done = 6
  38. def __init__(
  39. self,
  40. mission_space: MissionSpace,
  41. grid_size: int | None = None,
  42. width: int | None = None,
  43. height: int | None = None,
  44. max_steps: int = 100,
  45. see_through_walls: bool = False,
  46. agent_view_size: int = 7,
  47. render_mode: str | None = None,
  48. highlight: bool = True,
  49. tile_size: int = TILE_PIXELS,
  50. agent_pov: bool = False,
  51. ):
  52. # Initialize mission
  53. self.mission = mission_space.sample()
  54. # Can't set both grid_size and width/height
  55. if grid_size:
  56. assert width is None and height is None
  57. width = grid_size
  58. height = grid_size
  59. assert width is not None and height is not None
  60. # Action enumeration for this environment
  61. self.actions = MiniGridEnv.Actions
  62. # Actions are discrete integer values
  63. self.action_space = spaces.Discrete(len(self.actions))
  64. # Number of cells (width and height) in the agent view
  65. assert agent_view_size % 2 == 1
  66. assert agent_view_size >= 3
  67. self.agent_view_size = agent_view_size
  68. # Observations are dictionaries containing an
  69. # encoding of the grid and a textual 'mission' string
  70. image_observation_space = spaces.Box(
  71. low=0,
  72. high=255,
  73. shape=(self.agent_view_size, self.agent_view_size, 3),
  74. dtype="uint8",
  75. )
  76. self.observation_space = spaces.Dict(
  77. {
  78. "image": image_observation_space,
  79. "direction": spaces.Discrete(4),
  80. "mission": mission_space,
  81. }
  82. )
  83. # Range of possible rewards
  84. self.reward_range = (0, 1)
  85. self.window: Window = None
  86. # Environment configuration
  87. self.width = width
  88. self.height = height
  89. assert isinstance(
  90. max_steps, int
  91. ), f"The argument max_steps must be an integer, got: {type(max_steps)}"
  92. self.max_steps = max_steps
  93. self.see_through_walls = see_through_walls
  94. # Current position and direction of the agent
  95. self.agent_pos: np.ndarray | tuple[int, int] = None
  96. self.agent_dir: int = None
  97. # Current grid and mission and carrying
  98. self.grid = Grid(width, height)
  99. self.carrying = None
  100. # Rendering attributes
  101. self.render_mode = render_mode
  102. self.highlight = highlight
  103. self.tile_size = tile_size
  104. self.agent_pov = agent_pov
  105. def reset(self, *, seed=None, options=None):
  106. super().reset(seed=seed)
  107. # Reinitialize episode-specific variables
  108. self.agent_pos = (-1, -1)
  109. self.agent_dir = -1
  110. # Generate a new random grid at the start of each episode
  111. self._gen_grid(self.width, self.height)
  112. # These fields should be defined by _gen_grid
  113. assert (
  114. self.agent_pos >= (0, 0)
  115. if isinstance(self.agent_pos, tuple)
  116. else all(self.agent_pos >= 0) and self.agent_dir >= 0
  117. )
  118. # Check that the agent doesn't overlap with an object
  119. start_cell = self.grid.get(*self.agent_pos)
  120. assert start_cell is None or start_cell.can_overlap()
  121. # Item picked up, being carried, initially nothing
  122. self.carrying = None
  123. # Step count since episode start
  124. self.step_count = 0
  125. if self.render_mode == "human":
  126. self.render()
  127. # Return first observation
  128. obs = self.gen_obs()
  129. return obs, {}
  130. def hash(self, size=16):
  131. """Compute a hash that uniquely identifies the current state of the environment.
  132. :param size: Size of the hashing
  133. """
  134. sample_hash = hashlib.sha256()
  135. to_encode = [self.grid.encode().tolist(), self.agent_pos, self.agent_dir]
  136. for item in to_encode:
  137. sample_hash.update(str(item).encode("utf8"))
  138. return sample_hash.hexdigest()[:size]
  139. @property
  140. def steps_remaining(self):
  141. return self.max_steps - self.step_count
  142. def __str__(self):
  143. """
  144. Produce a pretty string of the environment's grid along with the agent.
  145. A grid cell is represented by 2-character string, the first one for
  146. the object and the second one for the color.
  147. """
  148. # Map of object types to short string
  149. OBJECT_TO_STR = {
  150. "wall": "W",
  151. "floor": "F",
  152. "door": "D",
  153. "key": "K",
  154. "ball": "A",
  155. "box": "B",
  156. "goal": "G",
  157. "lava": "V",
  158. }
  159. # Map agent's direction to short string
  160. AGENT_DIR_TO_STR = {0: ">", 1: "V", 2: "<", 3: "^"}
  161. str = ""
  162. for j in range(self.grid.height):
  163. for i in range(self.grid.width):
  164. if i == self.agent_pos[0] and j == self.agent_pos[1]:
  165. str += 2 * AGENT_DIR_TO_STR[self.agent_dir]
  166. continue
  167. c = self.grid.get(i, j)
  168. if c is None:
  169. str += " "
  170. continue
  171. if c.type == "door":
  172. if c.is_open:
  173. str += "__"
  174. elif c.is_locked:
  175. str += "L" + c.color[0].upper()
  176. else:
  177. str += "D" + c.color[0].upper()
  178. continue
  179. str += OBJECT_TO_STR[c.type] + c.color[0].upper()
  180. if j < self.grid.height - 1:
  181. str += "\n"
  182. return str
  183. @abstractmethod
  184. def _gen_grid(self, width, height):
  185. pass
  186. def _reward(self) -> float:
  187. """
  188. Compute the reward to be given upon success
  189. """
  190. return 1 - 0.9 * (self.step_count / self.max_steps)
  191. def _rand_int(self, low: int, high: int) -> int:
  192. """
  193. Generate random integer in [low,high[
  194. """
  195. return self.np_random.integers(low, high)
  196. def _rand_float(self, low: float, high: float) -> float:
  197. """
  198. Generate random float in [low,high[
  199. """
  200. return self.np_random.uniform(low, high)
  201. def _rand_bool(self) -> bool:
  202. """
  203. Generate random boolean value
  204. """
  205. return self.np_random.integers(0, 2) == 0
  206. def _rand_elem(self, iterable: Iterable[T]) -> T:
  207. """
  208. Pick a random element in a list
  209. """
  210. lst = list(iterable)
  211. idx = self._rand_int(0, len(lst))
  212. return lst[idx]
  213. def _rand_subset(self, iterable: Iterable[T], num_elems: int) -> list[T]:
  214. """
  215. Sample a random subset of distinct elements of a list
  216. """
  217. lst = list(iterable)
  218. assert num_elems <= len(lst)
  219. out: list[T] = []
  220. while len(out) < num_elems:
  221. elem = self._rand_elem(lst)
  222. lst.remove(elem)
  223. out.append(elem)
  224. return out
  225. def _rand_color(self) -> str:
  226. """
  227. Generate a random color name (string)
  228. """
  229. return self._rand_elem(COLOR_NAMES)
  230. def _rand_pos(
  231. self, x_low: int, x_high: int, y_low: int, y_high: int
  232. ) -> tuple[int, int]:
  233. """
  234. Generate a random (x,y) position tuple
  235. """
  236. return (
  237. self.np_random.integers(x_low, x_high),
  238. self.np_random.integers(y_low, y_high),
  239. )
  240. def place_obj(
  241. self,
  242. obj: WorldObj | None,
  243. top: Point = None,
  244. size: tuple[int, int] = None,
  245. reject_fn=None,
  246. max_tries=math.inf,
  247. ):
  248. """
  249. Place an object at an empty position in the grid
  250. :param top: top-left position of the rectangle where to place
  251. :param size: size of the rectangle where to place
  252. :param reject_fn: function to filter out potential positions
  253. """
  254. if top is None:
  255. top = (0, 0)
  256. else:
  257. top = (max(top[0], 0), max(top[1], 0))
  258. if size is None:
  259. size = (self.grid.width, self.grid.height)
  260. num_tries = 0
  261. while True:
  262. # This is to handle with rare cases where rejection sampling
  263. # gets stuck in an infinite loop
  264. if num_tries > max_tries:
  265. raise RecursionError("rejection sampling failed in place_obj")
  266. num_tries += 1
  267. pos = (
  268. self._rand_int(top[0], min(top[0] + size[0], self.grid.width)),
  269. self._rand_int(top[1], min(top[1] + size[1], self.grid.height)),
  270. )
  271. # Don't place the object on top of another object
  272. if self.grid.get(*pos) is not None:
  273. continue
  274. # Don't place the object where the agent is
  275. if np.array_equal(pos, self.agent_pos):
  276. continue
  277. # Check if there is a filtering criterion
  278. if reject_fn and reject_fn(self, pos):
  279. continue
  280. break
  281. self.grid.set(pos[0], pos[1], obj)
  282. if obj is not None:
  283. obj.init_pos = pos
  284. obj.cur_pos = pos
  285. return pos
  286. def put_obj(self, obj: WorldObj, i: int, j: int):
  287. """
  288. Put an object at a specific position in the grid
  289. """
  290. self.grid.set(i, j, obj)
  291. obj.init_pos = (i, j)
  292. obj.cur_pos = (i, j)
  293. def place_agent(self, top=None, size=None, rand_dir=True, max_tries=math.inf):
  294. """
  295. Set the agent's starting point at an empty position in the grid
  296. """
  297. self.agent_pos = (-1, -1)
  298. pos = self.place_obj(None, top, size, max_tries=max_tries)
  299. self.agent_pos = pos
  300. if rand_dir:
  301. self.agent_dir = self._rand_int(0, 4)
  302. return pos
  303. @property
  304. def dir_vec(self):
  305. """
  306. Get the direction vector for the agent, pointing in the direction
  307. of forward movement.
  308. """
  309. assert (
  310. self.agent_dir >= 0 and self.agent_dir < 4
  311. ), f"Invalid agent_dir: {self.agent_dir} is not within range(0, 4)"
  312. return DIR_TO_VEC[self.agent_dir]
  313. @property
  314. def right_vec(self):
  315. """
  316. Get the vector pointing to the right of the agent.
  317. """
  318. dx, dy = self.dir_vec
  319. return np.array((-dy, dx))
  320. @property
  321. def front_pos(self):
  322. """
  323. Get the position of the cell that is right in front of the agent
  324. """
  325. return self.agent_pos + self.dir_vec
  326. def get_view_coords(self, i, j):
  327. """
  328. Translate and rotate absolute grid coordinates (i, j) into the
  329. agent's partially observable view (sub-grid). Note that the resulting
  330. coordinates may be negative or outside of the agent's view size.
  331. """
  332. ax, ay = self.agent_pos
  333. dx, dy = self.dir_vec
  334. rx, ry = self.right_vec
  335. # Compute the absolute coordinates of the top-left view corner
  336. sz = self.agent_view_size
  337. hs = self.agent_view_size // 2
  338. tx = ax + (dx * (sz - 1)) - (rx * hs)
  339. ty = ay + (dy * (sz - 1)) - (ry * hs)
  340. lx = i - tx
  341. ly = j - ty
  342. # Project the coordinates of the object relative to the top-left
  343. # corner onto the agent's own coordinate system
  344. vx = rx * lx + ry * ly
  345. vy = -(dx * lx + dy * ly)
  346. return vx, vy
  347. def get_view_exts(self, agent_view_size=None):
  348. """
  349. Get the extents of the square set of tiles visible to the agent
  350. Note: the bottom extent indices are not included in the set
  351. if agent_view_size is None, use self.agent_view_size
  352. """
  353. agent_view_size = agent_view_size or self.agent_view_size
  354. # Facing right
  355. if self.agent_dir == 0:
  356. topX = self.agent_pos[0]
  357. topY = self.agent_pos[1] - agent_view_size // 2
  358. # Facing down
  359. elif self.agent_dir == 1:
  360. topX = self.agent_pos[0] - agent_view_size // 2
  361. topY = self.agent_pos[1]
  362. # Facing left
  363. elif self.agent_dir == 2:
  364. topX = self.agent_pos[0] - agent_view_size + 1
  365. topY = self.agent_pos[1] - agent_view_size // 2
  366. # Facing up
  367. elif self.agent_dir == 3:
  368. topX = self.agent_pos[0] - agent_view_size // 2
  369. topY = self.agent_pos[1] - agent_view_size + 1
  370. else:
  371. assert False, "invalid agent direction"
  372. botX = topX + agent_view_size
  373. botY = topY + agent_view_size
  374. return (topX, topY, botX, botY)
  375. def relative_coords(self, x, y):
  376. """
  377. Check if a grid position belongs to the agent's field of view, and returns the corresponding coordinates
  378. """
  379. vx, vy = self.get_view_coords(x, y)
  380. if vx < 0 or vy < 0 or vx >= self.agent_view_size or vy >= self.agent_view_size:
  381. return None
  382. return vx, vy
  383. def in_view(self, x, y):
  384. """
  385. check if a grid position is visible to the agent
  386. """
  387. return self.relative_coords(x, y) is not None
  388. def agent_sees(self, x, y):
  389. """
  390. Check if a non-empty grid position is visible to the agent
  391. """
  392. coordinates = self.relative_coords(x, y)
  393. if coordinates is None:
  394. return False
  395. vx, vy = coordinates
  396. obs = self.gen_obs()
  397. obs_grid, _ = Grid.decode(obs["image"])
  398. obs_cell = obs_grid.get(vx, vy)
  399. world_cell = self.grid.get(x, y)
  400. assert world_cell is not None
  401. return obs_cell is not None and obs_cell.type == world_cell.type
  402. def step(self, action):
  403. self.step_count += 1
  404. reward = 0
  405. terminated = False
  406. truncated = False
  407. # Get the position in front of the agent
  408. fwd_pos = self.front_pos
  409. # Get the contents of the cell in front of the agent
  410. fwd_cell = self.grid.get(*fwd_pos)
  411. # Rotate left
  412. if action == self.actions.left:
  413. self.agent_dir -= 1
  414. if self.agent_dir < 0:
  415. self.agent_dir += 4
  416. # Rotate right
  417. elif action == self.actions.right:
  418. self.agent_dir = (self.agent_dir + 1) % 4
  419. # Move forward
  420. elif action == self.actions.forward:
  421. if fwd_cell is None or fwd_cell.can_overlap():
  422. self.agent_pos = tuple(fwd_pos)
  423. if fwd_cell is not None and fwd_cell.type == "goal":
  424. terminated = True
  425. reward = self._reward()
  426. if fwd_cell is not None and fwd_cell.type == "lava":
  427. terminated = True
  428. # Pick up an object
  429. elif action == self.actions.pickup:
  430. if fwd_cell and fwd_cell.can_pickup():
  431. if self.carrying is None:
  432. self.carrying = fwd_cell
  433. self.carrying.cur_pos = np.array([-1, -1])
  434. self.grid.set(fwd_pos[0], fwd_pos[1], None)
  435. # Drop an object
  436. elif action == self.actions.drop:
  437. if not fwd_cell and self.carrying:
  438. self.grid.set(fwd_pos[0], fwd_pos[1], self.carrying)
  439. self.carrying.cur_pos = fwd_pos
  440. self.carrying = None
  441. # Toggle/activate an object
  442. elif action == self.actions.toggle:
  443. if fwd_cell:
  444. fwd_cell.toggle(self, fwd_pos)
  445. # Done action (not used by default)
  446. elif action == self.actions.done:
  447. pass
  448. else:
  449. raise ValueError(f"Unknown action: {action}")
  450. if self.step_count >= self.max_steps:
  451. truncated = True
  452. if self.render_mode == "human":
  453. self.render()
  454. obs = self.gen_obs()
  455. return obs, reward, terminated, truncated, {}
  456. def gen_obs_grid(self, agent_view_size=None):
  457. """
  458. Generate the sub-grid observed by the agent.
  459. This method also outputs a visibility mask telling us which grid
  460. cells the agent can actually see.
  461. if agent_view_size is None, self.agent_view_size is used
  462. """
  463. topX, topY, botX, botY = self.get_view_exts(agent_view_size)
  464. agent_view_size = agent_view_size or self.agent_view_size
  465. grid = self.grid.slice(topX, topY, agent_view_size, agent_view_size)
  466. for i in range(self.agent_dir + 1):
  467. grid = grid.rotate_left()
  468. # Process occluders and visibility
  469. # Note that this incurs some performance cost
  470. if not self.see_through_walls:
  471. vis_mask = grid.process_vis(
  472. agent_pos=(agent_view_size // 2, agent_view_size - 1)
  473. )
  474. else:
  475. vis_mask = np.ones(shape=(grid.width, grid.height), dtype=bool)
  476. # Make it so the agent sees what it's carrying
  477. # We do this by placing the carried object at the agent's position
  478. # in the agent's partially observable view
  479. agent_pos = grid.width // 2, grid.height - 1
  480. if self.carrying:
  481. grid.set(*agent_pos, self.carrying)
  482. else:
  483. grid.set(*agent_pos, None)
  484. return grid, vis_mask
  485. def gen_obs(self):
  486. """
  487. Generate the agent's view (partially observable, low-resolution encoding)
  488. """
  489. grid, vis_mask = self.gen_obs_grid()
  490. # Encode the partially observable view into a numpy array
  491. image = grid.encode(vis_mask)
  492. # Observations are dictionaries containing:
  493. # - an image (partially observable view of the environment)
  494. # - the agent's direction/orientation (acting as a compass)
  495. # - a textual mission string (instructions for the agent)
  496. obs = {"image": image, "direction": self.agent_dir, "mission": self.mission}
  497. return obs
  498. def get_pov_render(self, tile_size):
  499. """
  500. Render an agent's POV observation for visualization
  501. """
  502. grid, vis_mask = self.gen_obs_grid()
  503. # Render the whole grid
  504. img = grid.render(
  505. tile_size,
  506. agent_pos=(self.agent_view_size // 2, self.agent_view_size - 1),
  507. agent_dir=3,
  508. highlight_mask=vis_mask,
  509. )
  510. return img
  511. def get_full_render(self, highlight, tile_size):
  512. """
  513. Render a non-paratial observation for visualization
  514. """
  515. # Compute which cells are visible to the agent
  516. _, vis_mask = self.gen_obs_grid()
  517. # Compute the world coordinates of the bottom-left corner
  518. # of the agent's view area
  519. f_vec = self.dir_vec
  520. r_vec = self.right_vec
  521. top_left = (
  522. self.agent_pos
  523. + f_vec * (self.agent_view_size - 1)
  524. - r_vec * (self.agent_view_size // 2)
  525. )
  526. # Mask of which cells to highlight
  527. highlight_mask = np.zeros(shape=(self.width, self.height), dtype=bool)
  528. # For each cell in the visibility mask
  529. for vis_j in range(0, self.agent_view_size):
  530. for vis_i in range(0, self.agent_view_size):
  531. # If this cell is not visible, don't highlight it
  532. if not vis_mask[vis_i, vis_j]:
  533. continue
  534. # Compute the world coordinates of this cell
  535. abs_i, abs_j = top_left - (f_vec * vis_j) + (r_vec * vis_i)
  536. if abs_i < 0 or abs_i >= self.width:
  537. continue
  538. if abs_j < 0 or abs_j >= self.height:
  539. continue
  540. # Mark this cell to be highlighted
  541. highlight_mask[abs_i, abs_j] = True
  542. # Render the whole grid
  543. img = self.grid.render(
  544. tile_size,
  545. self.agent_pos,
  546. self.agent_dir,
  547. highlight_mask=highlight_mask if highlight else None,
  548. )
  549. return img
  550. def get_frame(
  551. self,
  552. highlight: bool = True,
  553. tile_size: int = TILE_PIXELS,
  554. agent_pov: bool = False,
  555. ):
  556. """Returns an RGB image corresponding to the whole environment or the agent's point of view.
  557. Args:
  558. highlight (bool): If true, the agent's field of view or point of view is highlighted with a lighter gray color.
  559. tile_size (int): How many pixels will form a tile from the NxM grid.
  560. agent_pov (bool): If true, the rendered frame will only contain the point of view of the agent.
  561. Returns:
  562. frame (np.ndarray): A frame of type numpy.ndarray with shape (x, y, 3) representing RGB values for the x-by-y pixel image.
  563. """
  564. if agent_pov:
  565. return self.get_pov_render(tile_size)
  566. else:
  567. return self.get_full_render(highlight, tile_size)
  568. def render(self):
  569. img = self.get_frame(self.highlight, self.tile_size, self.agent_pov)
  570. if self.render_mode == "human":
  571. if self.window is None:
  572. self.window = Window("minigrid")
  573. self.window.show(block=False)
  574. self.window.set_caption(self.mission)
  575. self.window.show_img(img)
  576. elif self.render_mode == "rgb_array":
  577. return img
  578. def close(self):
  579. if self.window:
  580. self.window.close()