minigrid.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717
  1. import hashlib
  2. import math
  3. from abc import abstractmethod
  4. from typing import Optional
  5. import gymnasium as gym
  6. import numpy as np
  7. from gymnasium import spaces
  8. from minigrid.core.actions import Actions
  9. from minigrid.core.constants import COLOR_NAMES, DIR_TO_VEC, TILE_PIXELS
  10. from minigrid.core.grid import Grid
  11. from minigrid.core.mission import MissionSpace
  12. from minigrid.utils.window import Window
  13. class MiniGridEnv(gym.Env):
  14. """
  15. 2D grid world game environment
  16. """
  17. metadata = {
  18. "render_modes": ["human", "rgb_array"],
  19. "render_fps": 10,
  20. }
  21. def __init__(
  22. self,
  23. mission_space: MissionSpace,
  24. grid_size: int = None,
  25. width: int = None,
  26. height: int = None,
  27. max_steps: int = 100,
  28. see_through_walls: bool = False,
  29. agent_view_size: int = 7,
  30. render_mode: Optional[str] = None,
  31. highlight: bool = True,
  32. tile_size: int = TILE_PIXELS,
  33. agent_pov: bool = False,
  34. ):
  35. # Initialize mission
  36. self.mission = mission_space.sample()
  37. # Can't set both grid_size and width/height
  38. if grid_size:
  39. assert width is None and height is None
  40. width = grid_size
  41. height = grid_size
  42. # Action enumeration for this environment
  43. self.actions = Actions
  44. # Actions are discrete integer values
  45. self.action_space = spaces.Discrete(len(self.actions))
  46. # Number of cells (width and height) in the agent view
  47. assert agent_view_size % 2 == 1
  48. assert agent_view_size >= 3
  49. self.agent_view_size = agent_view_size
  50. # Observations are dictionaries containing an
  51. # encoding of the grid and a textual 'mission' string
  52. image_observation_space = spaces.Box(
  53. low=0,
  54. high=255,
  55. shape=(self.agent_view_size, self.agent_view_size, 3),
  56. dtype="uint8",
  57. )
  58. self.observation_space = spaces.Dict(
  59. {
  60. "image": image_observation_space,
  61. "direction": spaces.Discrete(4),
  62. "mission": mission_space,
  63. }
  64. )
  65. # Range of possible rewards
  66. self.reward_range = (0, 1)
  67. self.window: Window = None
  68. # Environment configuration
  69. self.width = width
  70. self.height = height
  71. self.max_steps = max_steps
  72. self.see_through_walls = see_through_walls
  73. # Current position and direction of the agent
  74. self.agent_pos: np.ndarray = None
  75. self.agent_dir: int = None
  76. # Current grid and mission and carryinh
  77. self.grid = Grid(width, height)
  78. self.carrying = None
  79. # Rendering attributes
  80. self.render_mode = render_mode
  81. self.highlight = highlight
  82. self.tile_size = tile_size
  83. self.agent_pov = agent_pov
  84. def reset(self, *, seed=None, options=None):
  85. super().reset(seed=seed)
  86. # Reinitialize episode-specific variables
  87. self.agent_pos = (-1, -1)
  88. self.agent_dir = -1
  89. # Generate a new random grid at the start of each episode
  90. self._gen_grid(self.width, self.height)
  91. # These fields should be defined by _gen_grid
  92. assert (
  93. self.agent_pos >= (0, 0)
  94. if isinstance(self.agent_pos, tuple)
  95. else all(self.agent_pos >= 0) and self.agent_dir >= 0
  96. )
  97. # Check that the agent doesn't overlap with an object
  98. start_cell = self.grid.get(*self.agent_pos)
  99. assert start_cell is None or start_cell.can_overlap()
  100. # Item picked up, being carried, initially nothing
  101. self.carrying = None
  102. # Step count since episode start
  103. self.step_count = 0
  104. if self.render_mode == "human":
  105. self.render()
  106. # Return first observation
  107. obs = self.gen_obs()
  108. return obs, {}
  109. def hash(self, size=16):
  110. """Compute a hash that uniquely identifies the current state of the environment.
  111. :param size: Size of the hashing
  112. """
  113. sample_hash = hashlib.sha256()
  114. to_encode = [self.grid.encode().tolist(), self.agent_pos, self.agent_dir]
  115. for item in to_encode:
  116. sample_hash.update(str(item).encode("utf8"))
  117. return sample_hash.hexdigest()[:size]
  118. @property
  119. def steps_remaining(self):
  120. return self.max_steps - self.step_count
  121. def __str__(self):
  122. """
  123. Produce a pretty string of the environment's grid along with the agent.
  124. A grid cell is represented by 2-character string, the first one for
  125. the object and the second one for the color.
  126. """
  127. # Map of object types to short string
  128. OBJECT_TO_STR = {
  129. "wall": "W",
  130. "floor": "F",
  131. "door": "D",
  132. "key": "K",
  133. "ball": "A",
  134. "box": "B",
  135. "goal": "G",
  136. "lava": "V",
  137. }
  138. # Map agent's direction to short string
  139. AGENT_DIR_TO_STR = {0: ">", 1: "V", 2: "<", 3: "^"}
  140. str = ""
  141. for j in range(self.grid.height):
  142. for i in range(self.grid.width):
  143. if i == self.agent_pos[0] and j == self.agent_pos[1]:
  144. str += 2 * AGENT_DIR_TO_STR[self.agent_dir]
  145. continue
  146. c = self.grid.get(i, j)
  147. if c is None:
  148. str += " "
  149. continue
  150. if c.type == "door":
  151. if c.is_open:
  152. str += "__"
  153. elif c.is_locked:
  154. str += "L" + c.color[0].upper()
  155. else:
  156. str += "D" + c.color[0].upper()
  157. continue
  158. str += OBJECT_TO_STR[c.type] + c.color[0].upper()
  159. if j < self.grid.height - 1:
  160. str += "\n"
  161. return str
  162. @abstractmethod
  163. def _gen_grid(self, width, height):
  164. pass
  165. def _reward(self):
  166. """
  167. Compute the reward to be given upon success
  168. """
  169. return 1 - 0.9 * (self.step_count / self.max_steps)
  170. def _rand_int(self, low, high):
  171. """
  172. Generate random integer in [low,high[
  173. """
  174. return self.np_random.integers(low, high)
  175. def _rand_float(self, low, high):
  176. """
  177. Generate random float in [low,high[
  178. """
  179. return self.np_random.uniform(low, high)
  180. def _rand_bool(self):
  181. """
  182. Generate random boolean value
  183. """
  184. return self.np_random.integers(0, 2) == 0
  185. def _rand_elem(self, iterable):
  186. """
  187. Pick a random element in a list
  188. """
  189. lst = list(iterable)
  190. idx = self._rand_int(0, len(lst))
  191. return lst[idx]
  192. def _rand_subset(self, iterable, num_elems):
  193. """
  194. Sample a random subset of distinct elements of a list
  195. """
  196. lst = list(iterable)
  197. assert num_elems <= len(lst)
  198. out = []
  199. while len(out) < num_elems:
  200. elem = self._rand_elem(lst)
  201. lst.remove(elem)
  202. out.append(elem)
  203. return out
  204. def _rand_color(self):
  205. """
  206. Generate a random color name (string)
  207. """
  208. return self._rand_elem(COLOR_NAMES)
  209. def _rand_pos(self, xLow, xHigh, yLow, yHigh):
  210. """
  211. Generate a random (x,y) position tuple
  212. """
  213. return (
  214. self.np_random.integers(xLow, xHigh),
  215. self.np_random.integers(yLow, yHigh),
  216. )
  217. def place_obj(self, obj, top=None, size=None, reject_fn=None, max_tries=math.inf):
  218. """
  219. Place an object at an empty position in the grid
  220. :param top: top-left position of the rectangle where to place
  221. :param size: size of the rectangle where to place
  222. :param reject_fn: function to filter out potential positions
  223. """
  224. if top is None:
  225. top = (0, 0)
  226. else:
  227. top = (max(top[0], 0), max(top[1], 0))
  228. if size is None:
  229. size = (self.grid.width, self.grid.height)
  230. num_tries = 0
  231. while True:
  232. # This is to handle with rare cases where rejection sampling
  233. # gets stuck in an infinite loop
  234. if num_tries > max_tries:
  235. raise RecursionError("rejection sampling failed in place_obj")
  236. num_tries += 1
  237. pos = np.array(
  238. (
  239. self._rand_int(top[0], min(top[0] + size[0], self.grid.width)),
  240. self._rand_int(top[1], min(top[1] + size[1], self.grid.height)),
  241. )
  242. )
  243. pos = tuple(pos)
  244. # Don't place the object on top of another object
  245. if self.grid.get(*pos) is not None:
  246. continue
  247. # Don't place the object where the agent is
  248. if np.array_equal(pos, self.agent_pos):
  249. continue
  250. # Check if there is a filtering criterion
  251. if reject_fn and reject_fn(self, pos):
  252. continue
  253. break
  254. self.grid.set(pos[0], pos[1], obj)
  255. if obj is not None:
  256. obj.init_pos = pos
  257. obj.cur_pos = pos
  258. return pos
  259. def put_obj(self, obj, i, j):
  260. """
  261. Put an object at a specific position in the grid
  262. """
  263. self.grid.set(i, j, obj)
  264. obj.init_pos = (i, j)
  265. obj.cur_pos = (i, j)
  266. def place_agent(self, top=None, size=None, rand_dir=True, max_tries=math.inf):
  267. """
  268. Set the agent's starting point at an empty position in the grid
  269. """
  270. self.agent_pos = (-1, -1)
  271. pos = self.place_obj(None, top, size, max_tries=max_tries)
  272. self.agent_pos = pos
  273. if rand_dir:
  274. self.agent_dir = self._rand_int(0, 4)
  275. return pos
  276. @property
  277. def dir_vec(self):
  278. """
  279. Get the direction vector for the agent, pointing in the direction
  280. of forward movement.
  281. """
  282. assert self.agent_dir >= 0 and self.agent_dir < 4
  283. return DIR_TO_VEC[self.agent_dir]
  284. @property
  285. def right_vec(self):
  286. """
  287. Get the vector pointing to the right of the agent.
  288. """
  289. dx, dy = self.dir_vec
  290. return np.array((-dy, dx))
  291. @property
  292. def front_pos(self):
  293. """
  294. Get the position of the cell that is right in front of the agent
  295. """
  296. return self.agent_pos + self.dir_vec
  297. def get_view_coords(self, i, j):
  298. """
  299. Translate and rotate absolute grid coordinates (i, j) into the
  300. agent's partially observable view (sub-grid). Note that the resulting
  301. coordinates may be negative or outside of the agent's view size.
  302. """
  303. ax, ay = self.agent_pos
  304. dx, dy = self.dir_vec
  305. rx, ry = self.right_vec
  306. # Compute the absolute coordinates of the top-left view corner
  307. sz = self.agent_view_size
  308. hs = self.agent_view_size // 2
  309. tx = ax + (dx * (sz - 1)) - (rx * hs)
  310. ty = ay + (dy * (sz - 1)) - (ry * hs)
  311. lx = i - tx
  312. ly = j - ty
  313. # Project the coordinates of the object relative to the top-left
  314. # corner onto the agent's own coordinate system
  315. vx = rx * lx + ry * ly
  316. vy = -(dx * lx + dy * ly)
  317. return vx, vy
  318. def get_view_exts(self, agent_view_size=None):
  319. """
  320. Get the extents of the square set of tiles visible to the agent
  321. Note: the bottom extent indices are not included in the set
  322. if agent_view_size is None, use self.agent_view_size
  323. """
  324. agent_view_size = agent_view_size or self.agent_view_size
  325. # Facing right
  326. if self.agent_dir == 0:
  327. topX = self.agent_pos[0]
  328. topY = self.agent_pos[1] - agent_view_size // 2
  329. # Facing down
  330. elif self.agent_dir == 1:
  331. topX = self.agent_pos[0] - agent_view_size // 2
  332. topY = self.agent_pos[1]
  333. # Facing left
  334. elif self.agent_dir == 2:
  335. topX = self.agent_pos[0] - agent_view_size + 1
  336. topY = self.agent_pos[1] - agent_view_size // 2
  337. # Facing up
  338. elif self.agent_dir == 3:
  339. topX = self.agent_pos[0] - agent_view_size // 2
  340. topY = self.agent_pos[1] - agent_view_size + 1
  341. else:
  342. assert False, "invalid agent direction"
  343. botX = topX + agent_view_size
  344. botY = topY + agent_view_size
  345. return (topX, topY, botX, botY)
  346. def relative_coords(self, x, y):
  347. """
  348. Check if a grid position belongs to the agent's field of view, and returns the corresponding coordinates
  349. """
  350. vx, vy = self.get_view_coords(x, y)
  351. if vx < 0 or vy < 0 or vx >= self.agent_view_size or vy >= self.agent_view_size:
  352. return None
  353. return vx, vy
  354. def in_view(self, x, y):
  355. """
  356. check if a grid position is visible to the agent
  357. """
  358. return self.relative_coords(x, y) is not None
  359. def agent_sees(self, x, y):
  360. """
  361. Check if a non-empty grid position is visible to the agent
  362. """
  363. coordinates = self.relative_coords(x, y)
  364. if coordinates is None:
  365. return False
  366. vx, vy = coordinates
  367. obs = self.gen_obs()
  368. obs_grid, _ = Grid.decode(obs["image"])
  369. obs_cell = obs_grid.get(vx, vy)
  370. world_cell = self.grid.get(x, y)
  371. assert world_cell is not None
  372. return obs_cell is not None and obs_cell.type == world_cell.type
  373. def step(self, action):
  374. self.step_count += 1
  375. reward = 0
  376. terminated = False
  377. truncated = False
  378. # Get the position in front of the agent
  379. fwd_pos = self.front_pos
  380. # Get the contents of the cell in front of the agent
  381. fwd_cell = self.grid.get(*fwd_pos)
  382. # Rotate left
  383. if action == self.actions.left:
  384. self.agent_dir -= 1
  385. if self.agent_dir < 0:
  386. self.agent_dir += 4
  387. # Rotate right
  388. elif action == self.actions.right:
  389. self.agent_dir = (self.agent_dir + 1) % 4
  390. # Move forward
  391. elif action == self.actions.forward:
  392. if fwd_cell is None or fwd_cell.can_overlap():
  393. self.agent_pos = tuple(fwd_pos)
  394. if fwd_cell is not None and fwd_cell.type == "goal":
  395. terminated = True
  396. reward = self._reward()
  397. if fwd_cell is not None and fwd_cell.type == "lava":
  398. terminated = True
  399. # Pick up an object
  400. elif action == self.actions.pickup:
  401. if fwd_cell and fwd_cell.can_pickup():
  402. if self.carrying is None:
  403. self.carrying = fwd_cell
  404. self.carrying.cur_pos = np.array([-1, -1])
  405. self.grid.set(fwd_pos[0], fwd_pos[1], None)
  406. # Drop an object
  407. elif action == self.actions.drop:
  408. if not fwd_cell and self.carrying:
  409. self.grid.set(fwd_pos[0], fwd_pos[1], self.carrying)
  410. self.carrying.cur_pos = fwd_pos
  411. self.carrying = None
  412. # Toggle/activate an object
  413. elif action == self.actions.toggle:
  414. if fwd_cell:
  415. fwd_cell.toggle(self, fwd_pos)
  416. # Done action (not used by default)
  417. elif action == self.actions.done:
  418. pass
  419. else:
  420. raise ValueError(f"Unknown action: {action}")
  421. if self.step_count >= self.max_steps:
  422. truncated = True
  423. if self.render_mode == "human":
  424. self.render()
  425. obs = self.gen_obs()
  426. return obs, reward, terminated, truncated, {}
  427. def gen_obs_grid(self, agent_view_size=None):
  428. """
  429. Generate the sub-grid observed by the agent.
  430. This method also outputs a visibility mask telling us which grid
  431. cells the agent can actually see.
  432. if agent_view_size is None, self.agent_view_size is used
  433. """
  434. topX, topY, botX, botY = self.get_view_exts(agent_view_size)
  435. agent_view_size = agent_view_size or self.agent_view_size
  436. grid = self.grid.slice(topX, topY, agent_view_size, agent_view_size)
  437. for i in range(self.agent_dir + 1):
  438. grid = grid.rotate_left()
  439. # Process occluders and visibility
  440. # Note that this incurs some performance cost
  441. if not self.see_through_walls:
  442. vis_mask = grid.process_vis(
  443. agent_pos=(agent_view_size // 2, agent_view_size - 1)
  444. )
  445. else:
  446. vis_mask = np.ones(shape=(grid.width, grid.height), dtype=bool)
  447. # Make it so the agent sees what it's carrying
  448. # We do this by placing the carried object at the agent's position
  449. # in the agent's partially observable view
  450. agent_pos = grid.width // 2, grid.height - 1
  451. if self.carrying:
  452. grid.set(*agent_pos, self.carrying)
  453. else:
  454. grid.set(*agent_pos, None)
  455. return grid, vis_mask
  456. def gen_obs(self):
  457. """
  458. Generate the agent's view (partially observable, low-resolution encoding)
  459. """
  460. grid, vis_mask = self.gen_obs_grid()
  461. # Encode the partially observable view into a numpy array
  462. image = grid.encode(vis_mask)
  463. # Observations are dictionaries containing:
  464. # - an image (partially observable view of the environment)
  465. # - the agent's direction/orientation (acting as a compass)
  466. # - a textual mission string (instructions for the agent)
  467. obs = {"image": image, "direction": self.agent_dir, "mission": self.mission}
  468. return obs
  469. def get_pov_render(self, tile_size):
  470. """
  471. Render an agent's POV observation for visualization
  472. """
  473. grid, vis_mask = self.gen_obs_grid()
  474. # Render the whole grid
  475. img = grid.render(
  476. tile_size,
  477. agent_pos=(self.agent_view_size // 2, self.agent_view_size - 1),
  478. agent_dir=3,
  479. highlight_mask=vis_mask,
  480. )
  481. return img
  482. def get_full_render(self, highlight, tile_size):
  483. """
  484. Render a non-paratial observation for visualization
  485. """
  486. # Compute which cells are visible to the agent
  487. _, vis_mask = self.gen_obs_grid()
  488. # Compute the world coordinates of the bottom-left corner
  489. # of the agent's view area
  490. f_vec = self.dir_vec
  491. r_vec = self.right_vec
  492. top_left = (
  493. self.agent_pos
  494. + f_vec * (self.agent_view_size - 1)
  495. - r_vec * (self.agent_view_size // 2)
  496. )
  497. # Mask of which cells to highlight
  498. highlight_mask = np.zeros(shape=(self.width, self.height), dtype=bool)
  499. # For each cell in the visibility mask
  500. for vis_j in range(0, self.agent_view_size):
  501. for vis_i in range(0, self.agent_view_size):
  502. # If this cell is not visible, don't highlight it
  503. if not vis_mask[vis_i, vis_j]:
  504. continue
  505. # Compute the world coordinates of this cell
  506. abs_i, abs_j = top_left - (f_vec * vis_j) + (r_vec * vis_i)
  507. if abs_i < 0 or abs_i >= self.width:
  508. continue
  509. if abs_j < 0 or abs_j >= self.height:
  510. continue
  511. # Mark this cell to be highlighted
  512. highlight_mask[abs_i, abs_j] = True
  513. # Render the whole grid
  514. img = self.grid.render(
  515. tile_size,
  516. self.agent_pos,
  517. self.agent_dir,
  518. highlight_mask=highlight_mask if highlight else None,
  519. )
  520. return img
  521. def get_frame(
  522. self,
  523. highlight: bool = True,
  524. tile_size: int = TILE_PIXELS,
  525. agent_pov: bool = False,
  526. ):
  527. """Returns an RGB image corresponding to the whole environment or the agent's point of view.
  528. Args:
  529. highlight (bool): If true, the agent's field of view or point of view is highlighted with a lighter gray color.
  530. tile_size (int): How many pixels will form a tile from the NxM grid.
  531. agent_pov (bool): If true, the rendered frame will only contain the point of view of the agent.
  532. Returns:
  533. frame (np.ndarray): A frame of type numpy.ndarray with shape (x, y, 3) representing RGB values for the x-by-y pixel image.
  534. """
  535. if agent_pov:
  536. return self.get_pov_render(tile_size)
  537. else:
  538. return self.get_full_render(highlight, tile_size)
  539. def render(self):
  540. img = self.get_frame(self.highlight, self.tile_size, self.agent_pov)
  541. if self.render_mode == "human":
  542. if self.window is None:
  543. self.window = Window("minigrid")
  544. self.window.show(block=False)
  545. self.window.set_caption(self.mission)
  546. self.window.show_img(img)
  547. elif self.render_mode == "rgb_array":
  548. return img
  549. def close(self):
  550. if self.window:
  551. self.window.close()