wrappers.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874
  1. from __future__ import annotations
  2. import math
  3. import operator
  4. from functools import reduce
  5. from typing import Any
  6. import gymnasium as gym
  7. import numpy as np
  8. from gymnasium import logger, spaces
  9. from gymnasium.core import ActionWrapper, ObservationWrapper, ObsType, Wrapper
  10. from minigrid.core.constants import COLOR_TO_IDX, OBJECT_TO_IDX, STATE_TO_IDX
  11. from minigrid.core.world_object import Goal
  12. class ReseedWrapper(Wrapper):
  13. """
  14. Wrapper to always regenerate an environment with the same set of seeds.
  15. This can be used to force an environment to always keep the same
  16. configuration when reset.
  17. Example:
  18. >>> import minigrid
  19. >>> import gymnasium as gym
  20. >>> from minigrid.wrappers import ReseedWrapper
  21. >>> env = gym.make("MiniGrid-Empty-5x5-v0")
  22. >>> _ = env.reset(seed=123)
  23. >>> [env.np_random.integers(10) for i in range(10)]
  24. [0, 6, 5, 0, 9, 2, 2, 1, 3, 1]
  25. >>> env = ReseedWrapper(env, seeds=[0, 1], seed_idx=0)
  26. >>> _, _ = env.reset()
  27. >>> [env.np_random.integers(10) for i in range(10)]
  28. [8, 6, 5, 2, 3, 0, 0, 0, 1, 8]
  29. >>> _, _ = env.reset()
  30. >>> [env.np_random.integers(10) for i in range(10)]
  31. [4, 5, 7, 9, 0, 1, 8, 9, 2, 3]
  32. >>> _, _ = env.reset()
  33. >>> [env.np_random.integers(10) for i in range(10)]
  34. [8, 6, 5, 2, 3, 0, 0, 0, 1, 8]
  35. >>> _, _ = env.reset()
  36. >>> [env.np_random.integers(10) for i in range(10)]
  37. [4, 5, 7, 9, 0, 1, 8, 9, 2, 3]
  38. """
  39. def __init__(self, env, seeds=(0,), seed_idx=0):
  40. """A wrapper that always regenerate an environment with the same set of seeds.
  41. Args:
  42. env: The environment to apply the wrapper
  43. seeds: A list of seed to be applied to the env
  44. seed_idx: Index of the initial seed in seeds
  45. """
  46. self.seeds = list(seeds)
  47. self.seed_idx = seed_idx
  48. super().__init__(env)
  49. def reset(
  50. self, *, seed: int | None = None, options: dict[str, Any] | None = None
  51. ) -> tuple[ObsType, dict[str, Any]]:
  52. if seed is not None:
  53. logger.warn(
  54. "A seed has been passed to `ReseedWrapper.reset` which is ignored."
  55. )
  56. seed = self.seeds[self.seed_idx]
  57. self.seed_idx = (self.seed_idx + 1) % len(self.seeds)
  58. return self.env.reset(seed=seed, options=options)
  59. class ActionBonus(gym.Wrapper):
  60. """
  61. Wrapper which adds an exploration bonus.
  62. This is a reward to encourage exploration of less
  63. visited (state,action) pairs.
  64. Example:
  65. >>> import gymnasium as gym
  66. >>> from minigrid.wrappers import ActionBonus
  67. >>> env = gym.make("MiniGrid-Empty-5x5-v0")
  68. >>> _, _ = env.reset(seed=0)
  69. >>> _, reward, _, _, _ = env.step(1)
  70. >>> print(reward)
  71. 0
  72. >>> _, reward, _, _, _ = env.step(1)
  73. >>> print(reward)
  74. 0
  75. >>> env_bonus = ActionBonus(env)
  76. >>> _, _ = env_bonus.reset(seed=0)
  77. >>> _, reward, _, _, _ = env_bonus.step(1)
  78. >>> print(reward)
  79. 1.0
  80. >>> _, reward, _, _, _ = env_bonus.step(1)
  81. >>> print(reward)
  82. 1.0
  83. """
  84. def __init__(self, env):
  85. """A wrapper that adds an exploration bonus to less visited (state,action) pairs.
  86. Args:
  87. env: The environment to apply the wrapper
  88. """
  89. super().__init__(env)
  90. self.counts = {}
  91. def step(self, action):
  92. """Steps through the environment with `action`."""
  93. obs, reward, terminated, truncated, info = self.env.step(action)
  94. env = self.unwrapped
  95. tup = (tuple(env.agent_pos), env.agent_dir, action)
  96. # Get the count for this (s,a) pair
  97. pre_count = 0
  98. if tup in self.counts:
  99. pre_count = self.counts[tup]
  100. # Update the count for this (s,a) pair
  101. new_count = pre_count + 1
  102. self.counts[tup] = new_count
  103. bonus = 1 / math.sqrt(new_count)
  104. reward += bonus
  105. return obs, reward, terminated, truncated, info
  106. class PositionBonus(Wrapper):
  107. """
  108. Adds an exploration bonus based on which positions
  109. are visited on the grid.
  110. Note:
  111. This wrapper was previously called ``StateBonus``.
  112. Example:
  113. >>> import gymnasium as gym
  114. >>> from minigrid.wrappers import PositionBonus
  115. >>> env = gym.make("MiniGrid-Empty-5x5-v0")
  116. >>> _, _ = env.reset(seed=0)
  117. >>> _, reward, _, _, _ = env.step(1)
  118. >>> print(reward)
  119. 0
  120. >>> _, reward, _, _, _ = env.step(1)
  121. >>> print(reward)
  122. 0
  123. >>> env_bonus = PositionBonus(env)
  124. >>> obs, _ = env_bonus.reset(seed=0)
  125. >>> obs, reward, terminated, truncated, info = env_bonus.step(1)
  126. >>> print(reward)
  127. 1.0
  128. >>> obs, reward, terminated, truncated, info = env_bonus.step(1)
  129. >>> print(reward)
  130. 0.7071067811865475
  131. """
  132. def __init__(self, env):
  133. """A wrapper that adds an exploration bonus to less visited positions.
  134. Args:
  135. env: The environment to apply the wrapper
  136. """
  137. super().__init__(env)
  138. self.counts = {}
  139. def step(self, action):
  140. """Steps through the environment with `action`."""
  141. obs, reward, terminated, truncated, info = self.env.step(action)
  142. # Tuple based on which we index the counts
  143. # We use the position after an update
  144. env = self.unwrapped
  145. tup = tuple(env.agent_pos)
  146. # Get the count for this key
  147. pre_count = 0
  148. if tup in self.counts:
  149. pre_count = self.counts[tup]
  150. # Update the count for this key
  151. new_count = pre_count + 1
  152. self.counts[tup] = new_count
  153. bonus = 1 / math.sqrt(new_count)
  154. reward += bonus
  155. return obs, reward, terminated, truncated, info
  156. class ImgObsWrapper(ObservationWrapper):
  157. """
  158. Use the image as the only observation output, no language/mission.
  159. Example:
  160. >>> import gymnasium as gym
  161. >>> from minigrid.wrappers import ImgObsWrapper
  162. >>> env = gym.make("MiniGrid-Empty-5x5-v0")
  163. >>> obs, _ = env.reset()
  164. >>> obs.keys()
  165. dict_keys(['image', 'direction', 'mission'])
  166. >>> env = ImgObsWrapper(env)
  167. >>> obs, _ = env.reset()
  168. >>> obs.shape
  169. (7, 7, 3)
  170. """
  171. def __init__(self, env):
  172. """A wrapper that makes image the only observation.
  173. Args:
  174. env: The environment to apply the wrapper
  175. """
  176. super().__init__(env)
  177. self.observation_space = env.observation_space.spaces["image"]
  178. def observation(self, obs):
  179. return obs["image"]
  180. class OneHotPartialObsWrapper(ObservationWrapper):
  181. """
  182. Wrapper to get a one-hot encoding of a partially observable
  183. agent view as observation.
  184. Example:
  185. >>> import gymnasium as gym
  186. >>> from minigrid.wrappers import OneHotPartialObsWrapper
  187. >>> env = gym.make("MiniGrid-Empty-5x5-v0")
  188. >>> obs, _ = env.reset()
  189. >>> obs["image"][0, :, :]
  190. array([[2, 5, 0],
  191. [2, 5, 0],
  192. [2, 5, 0],
  193. [2, 5, 0],
  194. [2, 5, 0],
  195. [2, 5, 0],
  196. [2, 5, 0]], dtype=uint8)
  197. >>> env = OneHotPartialObsWrapper(env)
  198. >>> obs, _ = env.reset()
  199. >>> obs["image"][0, :, :]
  200. array([[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
  201. [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
  202. [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
  203. [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
  204. [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
  205. [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
  206. [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0]],
  207. dtype=uint8)
  208. """
  209. def __init__(self, env, tile_size=8):
  210. """A wrapper that makes the image observation a one-hot encoding of a partially observable agent view.
  211. Args:
  212. env: The environment to apply the wrapper
  213. """
  214. super().__init__(env)
  215. self.tile_size = tile_size
  216. obs_shape = env.observation_space["image"].shape
  217. # Number of bits per cell
  218. num_bits = len(OBJECT_TO_IDX) + len(COLOR_TO_IDX) + len(STATE_TO_IDX)
  219. new_image_space = spaces.Box(
  220. low=0, high=255, shape=(obs_shape[0], obs_shape[1], num_bits), dtype="uint8"
  221. )
  222. self.observation_space = spaces.Dict(
  223. {**self.observation_space.spaces, "image": new_image_space}
  224. )
  225. def observation(self, obs):
  226. img = obs["image"]
  227. out = np.zeros(self.observation_space.spaces["image"].shape, dtype="uint8")
  228. for i in range(img.shape[0]):
  229. for j in range(img.shape[1]):
  230. type = img[i, j, 0]
  231. color = img[i, j, 1]
  232. state = img[i, j, 2]
  233. out[i, j, type] = 1
  234. out[i, j, len(OBJECT_TO_IDX) + color] = 1
  235. out[i, j, len(OBJECT_TO_IDX) + len(COLOR_TO_IDX) + state] = 1
  236. return {**obs, "image": out}
  237. class RGBImgObsWrapper(ObservationWrapper):
  238. """
  239. Wrapper to use fully observable RGB image as observation,
  240. This can be used to have the agent to solve the gridworld in pixel space.
  241. Example:
  242. >>> import gymnasium as gym
  243. >>> import matplotlib.pyplot as plt
  244. >>> from minigrid.wrappers import RGBImgObsWrapper
  245. >>> env = gym.make("MiniGrid-Empty-5x5-v0")
  246. >>> obs, _ = env.reset()
  247. >>> plt.imshow(obs['image']) # doctest: +SKIP
  248. ![NoWrapper](../figures/lavacrossing_NoWrapper.png)
  249. >>> env = RGBImgObsWrapper(env)
  250. >>> obs, _ = env.reset()
  251. >>> plt.imshow(obs['image']) # doctest: +SKIP
  252. ![RGBImgObsWrapper](../figures/lavacrossing_RGBImgObsWrapper.png)
  253. """
  254. def __init__(self, env, tile_size=8):
  255. super().__init__(env)
  256. self.tile_size = tile_size
  257. new_image_space = spaces.Box(
  258. low=0,
  259. high=255,
  260. shape=(
  261. self.unwrapped.width * tile_size,
  262. self.unwrapped.height * tile_size,
  263. 3,
  264. ),
  265. dtype="uint8",
  266. )
  267. self.observation_space = spaces.Dict(
  268. {**self.observation_space.spaces, "image": new_image_space}
  269. )
  270. def observation(self, obs):
  271. rgb_img = self.get_frame(
  272. highlight=self.unwrapped.highlight, tile_size=self.tile_size
  273. )
  274. return {**obs, "image": rgb_img}
  275. class RGBImgPartialObsWrapper(ObservationWrapper):
  276. """
  277. Wrapper to use partially observable RGB image as observation.
  278. This can be used to have the agent to solve the gridworld in pixel space.
  279. Example:
  280. >>> import gymnasium as gym
  281. >>> import matplotlib.pyplot as plt
  282. >>> from minigrid.wrappers import RGBImgObsWrapper, RGBImgPartialObsWrapper
  283. >>> env = gym.make("MiniGrid-LavaCrossingS11N5-v0")
  284. >>> obs, _ = env.reset()
  285. >>> plt.imshow(obs["image"]) # doctest: +SKIP
  286. ![NoWrapper](../figures/lavacrossing_NoWrapper.png)
  287. >>> env_obs = RGBImgObsWrapper(env)
  288. >>> obs, _ = env_obs.reset()
  289. >>> plt.imshow(obs["image"]) # doctest: +SKIP
  290. ![RGBImgObsWrapper](../figures/lavacrossing_RGBImgObsWrapper.png)
  291. >>> env_obs = RGBImgPartialObsWrapper(env)
  292. >>> obs, _ = env_obs.reset()
  293. >>> plt.imshow(obs["image"]) # doctest: +SKIP
  294. ![RGBImgPartialObsWrapper](../figures/lavacrossing_RGBImgPartialObsWrapper.png)
  295. """
  296. def __init__(self, env, tile_size=8):
  297. super().__init__(env)
  298. # Rendering attributes for observations
  299. self.tile_size = tile_size
  300. obs_shape = env.observation_space.spaces["image"].shape
  301. new_image_space = spaces.Box(
  302. low=0,
  303. high=255,
  304. shape=(obs_shape[0] * tile_size, obs_shape[1] * tile_size, 3),
  305. dtype="uint8",
  306. )
  307. self.observation_space = spaces.Dict(
  308. {**self.observation_space.spaces, "image": new_image_space}
  309. )
  310. def observation(self, obs):
  311. rgb_img_partial = self.get_frame(tile_size=self.tile_size, agent_pov=True)
  312. return {**obs, "image": rgb_img_partial}
  313. class FullyObsWrapper(ObservationWrapper):
  314. """
  315. Fully observable gridworld using a compact grid encoding instead of the agent view.
  316. Example:
  317. >>> import gymnasium as gym
  318. >>> import matplotlib.pyplot as plt
  319. >>> from minigrid.wrappers import FullyObsWrapper
  320. >>> env = gym.make("MiniGrid-LavaCrossingS11N5-v0")
  321. >>> obs, _ = env.reset()
  322. >>> obs['image'].shape
  323. (7, 7, 3)
  324. >>> env_obs = FullyObsWrapper(env)
  325. >>> obs, _ = env_obs.reset()
  326. >>> obs['image'].shape
  327. (11, 11, 3)
  328. """
  329. def __init__(self, env):
  330. super().__init__(env)
  331. new_image_space = spaces.Box(
  332. low=0,
  333. high=255,
  334. shape=(self.env.width, self.env.height, 3), # number of cells
  335. dtype="uint8",
  336. )
  337. self.observation_space = spaces.Dict(
  338. {**self.observation_space.spaces, "image": new_image_space}
  339. )
  340. def observation(self, obs):
  341. env = self.unwrapped
  342. full_grid = env.grid.encode()
  343. full_grid[env.agent_pos[0]][env.agent_pos[1]] = np.array(
  344. [OBJECT_TO_IDX["agent"], COLOR_TO_IDX["red"], env.agent_dir]
  345. )
  346. return {**obs, "image": full_grid}
  347. class DictObservationSpaceWrapper(ObservationWrapper):
  348. """
  349. Transforms the observation space (that has a textual component) to a fully numerical observation space,
  350. where the textual instructions are replaced by arrays representing the indices of each word in a fixed vocabulary.
  351. This wrapper is not applicable to BabyAI environments, given that these have their own language component.
  352. Example:
  353. >>> import gymnasium as gym
  354. >>> import matplotlib.pyplot as plt
  355. >>> from minigrid.wrappers import DictObservationSpaceWrapper
  356. >>> env = gym.make("MiniGrid-LavaCrossingS11N5-v0")
  357. >>> obs, _ = env.reset()
  358. >>> obs['mission']
  359. 'avoid the lava and get to the green goal square'
  360. >>> env_obs = DictObservationSpaceWrapper(env)
  361. >>> obs, _ = env_obs.reset()
  362. >>> obs['mission'][:10]
  363. [19, 31, 17, 36, 20, 38, 31, 2, 15, 35]
  364. """
  365. def __init__(self, env, max_words_in_mission=50, word_dict=None):
  366. """
  367. max_words_in_mission is the length of the array to represent a mission, value 0 for missing words
  368. word_dict is a dictionary of words to use (keys=words, values=indices from 1 to < max_words_in_mission),
  369. if None, use the Minigrid language
  370. """
  371. super().__init__(env)
  372. if word_dict is None:
  373. word_dict = self.get_minigrid_words()
  374. self.max_words_in_mission = max_words_in_mission
  375. self.word_dict = word_dict
  376. self.observation_space = spaces.Dict(
  377. {
  378. "image": env.observation_space["image"],
  379. "direction": spaces.Discrete(4),
  380. "mission": spaces.MultiDiscrete(
  381. [len(self.word_dict.keys())] * max_words_in_mission
  382. ),
  383. }
  384. )
  385. @staticmethod
  386. def get_minigrid_words():
  387. colors = ["red", "green", "blue", "yellow", "purple", "grey"]
  388. objects = [
  389. "unseen",
  390. "empty",
  391. "wall",
  392. "floor",
  393. "box",
  394. "key",
  395. "ball",
  396. "door",
  397. "goal",
  398. "agent",
  399. "lava",
  400. ]
  401. verbs = [
  402. "pick",
  403. "avoid",
  404. "get",
  405. "find",
  406. "put",
  407. "use",
  408. "open",
  409. "go",
  410. "fetch",
  411. "reach",
  412. "unlock",
  413. "traverse",
  414. ]
  415. extra_words = [
  416. "up",
  417. "the",
  418. "a",
  419. "at",
  420. ",",
  421. "square",
  422. "and",
  423. "then",
  424. "to",
  425. "of",
  426. "rooms",
  427. "near",
  428. "opening",
  429. "must",
  430. "you",
  431. "matching",
  432. "end",
  433. "hallway",
  434. "object",
  435. "from",
  436. "room",
  437. "maze",
  438. ]
  439. all_words = colors + objects + verbs + extra_words
  440. assert len(all_words) == len(set(all_words))
  441. return {word: i for i, word in enumerate(all_words)}
  442. def string_to_indices(self, string, offset=1):
  443. """
  444. Convert a string to a list of indices.
  445. """
  446. indices = []
  447. # adding space before and after commas
  448. string = string.replace(",", " , ")
  449. for word in string.split():
  450. if word in self.word_dict.keys():
  451. indices.append(self.word_dict[word] + offset)
  452. else:
  453. raise ValueError(f"Unknown word: {word}")
  454. return indices
  455. def observation(self, obs):
  456. obs["mission"] = self.string_to_indices(obs["mission"])
  457. assert len(obs["mission"]) < self.max_words_in_mission
  458. obs["mission"] += [0] * (self.max_words_in_mission - len(obs["mission"]))
  459. return obs
  460. class FlatObsWrapper(ObservationWrapper):
  461. """
  462. Encode mission strings using a one-hot scheme,
  463. and combine these with observed images into one flat array.
  464. This wrapper is not applicable to BabyAI environments, given that these have their own language component.
  465. Example:
  466. >>> import gymnasium as gym
  467. >>> import matplotlib.pyplot as plt
  468. >>> from minigrid.wrappers import FlatObsWrapper
  469. >>> env = gym.make("MiniGrid-LavaCrossingS11N5-v0")
  470. >>> env_obs = FlatObsWrapper(env)
  471. >>> obs, _ = env_obs.reset()
  472. >>> obs.shape
  473. (2835,)
  474. """
  475. def __init__(self, env, maxStrLen=96):
  476. super().__init__(env)
  477. self.maxStrLen = maxStrLen
  478. self.numCharCodes = 28
  479. imgSpace = env.observation_space.spaces["image"]
  480. imgSize = reduce(operator.mul, imgSpace.shape, 1)
  481. self.observation_space = spaces.Box(
  482. low=0,
  483. high=255,
  484. shape=(imgSize + self.numCharCodes * self.maxStrLen,),
  485. dtype="uint8",
  486. )
  487. self.cachedStr: str = None
  488. def observation(self, obs):
  489. image = obs["image"]
  490. mission = obs["mission"]
  491. # Cache the last-encoded mission string
  492. if mission != self.cachedStr:
  493. assert (
  494. len(mission) <= self.maxStrLen
  495. ), f"mission string too long ({len(mission)} chars)"
  496. mission = mission.lower()
  497. strArray = np.zeros(
  498. shape=(self.maxStrLen, self.numCharCodes), dtype="float32"
  499. )
  500. for idx, ch in enumerate(mission):
  501. if ch >= "a" and ch <= "z":
  502. chNo = ord(ch) - ord("a")
  503. elif ch == " ":
  504. chNo = ord("z") - ord("a") + 1
  505. elif ch == ",":
  506. chNo = ord("z") - ord("a") + 2
  507. else:
  508. raise ValueError(
  509. f"Character {ch} is not available in mission string."
  510. )
  511. assert chNo < self.numCharCodes, "%s : %d" % (ch, chNo)
  512. strArray[idx, chNo] = 1
  513. self.cachedStr = mission
  514. self.cachedArray = strArray
  515. obs = np.concatenate((image.flatten(), self.cachedArray.flatten()))
  516. return obs
  517. class ViewSizeWrapper(ObservationWrapper):
  518. """
  519. Wrapper to customize the agent field of view size.
  520. This cannot be used with fully observable wrappers.
  521. Example:
  522. >>> import gymnasium as gym
  523. >>> from minigrid.wrappers import ViewSizeWrapper
  524. >>> env = gym.make("MiniGrid-LavaCrossingS11N5-v0")
  525. >>> obs, _ = env.reset()
  526. >>> obs['image'].shape
  527. (7, 7, 3)
  528. >>> env_obs = ViewSizeWrapper(env, agent_view_size=5)
  529. >>> obs, _ = env_obs.reset()
  530. >>> obs['image'].shape
  531. (5, 5, 3)
  532. """
  533. def __init__(self, env, agent_view_size=7):
  534. super().__init__(env)
  535. assert agent_view_size % 2 == 1
  536. assert agent_view_size >= 3
  537. self.agent_view_size = agent_view_size
  538. # Compute observation space with specified view size
  539. new_image_space = gym.spaces.Box(
  540. low=0, high=255, shape=(agent_view_size, agent_view_size, 3), dtype="uint8"
  541. )
  542. # Override the environment's observation spaceexit
  543. self.observation_space = spaces.Dict(
  544. {**self.observation_space.spaces, "image": new_image_space}
  545. )
  546. def observation(self, obs):
  547. env = self.unwrapped
  548. grid, vis_mask = env.gen_obs_grid(self.agent_view_size)
  549. # Encode the partially observable view into a numpy array
  550. image = grid.encode(vis_mask)
  551. return {**obs, "image": image}
  552. class DirectionObsWrapper(ObservationWrapper):
  553. """
  554. Provides the slope/angular direction to the goal with the observations as modeled by (y2 - y2 )/( x2 - x1)
  555. type = {slope , angle}
  556. Example:
  557. >>> import gymnasium as gym
  558. >>> import matplotlib.pyplot as plt
  559. >>> from minigrid.wrappers import DirectionObsWrapper
  560. >>> env = gym.make("MiniGrid-LavaCrossingS11N5-v0")
  561. >>> env_obs = DirectionObsWrapper(env, type="slope")
  562. >>> obs, _ = env_obs.reset()
  563. >>> obs['goal_direction']
  564. 1.0
  565. """
  566. def __init__(self, env, type="slope"):
  567. super().__init__(env)
  568. self.goal_position: tuple = None
  569. self.type = type
  570. def reset(
  571. self, *, seed: int | None = None, options: dict[str, Any] | None = None
  572. ) -> tuple[ObsType, dict[str, Any]]:
  573. obs, info = self.env.reset()
  574. if not self.goal_position:
  575. self.goal_position = [
  576. x for x, y in enumerate(self.grid.grid) if isinstance(y, Goal)
  577. ]
  578. # in case there are multiple goals , needs to be handled for other env types
  579. if len(self.goal_position) >= 1:
  580. self.goal_position = (
  581. int(self.goal_position[0] / self.height),
  582. self.goal_position[0] % self.width,
  583. )
  584. return self.observation(obs), info
  585. def observation(self, obs):
  586. slope = np.divide(
  587. self.goal_position[1] - self.agent_pos[1],
  588. self.goal_position[0] - self.agent_pos[0],
  589. )
  590. if self.type == "angle":
  591. obs["goal_direction"] = np.arctan(slope)
  592. else:
  593. obs["goal_direction"] = slope
  594. return obs
  595. class SymbolicObsWrapper(ObservationWrapper):
  596. """
  597. Fully observable grid with a symbolic state representation.
  598. The symbol is a triple of (X, Y, IDX), where X and Y are
  599. the coordinates on the grid, and IDX is the id of the object.
  600. Example:
  601. >>> import gymnasium as gym
  602. >>> from minigrid.wrappers import SymbolicObsWrapper
  603. >>> env = gym.make("MiniGrid-LavaCrossingS11N5-v0")
  604. >>> obs, _ = env.reset()
  605. >>> obs['image'].shape
  606. (7, 7, 3)
  607. >>> env_obs = SymbolicObsWrapper(env)
  608. >>> obs, _ = env_obs.reset()
  609. >>> obs['image'].shape
  610. (11, 11, 3)
  611. """
  612. def __init__(self, env):
  613. super().__init__(env)
  614. new_image_space = spaces.Box(
  615. low=0,
  616. high=max(OBJECT_TO_IDX.values()),
  617. shape=(self.env.width, self.env.height, 3), # number of cells
  618. dtype="uint8",
  619. )
  620. self.observation_space = spaces.Dict(
  621. {**self.observation_space.spaces, "image": new_image_space}
  622. )
  623. def observation(self, obs):
  624. objects = np.array(
  625. [OBJECT_TO_IDX[o.type] if o is not None else -1 for o in self.grid.grid]
  626. )
  627. agent_pos = self.env.agent_pos
  628. ncol, nrow = self.width, self.height
  629. grid = np.mgrid[:ncol, :nrow]
  630. _objects = np.transpose(objects.reshape(1, nrow, ncol), (0, 2, 1))
  631. grid = np.concatenate([grid, _objects])
  632. grid = np.transpose(grid, (1, 2, 0))
  633. grid[agent_pos[0], agent_pos[1], 2] = OBJECT_TO_IDX["agent"]
  634. obs["image"] = grid
  635. return obs
  636. class StochasticActionWrapper(ActionWrapper):
  637. """
  638. Add stochasticity to the actions
  639. If a random action is provided, it is returned with probability `1 - prob`.
  640. Else, a random action is sampled from the action space.
  641. """
  642. def __init__(self, env=None, prob=0.9, random_action=None):
  643. super().__init__(env)
  644. self.prob = prob
  645. self.random_action = random_action
  646. def action(self, action):
  647. """ """
  648. if np.random.uniform() < self.prob:
  649. return action
  650. else:
  651. if self.random_action is None:
  652. return self.np_random.integers(0, high=6)
  653. else:
  654. return self.random_action
  655. class NoDeath(Wrapper):
  656. """
  657. Wrapper to prevent death in specific cells (e.g., lava cells).
  658. Instead of dying, the agent will receive a negative reward.
  659. Example:
  660. >>> import gymnasium as gym
  661. >>> from minigrid.wrappers import NoDeath
  662. >>>
  663. >>> env = gym.make("MiniGrid-LavaCrossingS9N1-v0")
  664. >>> _, _ = env.reset(seed=2)
  665. >>> _, _, _, _, _ = env.step(1)
  666. >>> _, reward, term, *_ = env.step(2)
  667. >>> reward, term
  668. (0, True)
  669. >>>
  670. >>> env = NoDeath(env, no_death_types=("lava",), death_cost=-1.0)
  671. >>> _, _ = env.reset(seed=2)
  672. >>> _, _, _, _, _ = env.step(1)
  673. >>> _, reward, term, *_ = env.step(2)
  674. >>> reward, term
  675. (-1.0, False)
  676. >>>
  677. >>>
  678. >>> env = gym.make("MiniGrid-Dynamic-Obstacles-5x5-v0")
  679. >>> _, _ = env.reset(seed=2)
  680. >>> _, reward, term, *_ = env.step(2)
  681. >>> reward, term
  682. (-1, True)
  683. >>>
  684. >>> env = NoDeath(env, no_death_types=("ball",), death_cost=-1.0)
  685. >>> _, _ = env.reset(seed=2)
  686. >>> _, reward, term, *_ = env.step(2)
  687. >>> reward, term
  688. (-2.0, False)
  689. """
  690. def __init__(self, env, no_death_types: tuple[str, ...], death_cost: float = -1.0):
  691. """A wrapper to prevent death in specific cells.
  692. Args:
  693. env: The environment to apply the wrapper
  694. no_death_types: List of strings to identify death cells
  695. death_cost: The negative reward received in death cells
  696. """
  697. assert "goal" not in no_death_types, "goal cannot be a death cell"
  698. super().__init__(env)
  699. self.death_cost = death_cost
  700. self.no_death_types = no_death_types
  701. def step(self, action):
  702. # In Dynamic-Obstacles, obstacles move after the agent moves,
  703. # so we need to check for collision before self.env.step()
  704. front_cell = self.grid.get(*self.front_pos)
  705. going_to_death = (
  706. action == self.actions.forward
  707. and front_cell is not None
  708. and front_cell.type in self.no_death_types
  709. )
  710. obs, reward, terminated, truncated, info = self.env.step(action)
  711. # We also check if the agent stays in death cells (e.g., lava)
  712. # without moving
  713. current_cell = self.grid.get(*self.agent_pos)
  714. in_death = current_cell is not None and current_cell.type in self.no_death_types
  715. if terminated and (going_to_death or in_death):
  716. terminated = False
  717. reward += self.death_cost
  718. return obs, reward, terminated, truncated, info