minigrid_env.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739
  1. import hashlib
  2. import math
  3. from abc import abstractmethod
  4. from enum import IntEnum
  5. from typing import Optional
  6. import gymnasium as gym
  7. import numpy as np
  8. from gymnasium import spaces
  9. from minigrid.core.constants import COLOR_NAMES, DIR_TO_VEC, TILE_PIXELS
  10. from minigrid.core.grid import Grid
  11. from minigrid.core.mission import MissionSpace
  12. from minigrid.utils.window import Window
  13. class MiniGridEnv(gym.Env):
  14. """
  15. 2D grid world game environment
  16. """
  17. metadata = {
  18. "render_modes": ["human", "rgb_array"],
  19. "render_fps": 10,
  20. }
  21. # Enumeration of possible actions
  22. class Actions(IntEnum):
  23. # Turn left, turn right, move forward
  24. left = 0
  25. right = 1
  26. forward = 2
  27. # Pick up an object
  28. pickup = 3
  29. # Drop an object
  30. drop = 4
  31. # Toggle/activate an object
  32. toggle = 5
  33. # Done completing task
  34. done = 6
  35. def __init__(
  36. self,
  37. mission_space: MissionSpace,
  38. grid_size: int = None,
  39. width: int = None,
  40. height: int = None,
  41. max_steps: int = 100,
  42. see_through_walls: bool = False,
  43. agent_view_size: int = 7,
  44. render_mode: Optional[str] = None,
  45. highlight: bool = True,
  46. tile_size: int = TILE_PIXELS,
  47. agent_pov: bool = False,
  48. ):
  49. # Initialize mission
  50. self.mission = mission_space.sample()
  51. # Can't set both grid_size and width/height
  52. if grid_size:
  53. assert width is None and height is None
  54. width = grid_size
  55. height = grid_size
  56. # Action enumeration for this environment
  57. self.actions = MiniGridEnv.Actions
  58. # Actions are discrete integer values
  59. self.action_space = spaces.Discrete(len(self.actions))
  60. # Number of cells (width and height) in the agent view
  61. assert agent_view_size % 2 == 1
  62. assert agent_view_size >= 3
  63. self.agent_view_size = agent_view_size
  64. # Observations are dictionaries containing an
  65. # encoding of the grid and a textual 'mission' string
  66. image_observation_space = spaces.Box(
  67. low=0,
  68. high=255,
  69. shape=(self.agent_view_size, self.agent_view_size, 3),
  70. dtype="uint8",
  71. )
  72. self.observation_space = spaces.Dict(
  73. {
  74. "image": image_observation_space,
  75. "direction": spaces.Discrete(4),
  76. "mission": mission_space,
  77. }
  78. )
  79. # Range of possible rewards
  80. self.reward_range = (0, 1)
  81. self.window: Window = None
  82. # Environment configuration
  83. self.width = width
  84. self.height = height
  85. assert isinstance(
  86. max_steps, int
  87. ), f"The argument max_steps must be an integer, got: {type(max_steps)}"
  88. self.max_steps = max_steps
  89. self.see_through_walls = see_through_walls
  90. # Current position and direction of the agent
  91. self.agent_pos: np.ndarray = None
  92. self.agent_dir: int = None
  93. # Current grid and mission and carrying
  94. self.grid = Grid(width, height)
  95. self.carrying = None
  96. # Rendering attributes
  97. self.render_mode = render_mode
  98. self.highlight = highlight
  99. self.tile_size = tile_size
  100. self.agent_pov = agent_pov
  101. def reset(self, *, seed=None, options=None):
  102. super().reset(seed=seed)
  103. # Reinitialize episode-specific variables
  104. self.agent_pos = (-1, -1)
  105. self.agent_dir = -1
  106. # Generate a new random grid at the start of each episode
  107. self._gen_grid(self.width, self.height)
  108. # These fields should be defined by _gen_grid
  109. assert (
  110. self.agent_pos >= (0, 0)
  111. if isinstance(self.agent_pos, tuple)
  112. else all(self.agent_pos >= 0) and self.agent_dir >= 0
  113. )
  114. # Check that the agent doesn't overlap with an object
  115. start_cell = self.grid.get(*self.agent_pos)
  116. assert start_cell is None or start_cell.can_overlap()
  117. # Item picked up, being carried, initially nothing
  118. self.carrying = None
  119. # Step count since episode start
  120. self.step_count = 0
  121. if self.render_mode == "human":
  122. self.render()
  123. # Return first observation
  124. obs = self.gen_obs()
  125. return obs, {}
  126. def hash(self, size=16):
  127. """Compute a hash that uniquely identifies the current state of the environment.
  128. :param size: Size of the hashing
  129. """
  130. sample_hash = hashlib.sha256()
  131. to_encode = [self.grid.encode().tolist(), self.agent_pos, self.agent_dir]
  132. for item in to_encode:
  133. sample_hash.update(str(item).encode("utf8"))
  134. return sample_hash.hexdigest()[:size]
  135. @property
  136. def steps_remaining(self):
  137. return self.max_steps - self.step_count
  138. def __str__(self):
  139. """
  140. Produce a pretty string of the environment's grid along with the agent.
  141. A grid cell is represented by 2-character string, the first one for
  142. the object and the second one for the color.
  143. """
  144. # Map of object types to short string
  145. OBJECT_TO_STR = {
  146. "wall": "W",
  147. "floor": "F",
  148. "door": "D",
  149. "key": "K",
  150. "ball": "A",
  151. "box": "B",
  152. "goal": "G",
  153. "lava": "V",
  154. }
  155. # Map agent's direction to short string
  156. AGENT_DIR_TO_STR = {0: ">", 1: "V", 2: "<", 3: "^"}
  157. str = ""
  158. for j in range(self.grid.height):
  159. for i in range(self.grid.width):
  160. if i == self.agent_pos[0] and j == self.agent_pos[1]:
  161. str += 2 * AGENT_DIR_TO_STR[self.agent_dir]
  162. continue
  163. c = self.grid.get(i, j)
  164. if c is None:
  165. str += " "
  166. continue
  167. if c.type == "door":
  168. if c.is_open:
  169. str += "__"
  170. elif c.is_locked:
  171. str += "L" + c.color[0].upper()
  172. else:
  173. str += "D" + c.color[0].upper()
  174. continue
  175. str += OBJECT_TO_STR[c.type] + c.color[0].upper()
  176. if j < self.grid.height - 1:
  177. str += "\n"
  178. return str
  179. @abstractmethod
  180. def _gen_grid(self, width, height):
  181. pass
  182. def _reward(self):
  183. """
  184. Compute the reward to be given upon success
  185. """
  186. return 1 - 0.9 * (self.step_count / self.max_steps)
  187. def _rand_int(self, low, high):
  188. """
  189. Generate random integer in [low,high[
  190. """
  191. return self.np_random.integers(low, high)
  192. def _rand_float(self, low, high):
  193. """
  194. Generate random float in [low,high[
  195. """
  196. return self.np_random.uniform(low, high)
  197. def _rand_bool(self):
  198. """
  199. Generate random boolean value
  200. """
  201. return self.np_random.integers(0, 2) == 0
  202. def _rand_elem(self, iterable):
  203. """
  204. Pick a random element in a list
  205. """
  206. lst = list(iterable)
  207. idx = self._rand_int(0, len(lst))
  208. return lst[idx]
  209. def _rand_subset(self, iterable, num_elems):
  210. """
  211. Sample a random subset of distinct elements of a list
  212. """
  213. lst = list(iterable)
  214. assert num_elems <= len(lst)
  215. out = []
  216. while len(out) < num_elems:
  217. elem = self._rand_elem(lst)
  218. lst.remove(elem)
  219. out.append(elem)
  220. return out
  221. def _rand_color(self):
  222. """
  223. Generate a random color name (string)
  224. """
  225. return self._rand_elem(COLOR_NAMES)
  226. def _rand_pos(self, xLow, xHigh, yLow, yHigh):
  227. """
  228. Generate a random (x,y) position tuple
  229. """
  230. return (
  231. self.np_random.integers(xLow, xHigh),
  232. self.np_random.integers(yLow, yHigh),
  233. )
  234. def place_obj(self, obj, top=None, size=None, reject_fn=None, max_tries=math.inf):
  235. """
  236. Place an object at an empty position in the grid
  237. :param top: top-left position of the rectangle where to place
  238. :param size: size of the rectangle where to place
  239. :param reject_fn: function to filter out potential positions
  240. """
  241. if top is None:
  242. top = (0, 0)
  243. else:
  244. top = (max(top[0], 0), max(top[1], 0))
  245. if size is None:
  246. size = (self.grid.width, self.grid.height)
  247. num_tries = 0
  248. while True:
  249. # This is to handle with rare cases where rejection sampling
  250. # gets stuck in an infinite loop
  251. if num_tries > max_tries:
  252. raise RecursionError("rejection sampling failed in place_obj")
  253. num_tries += 1
  254. pos = np.array(
  255. (
  256. self._rand_int(top[0], min(top[0] + size[0], self.grid.width)),
  257. self._rand_int(top[1], min(top[1] + size[1], self.grid.height)),
  258. )
  259. )
  260. pos = tuple(pos)
  261. # Don't place the object on top of another object
  262. if self.grid.get(*pos) is not None:
  263. continue
  264. # Don't place the object where the agent is
  265. if np.array_equal(pos, self.agent_pos):
  266. continue
  267. # Check if there is a filtering criterion
  268. if reject_fn and reject_fn(self, pos):
  269. continue
  270. break
  271. self.grid.set(pos[0], pos[1], obj)
  272. if obj is not None:
  273. obj.init_pos = pos
  274. obj.cur_pos = pos
  275. return pos
  276. def put_obj(self, obj, i, j):
  277. """
  278. Put an object at a specific position in the grid
  279. """
  280. self.grid.set(i, j, obj)
  281. obj.init_pos = (i, j)
  282. obj.cur_pos = (i, j)
  283. def place_agent(self, top=None, size=None, rand_dir=True, max_tries=math.inf):
  284. """
  285. Set the agent's starting point at an empty position in the grid
  286. """
  287. self.agent_pos = (-1, -1)
  288. pos = self.place_obj(None, top, size, max_tries=max_tries)
  289. self.agent_pos = pos
  290. if rand_dir:
  291. self.agent_dir = self._rand_int(0, 4)
  292. return pos
  293. @property
  294. def dir_vec(self):
  295. """
  296. Get the direction vector for the agent, pointing in the direction
  297. of forward movement.
  298. """
  299. assert self.agent_dir >= 0 and self.agent_dir < 4
  300. return DIR_TO_VEC[self.agent_dir]
  301. @property
  302. def right_vec(self):
  303. """
  304. Get the vector pointing to the right of the agent.
  305. """
  306. dx, dy = self.dir_vec
  307. return np.array((-dy, dx))
  308. @property
  309. def front_pos(self):
  310. """
  311. Get the position of the cell that is right in front of the agent
  312. """
  313. return self.agent_pos + self.dir_vec
  314. def get_view_coords(self, i, j):
  315. """
  316. Translate and rotate absolute grid coordinates (i, j) into the
  317. agent's partially observable view (sub-grid). Note that the resulting
  318. coordinates may be negative or outside of the agent's view size.
  319. """
  320. ax, ay = self.agent_pos
  321. dx, dy = self.dir_vec
  322. rx, ry = self.right_vec
  323. # Compute the absolute coordinates of the top-left view corner
  324. sz = self.agent_view_size
  325. hs = self.agent_view_size // 2
  326. tx = ax + (dx * (sz - 1)) - (rx * hs)
  327. ty = ay + (dy * (sz - 1)) - (ry * hs)
  328. lx = i - tx
  329. ly = j - ty
  330. # Project the coordinates of the object relative to the top-left
  331. # corner onto the agent's own coordinate system
  332. vx = rx * lx + ry * ly
  333. vy = -(dx * lx + dy * ly)
  334. return vx, vy
  335. def get_view_exts(self, agent_view_size=None):
  336. """
  337. Get the extents of the square set of tiles visible to the agent
  338. Note: the bottom extent indices are not included in the set
  339. if agent_view_size is None, use self.agent_view_size
  340. """
  341. agent_view_size = agent_view_size or self.agent_view_size
  342. # Facing right
  343. if self.agent_dir == 0:
  344. topX = self.agent_pos[0]
  345. topY = self.agent_pos[1] - agent_view_size // 2
  346. # Facing down
  347. elif self.agent_dir == 1:
  348. topX = self.agent_pos[0] - agent_view_size // 2
  349. topY = self.agent_pos[1]
  350. # Facing left
  351. elif self.agent_dir == 2:
  352. topX = self.agent_pos[0] - agent_view_size + 1
  353. topY = self.agent_pos[1] - agent_view_size // 2
  354. # Facing up
  355. elif self.agent_dir == 3:
  356. topX = self.agent_pos[0] - agent_view_size // 2
  357. topY = self.agent_pos[1] - agent_view_size + 1
  358. else:
  359. assert False, "invalid agent direction"
  360. botX = topX + agent_view_size
  361. botY = topY + agent_view_size
  362. return (topX, topY, botX, botY)
  363. def relative_coords(self, x, y):
  364. """
  365. Check if a grid position belongs to the agent's field of view, and returns the corresponding coordinates
  366. """
  367. vx, vy = self.get_view_coords(x, y)
  368. if vx < 0 or vy < 0 or vx >= self.agent_view_size or vy >= self.agent_view_size:
  369. return None
  370. return vx, vy
  371. def in_view(self, x, y):
  372. """
  373. check if a grid position is visible to the agent
  374. """
  375. return self.relative_coords(x, y) is not None
  376. def agent_sees(self, x, y):
  377. """
  378. Check if a non-empty grid position is visible to the agent
  379. """
  380. coordinates = self.relative_coords(x, y)
  381. if coordinates is None:
  382. return False
  383. vx, vy = coordinates
  384. obs = self.gen_obs()
  385. obs_grid, _ = Grid.decode(obs["image"])
  386. obs_cell = obs_grid.get(vx, vy)
  387. world_cell = self.grid.get(x, y)
  388. assert world_cell is not None
  389. return obs_cell is not None and obs_cell.type == world_cell.type
  390. def step(self, action):
  391. self.step_count += 1
  392. reward = 0
  393. terminated = False
  394. truncated = False
  395. # Get the position in front of the agent
  396. fwd_pos = self.front_pos
  397. # Get the contents of the cell in front of the agent
  398. fwd_cell = self.grid.get(*fwd_pos)
  399. # Rotate left
  400. if action == self.actions.left:
  401. self.agent_dir -= 1
  402. if self.agent_dir < 0:
  403. self.agent_dir += 4
  404. # Rotate right
  405. elif action == self.actions.right:
  406. self.agent_dir = (self.agent_dir + 1) % 4
  407. # Move forward
  408. elif action == self.actions.forward:
  409. if fwd_cell is None or fwd_cell.can_overlap():
  410. self.agent_pos = tuple(fwd_pos)
  411. if fwd_cell is not None and fwd_cell.type == "goal":
  412. terminated = True
  413. reward = self._reward()
  414. if fwd_cell is not None and fwd_cell.type == "lava":
  415. terminated = True
  416. # Pick up an object
  417. elif action == self.actions.pickup:
  418. if fwd_cell and fwd_cell.can_pickup():
  419. if self.carrying is None:
  420. self.carrying = fwd_cell
  421. self.carrying.cur_pos = np.array([-1, -1])
  422. self.grid.set(fwd_pos[0], fwd_pos[1], None)
  423. # Drop an object
  424. elif action == self.actions.drop:
  425. if not fwd_cell and self.carrying:
  426. self.grid.set(fwd_pos[0], fwd_pos[1], self.carrying)
  427. self.carrying.cur_pos = fwd_pos
  428. self.carrying = None
  429. # Toggle/activate an object
  430. elif action == self.actions.toggle:
  431. if fwd_cell:
  432. fwd_cell.toggle(self, fwd_pos)
  433. # Done action (not used by default)
  434. elif action == self.actions.done:
  435. pass
  436. else:
  437. raise ValueError(f"Unknown action: {action}")
  438. if self.step_count >= self.max_steps:
  439. truncated = True
  440. if self.render_mode == "human":
  441. self.render()
  442. obs = self.gen_obs()
  443. return obs, reward, terminated, truncated, {}
  444. def gen_obs_grid(self, agent_view_size=None):
  445. """
  446. Generate the sub-grid observed by the agent.
  447. This method also outputs a visibility mask telling us which grid
  448. cells the agent can actually see.
  449. if agent_view_size is None, self.agent_view_size is used
  450. """
  451. topX, topY, botX, botY = self.get_view_exts(agent_view_size)
  452. agent_view_size = agent_view_size or self.agent_view_size
  453. grid = self.grid.slice(topX, topY, agent_view_size, agent_view_size)
  454. for i in range(self.agent_dir + 1):
  455. grid = grid.rotate_left()
  456. # Process occluders and visibility
  457. # Note that this incurs some performance cost
  458. if not self.see_through_walls:
  459. vis_mask = grid.process_vis(
  460. agent_pos=(agent_view_size // 2, agent_view_size - 1)
  461. )
  462. else:
  463. vis_mask = np.ones(shape=(grid.width, grid.height), dtype=bool)
  464. # Make it so the agent sees what it's carrying
  465. # We do this by placing the carried object at the agent's position
  466. # in the agent's partially observable view
  467. agent_pos = grid.width // 2, grid.height - 1
  468. if self.carrying:
  469. grid.set(*agent_pos, self.carrying)
  470. else:
  471. grid.set(*agent_pos, None)
  472. return grid, vis_mask
  473. def gen_obs(self):
  474. """
  475. Generate the agent's view (partially observable, low-resolution encoding)
  476. """
  477. grid, vis_mask = self.gen_obs_grid()
  478. # Encode the partially observable view into a numpy array
  479. image = grid.encode(vis_mask)
  480. # Observations are dictionaries containing:
  481. # - an image (partially observable view of the environment)
  482. # - the agent's direction/orientation (acting as a compass)
  483. # - a textual mission string (instructions for the agent)
  484. obs = {"image": image, "direction": self.agent_dir, "mission": self.mission}
  485. return obs
  486. def get_pov_render(self, tile_size):
  487. """
  488. Render an agent's POV observation for visualization
  489. """
  490. grid, vis_mask = self.gen_obs_grid()
  491. # Render the whole grid
  492. img = grid.render(
  493. tile_size,
  494. agent_pos=(self.agent_view_size // 2, self.agent_view_size - 1),
  495. agent_dir=3,
  496. highlight_mask=vis_mask,
  497. )
  498. return img
  499. def get_full_render(self, highlight, tile_size):
  500. """
  501. Render a non-paratial observation for visualization
  502. """
  503. # Compute which cells are visible to the agent
  504. _, vis_mask = self.gen_obs_grid()
  505. # Compute the world coordinates of the bottom-left corner
  506. # of the agent's view area
  507. f_vec = self.dir_vec
  508. r_vec = self.right_vec
  509. top_left = (
  510. self.agent_pos
  511. + f_vec * (self.agent_view_size - 1)
  512. - r_vec * (self.agent_view_size // 2)
  513. )
  514. # Mask of which cells to highlight
  515. highlight_mask = np.zeros(shape=(self.width, self.height), dtype=bool)
  516. # For each cell in the visibility mask
  517. for vis_j in range(0, self.agent_view_size):
  518. for vis_i in range(0, self.agent_view_size):
  519. # If this cell is not visible, don't highlight it
  520. if not vis_mask[vis_i, vis_j]:
  521. continue
  522. # Compute the world coordinates of this cell
  523. abs_i, abs_j = top_left - (f_vec * vis_j) + (r_vec * vis_i)
  524. if abs_i < 0 or abs_i >= self.width:
  525. continue
  526. if abs_j < 0 or abs_j >= self.height:
  527. continue
  528. # Mark this cell to be highlighted
  529. highlight_mask[abs_i, abs_j] = True
  530. # Render the whole grid
  531. img = self.grid.render(
  532. tile_size,
  533. self.agent_pos,
  534. self.agent_dir,
  535. highlight_mask=highlight_mask if highlight else None,
  536. )
  537. return img
  538. def get_frame(
  539. self,
  540. highlight: bool = True,
  541. tile_size: int = TILE_PIXELS,
  542. agent_pov: bool = False,
  543. ):
  544. """Returns an RGB image corresponding to the whole environment or the agent's point of view.
  545. Args:
  546. highlight (bool): If true, the agent's field of view or point of view is highlighted with a lighter gray color.
  547. tile_size (int): How many pixels will form a tile from the NxM grid.
  548. agent_pov (bool): If true, the rendered frame will only contain the point of view of the agent.
  549. Returns:
  550. frame (np.ndarray): A frame of type numpy.ndarray with shape (x, y, 3) representing RGB values for the x-by-y pixel image.
  551. """
  552. if agent_pov:
  553. return self.get_pov_render(tile_size)
  554. else:
  555. return self.get_full_render(highlight, tile_size)
  556. def render(self):
  557. img = self.get_frame(self.highlight, self.tile_size, self.agent_pov)
  558. if self.render_mode == "human":
  559. if self.window is None:
  560. self.window = Window("minigrid")
  561. self.window.show(block=False)
  562. self.window.set_caption(self.mission)
  563. self.window.show_img(img)
  564. elif self.render_mode == "rgb_array":
  565. return img
  566. def close(self):
  567. if self.window:
  568. self.window.close()