wrappers.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883
  1. from __future__ import annotations
  2. import math
  3. from typing import Any
  4. import gymnasium as gym
  5. import numpy as np
  6. from gymnasium import logger, spaces
  7. from gymnasium.core import ActionWrapper, ObservationWrapper, ObsType, Wrapper
  8. from minigrid.core.constants import COLOR_TO_IDX, OBJECT_TO_IDX, STATE_TO_IDX
  9. from minigrid.core.world_object import Goal
  10. class ReseedWrapper(Wrapper):
  11. """
  12. Wrapper to always regenerate an environment with the same set of seeds.
  13. This can be used to force an environment to always keep the same
  14. configuration when reset.
  15. Example:
  16. >>> import minigrid
  17. >>> import gymnasium as gym
  18. >>> from minigrid.wrappers import ReseedWrapper
  19. >>> env = gym.make("MiniGrid-Empty-5x5-v0")
  20. >>> _ = env.reset(seed=123)
  21. >>> [env.np_random.integers(10).item() for i in range(10)]
  22. [0, 6, 5, 0, 9, 2, 2, 1, 3, 1]
  23. >>> env = ReseedWrapper(env, seeds=[0, 1], seed_idx=0)
  24. >>> _, _ = env.reset()
  25. >>> [env.np_random.integers(10).item() for i in range(10)]
  26. [8, 6, 5, 2, 3, 0, 0, 0, 1, 8]
  27. >>> _, _ = env.reset()
  28. >>> [env.np_random.integers(10).item() for i in range(10)]
  29. [4, 5, 7, 9, 0, 1, 8, 9, 2, 3]
  30. >>> _, _ = env.reset()
  31. >>> [env.np_random.integers(10).item() for i in range(10)]
  32. [8, 6, 5, 2, 3, 0, 0, 0, 1, 8]
  33. >>> _, _ = env.reset()
  34. >>> [env.np_random.integers(10).item() for i in range(10)]
  35. [4, 5, 7, 9, 0, 1, 8, 9, 2, 3]
  36. """
  37. def __init__(self, env, seeds=(0,), seed_idx=0):
  38. """A wrapper that always regenerate an environment with the same set of seeds.
  39. Args:
  40. env: The environment to apply the wrapper
  41. seeds: A list of seed to be applied to the env
  42. seed_idx: Index of the initial seed in seeds
  43. """
  44. self.seeds = list(seeds)
  45. self.seed_idx = seed_idx
  46. super().__init__(env)
  47. def reset(
  48. self, *, seed: int | None = None, options: dict[str, Any] | None = None
  49. ) -> tuple[ObsType, dict[str, Any]]:
  50. if seed is not None:
  51. logger.warn(
  52. "A seed has been passed to `ReseedWrapper.reset` which is ignored."
  53. )
  54. seed = self.seeds[self.seed_idx]
  55. self.seed_idx = (self.seed_idx + 1) % len(self.seeds)
  56. return self.env.reset(seed=seed, options=options)
  57. class ActionBonus(gym.Wrapper):
  58. """
  59. Wrapper which adds an exploration bonus.
  60. This is a reward to encourage exploration of less
  61. visited (state,action) pairs.
  62. Example:
  63. >>> import gymnasium as gym
  64. >>> from minigrid.wrappers import ActionBonus
  65. >>> env = gym.make("MiniGrid-Empty-5x5-v0")
  66. >>> _, _ = env.reset(seed=0)
  67. >>> _, reward, _, _, _ = env.step(1)
  68. >>> print(reward)
  69. 0
  70. >>> _, reward, _, _, _ = env.step(1)
  71. >>> print(reward)
  72. 0
  73. >>> env_bonus = ActionBonus(env)
  74. >>> _, _ = env_bonus.reset(seed=0)
  75. >>> _, reward, _, _, _ = env_bonus.step(1)
  76. >>> print(reward)
  77. 1.0
  78. >>> _, reward, _, _, _ = env_bonus.step(1)
  79. >>> print(reward)
  80. 1.0
  81. """
  82. def __init__(self, env):
  83. """A wrapper that adds an exploration bonus to less visited (state,action) pairs.
  84. Args:
  85. env: The environment to apply the wrapper
  86. """
  87. super().__init__(env)
  88. self.counts = {}
  89. def step(self, action):
  90. """Steps through the environment with `action`."""
  91. obs, reward, terminated, truncated, info = self.env.step(action)
  92. env = self.unwrapped
  93. tup = (tuple(env.agent_pos), env.agent_dir, action)
  94. # Get the count for this (s,a) pair
  95. pre_count = 0
  96. if tup in self.counts:
  97. pre_count = self.counts[tup]
  98. # Update the count for this (s,a) pair
  99. new_count = pre_count + 1
  100. self.counts[tup] = new_count
  101. bonus = 1 / math.sqrt(new_count)
  102. reward += bonus
  103. return obs, reward, terminated, truncated, info
  104. class PositionBonus(Wrapper):
  105. """
  106. Adds a scaled exploration bonus based on which positions
  107. are visited on the grid.
  108. Note:
  109. This wrapper was previously called ``StateBonus``.
  110. Example:
  111. >>> import gymnasium as gym
  112. >>> from minigrid.wrappers import PositionBonus
  113. >>> env = gym.make("MiniGrid-Empty-5x5-v0")
  114. >>> _, _ = env.reset(seed=0)
  115. >>> _, reward, _, _, _ = env.step(1)
  116. >>> print(reward)
  117. 0
  118. >>> _, reward, _, _, _ = env.step(1)
  119. >>> print(reward)
  120. 0
  121. >>> env_bonus = PositionBonus(env, scale=1)
  122. >>> obs, _ = env_bonus.reset(seed=0)
  123. >>> obs, reward, terminated, truncated, info = env_bonus.step(1)
  124. >>> print(reward)
  125. 1.0
  126. >>> obs, reward, terminated, truncated, info = env_bonus.step(1)
  127. >>> print(reward)
  128. 0.7071067811865475
  129. """
  130. def __init__(self, env, scale=1):
  131. """A wrapper that adds an exploration bonus to less visited positions.
  132. Args:
  133. env: The environment to apply the wrapper
  134. """
  135. super().__init__(env)
  136. self.counts = {}
  137. self.scale = 1
  138. def step(self, action):
  139. """Steps through the environment with `action`."""
  140. obs, reward, terminated, truncated, info = self.env.step(action)
  141. # Tuple based on which we index the counts
  142. # We use the position after an update
  143. env = self.unwrapped
  144. tup = tuple(env.agent_pos)
  145. # Get the count for this key
  146. pre_count = self.counts.get(tup, 0)
  147. # Update the count for this key
  148. new_count = pre_count + 1
  149. self.counts[tup] = new_count
  150. bonus = 1 / math.sqrt(new_count)
  151. reward += bonus * self.scale
  152. return obs, reward, terminated, truncated, info
  153. class ImgObsWrapper(ObservationWrapper):
  154. """
  155. Use the image as the only observation output, no language/mission.
  156. Example:
  157. >>> import gymnasium as gym
  158. >>> from minigrid.wrappers import ImgObsWrapper
  159. >>> env = gym.make("MiniGrid-Empty-5x5-v0")
  160. >>> obs, _ = env.reset()
  161. >>> obs.keys()
  162. dict_keys(['image', 'direction', 'mission'])
  163. >>> env = ImgObsWrapper(env)
  164. >>> obs, _ = env.reset()
  165. >>> obs.shape
  166. (7, 7, 3)
  167. """
  168. def __init__(self, env):
  169. """A wrapper that makes image the only observation.
  170. Args:
  171. env: The environment to apply the wrapper
  172. """
  173. super().__init__(env)
  174. self.observation_space = env.observation_space.spaces["image"]
  175. def observation(self, obs):
  176. return obs["image"]
  177. class OneHotPartialObsWrapper(ObservationWrapper):
  178. """
  179. Wrapper to get a one-hot encoding of a partially observable
  180. agent view as observation.
  181. Example:
  182. >>> import gymnasium as gym
  183. >>> from minigrid.wrappers import OneHotPartialObsWrapper
  184. >>> env = gym.make("MiniGrid-Empty-5x5-v0")
  185. >>> obs, _ = env.reset()
  186. >>> obs["image"][0, :, :]
  187. array([[2, 5, 0],
  188. [2, 5, 0],
  189. [2, 5, 0],
  190. [2, 5, 0],
  191. [2, 5, 0],
  192. [2, 5, 0],
  193. [2, 5, 0]], dtype=uint8)
  194. >>> env = OneHotPartialObsWrapper(env)
  195. >>> obs, _ = env.reset()
  196. >>> obs["image"][0, :, :]
  197. array([[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
  198. [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
  199. [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
  200. [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
  201. [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
  202. [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
  203. [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0]],
  204. dtype=uint8)
  205. """
  206. def __init__(self, env, tile_size=8):
  207. """A wrapper that makes the image observation a one-hot encoding of a partially observable agent view.
  208. Args:
  209. env: The environment to apply the wrapper
  210. """
  211. super().__init__(env)
  212. self.tile_size = tile_size
  213. obs_shape = env.observation_space["image"].shape
  214. # Number of bits per cell
  215. num_bits = len(OBJECT_TO_IDX) + len(COLOR_TO_IDX) + len(STATE_TO_IDX)
  216. new_image_space = spaces.Box(
  217. low=0, high=255, shape=(obs_shape[0], obs_shape[1], num_bits), dtype="uint8"
  218. )
  219. self.observation_space = spaces.Dict(
  220. {**self.observation_space.spaces, "image": new_image_space}
  221. )
  222. def observation(self, obs):
  223. img = obs["image"]
  224. out = np.zeros(self.observation_space.spaces["image"].shape, dtype="uint8")
  225. for i in range(img.shape[0]):
  226. for j in range(img.shape[1]):
  227. type = img[i, j, 0]
  228. color = img[i, j, 1]
  229. state = img[i, j, 2]
  230. out[i, j, type] = 1
  231. out[i, j, len(OBJECT_TO_IDX) + color] = 1
  232. out[i, j, len(OBJECT_TO_IDX) + len(COLOR_TO_IDX) + state] = 1
  233. return {**obs, "image": out}
  234. class RGBImgObsWrapper(ObservationWrapper):
  235. """
  236. Wrapper to use fully observable RGB image as observation,
  237. This can be used to have the agent to solve the gridworld in pixel space.
  238. Example:
  239. >>> import gymnasium as gym
  240. >>> import matplotlib.pyplot as plt
  241. >>> from minigrid.wrappers import RGBImgObsWrapper
  242. >>> env = gym.make("MiniGrid-Empty-5x5-v0")
  243. >>> obs, _ = env.reset()
  244. >>> plt.imshow(obs['image']) # doctest: +SKIP
  245. ![NoWrapper](../figures/lavacrossing_NoWrapper.png)
  246. >>> env = RGBImgObsWrapper(env)
  247. >>> obs, _ = env.reset()
  248. >>> plt.imshow(obs['image']) # doctest: +SKIP
  249. ![RGBImgObsWrapper](../figures/lavacrossing_RGBImgObsWrapper.png)
  250. """
  251. def __init__(self, env, tile_size=8):
  252. super().__init__(env)
  253. self.tile_size = tile_size
  254. new_image_space = spaces.Box(
  255. low=0,
  256. high=255,
  257. shape=(
  258. self.unwrapped.height * tile_size,
  259. self.unwrapped.width * tile_size,
  260. 3,
  261. ),
  262. dtype="uint8",
  263. )
  264. self.observation_space = spaces.Dict(
  265. {**self.observation_space.spaces, "image": new_image_space}
  266. )
  267. def observation(self, obs):
  268. rgb_img = self.unwrapped.get_frame(
  269. highlight=self.unwrapped.highlight, tile_size=self.tile_size
  270. )
  271. return {**obs, "image": rgb_img}
  272. class RGBImgPartialObsWrapper(ObservationWrapper):
  273. """
  274. Wrapper to use partially observable RGB image as observation.
  275. This can be used to have the agent to solve the gridworld in pixel space.
  276. Example:
  277. >>> import gymnasium as gym
  278. >>> import matplotlib.pyplot as plt
  279. >>> from minigrid.wrappers import RGBImgObsWrapper, RGBImgPartialObsWrapper
  280. >>> env = gym.make("MiniGrid-LavaCrossingS11N5-v0")
  281. >>> obs, _ = env.reset()
  282. >>> plt.imshow(obs["image"]) # doctest: +SKIP
  283. ![NoWrapper](../figures/lavacrossing_NoWrapper.png)
  284. >>> env_obs = RGBImgObsWrapper(env)
  285. >>> obs, _ = env_obs.reset()
  286. >>> plt.imshow(obs["image"]) # doctest: +SKIP
  287. ![RGBImgObsWrapper](../figures/lavacrossing_RGBImgObsWrapper.png)
  288. >>> env_obs = RGBImgPartialObsWrapper(env)
  289. >>> obs, _ = env_obs.reset()
  290. >>> plt.imshow(obs["image"]) # doctest: +SKIP
  291. ![RGBImgPartialObsWrapper](../figures/lavacrossing_RGBImgPartialObsWrapper.png)
  292. """
  293. def __init__(self, env, tile_size=8):
  294. super().__init__(env)
  295. # Rendering attributes for observations
  296. self.tile_size = tile_size
  297. obs_shape = env.observation_space.spaces["image"].shape
  298. new_image_space = spaces.Box(
  299. low=0,
  300. high=255,
  301. shape=(obs_shape[0] * tile_size, obs_shape[1] * tile_size, 3),
  302. dtype="uint8",
  303. )
  304. self.observation_space = spaces.Dict(
  305. {**self.observation_space.spaces, "image": new_image_space}
  306. )
  307. def observation(self, obs):
  308. rgb_img_partial = self.unwrapped.get_frame(
  309. tile_size=self.tile_size, agent_pov=True
  310. )
  311. return {**obs, "image": rgb_img_partial}
  312. class FullyObsWrapper(ObservationWrapper):
  313. """
  314. Fully observable gridworld using a compact grid encoding instead of the agent view.
  315. Example:
  316. >>> import gymnasium as gym
  317. >>> import matplotlib.pyplot as plt
  318. >>> from minigrid.wrappers import FullyObsWrapper
  319. >>> env = gym.make("MiniGrid-LavaCrossingS11N5-v0")
  320. >>> obs, _ = env.reset()
  321. >>> obs['image'].shape
  322. (7, 7, 3)
  323. >>> env_obs = FullyObsWrapper(env)
  324. >>> obs, _ = env_obs.reset()
  325. >>> obs['image'].shape
  326. (11, 11, 3)
  327. """
  328. def __init__(self, env):
  329. super().__init__(env)
  330. new_image_space = spaces.Box(
  331. low=0,
  332. high=255,
  333. shape=(
  334. self.env.unwrapped.width,
  335. self.env.unwrapped.height,
  336. 3,
  337. ), # number of cells
  338. dtype="uint8",
  339. )
  340. self.observation_space = spaces.Dict(
  341. {**self.observation_space.spaces, "image": new_image_space}
  342. )
  343. def observation(self, obs):
  344. env = self.unwrapped
  345. full_grid = env.grid.encode()
  346. full_grid[env.agent_pos[0]][env.agent_pos[1]] = np.array(
  347. [OBJECT_TO_IDX["agent"], COLOR_TO_IDX["red"], env.agent_dir]
  348. )
  349. return {**obs, "image": full_grid}
  350. class DictObservationSpaceWrapper(ObservationWrapper):
  351. """
  352. Transforms the observation space (that has a textual component) to a fully numerical observation space,
  353. where the textual instructions are replaced by arrays representing the indices of each word in a fixed vocabulary.
  354. This wrapper is not applicable to BabyAI environments, given that these have their own language component.
  355. Example:
  356. >>> import gymnasium as gym
  357. >>> import matplotlib.pyplot as plt
  358. >>> from minigrid.wrappers import DictObservationSpaceWrapper
  359. >>> env = gym.make("MiniGrid-LavaCrossingS11N5-v0")
  360. >>> obs, _ = env.reset()
  361. >>> obs['mission']
  362. 'avoid the lava and get to the green goal square'
  363. >>> env_obs = DictObservationSpaceWrapper(env)
  364. >>> obs, _ = env_obs.reset()
  365. >>> obs['mission'][:10]
  366. [19, 31, 17, 36, 20, 38, 31, 2, 15, 35]
  367. """
  368. def __init__(self, env, max_words_in_mission=50, word_dict=None):
  369. """
  370. max_words_in_mission is the length of the array to represent a mission, value 0 for missing words
  371. word_dict is a dictionary of words to use (keys=words, values=indices from 1 to < max_words_in_mission),
  372. if None, use the Minigrid language
  373. """
  374. super().__init__(env)
  375. if word_dict is None:
  376. word_dict = self.get_minigrid_words()
  377. self.max_words_in_mission = max_words_in_mission
  378. self.word_dict = word_dict
  379. self.observation_space = spaces.Dict(
  380. {
  381. "image": env.observation_space["image"],
  382. "direction": spaces.Discrete(4),
  383. "mission": spaces.MultiDiscrete(
  384. [len(self.word_dict.keys())] * max_words_in_mission
  385. ),
  386. }
  387. )
  388. @staticmethod
  389. def get_minigrid_words():
  390. colors = ["red", "green", "blue", "yellow", "purple", "grey"]
  391. objects = [
  392. "unseen",
  393. "empty",
  394. "wall",
  395. "floor",
  396. "box",
  397. "key",
  398. "ball",
  399. "door",
  400. "goal",
  401. "agent",
  402. "lava",
  403. ]
  404. verbs = [
  405. "pick",
  406. "avoid",
  407. "get",
  408. "find",
  409. "put",
  410. "use",
  411. "open",
  412. "go",
  413. "fetch",
  414. "reach",
  415. "unlock",
  416. "traverse",
  417. ]
  418. extra_words = [
  419. "up",
  420. "the",
  421. "a",
  422. "at",
  423. ",",
  424. "square",
  425. "and",
  426. "then",
  427. "to",
  428. "of",
  429. "rooms",
  430. "near",
  431. "opening",
  432. "must",
  433. "you",
  434. "matching",
  435. "end",
  436. "hallway",
  437. "object",
  438. "from",
  439. "room",
  440. "maze",
  441. ]
  442. all_words = colors + objects + verbs + extra_words
  443. assert len(all_words) == len(set(all_words))
  444. return {word: i for i, word in enumerate(all_words)}
  445. def string_to_indices(self, string, offset=1):
  446. """
  447. Convert a string to a list of indices.
  448. """
  449. indices = []
  450. # adding space before and after commas
  451. string = string.replace(",", " , ")
  452. for word in string.split():
  453. if word in self.word_dict.keys():
  454. indices.append(self.word_dict[word] + offset)
  455. else:
  456. raise ValueError(f"Unknown word: {word}")
  457. return indices
  458. def observation(self, obs):
  459. obs["mission"] = self.string_to_indices(obs["mission"])
  460. assert len(obs["mission"]) < self.max_words_in_mission
  461. obs["mission"] += [0] * (self.max_words_in_mission - len(obs["mission"]))
  462. return obs
  463. class FlatObsWrapper(ObservationWrapper):
  464. """
  465. Encode mission strings using a one-hot scheme,
  466. and combine these with observed images into one flat array.
  467. This wrapper is not applicable to BabyAI environments, given that these have their own language component.
  468. Example:
  469. >>> import gymnasium as gym
  470. >>> import matplotlib.pyplot as plt
  471. >>> from minigrid.wrappers import FlatObsWrapper
  472. >>> env = gym.make("MiniGrid-LavaCrossingS11N5-v0")
  473. >>> env_obs = FlatObsWrapper(env)
  474. >>> obs, _ = env_obs.reset()
  475. >>> obs.shape
  476. (2835,)
  477. """
  478. def __init__(self, env, maxStrLen: int = 96):
  479. super().__init__(env)
  480. self.maxStrLen = maxStrLen
  481. self.numCharCodes = 28
  482. img_size = np.prod(env.observation_space["image"].shape)
  483. self.observation_space = spaces.Box(
  484. low=0,
  485. high=255,
  486. shape=(img_size + self.numCharCodes * self.maxStrLen,),
  487. dtype="uint8",
  488. )
  489. self.cachedStr: str = None
  490. def observation(self, obs):
  491. image = obs["image"]
  492. mission = obs["mission"]
  493. # Cache the last-encoded mission string
  494. if mission != self.cachedStr:
  495. assert (
  496. len(mission) <= self.maxStrLen
  497. ), f"mission string too long ({len(mission)} chars)"
  498. mission = mission.lower()
  499. str_array = np.zeros(
  500. shape=(self.maxStrLen, self.numCharCodes), dtype="uint8"
  501. )
  502. # as `numCharCodes` < 255 then we can use `uint8`
  503. for idx, ch in enumerate(mission):
  504. if "a" <= ch <= "z":
  505. chNo = ord(ch) - ord("a")
  506. elif ch == " ":
  507. chNo = ord("z") - ord("a") + 1
  508. elif ch == ",":
  509. chNo = ord("z") - ord("a") + 2
  510. else:
  511. raise ValueError(
  512. f"Character {ch} is not available in mission string."
  513. )
  514. assert chNo < self.numCharCodes, f"{ch} : {chNo:d}"
  515. str_array[idx, chNo] = 1
  516. self.cachedStr = mission
  517. self.cachedArray = str_array
  518. obs = np.concatenate((image.flatten(), self.cachedArray.flatten()))
  519. return obs
  520. class ViewSizeWrapper(ObservationWrapper):
  521. """
  522. Wrapper to customize the agent field of view size.
  523. This cannot be used with fully observable wrappers.
  524. Example:
  525. >>> import gymnasium as gym
  526. >>> from minigrid.wrappers import ViewSizeWrapper
  527. >>> env = gym.make("MiniGrid-LavaCrossingS11N5-v0")
  528. >>> obs, _ = env.reset()
  529. >>> obs['image'].shape
  530. (7, 7, 3)
  531. >>> env_obs = ViewSizeWrapper(env, agent_view_size=5)
  532. >>> obs, _ = env_obs.reset()
  533. >>> obs['image'].shape
  534. (5, 5, 3)
  535. """
  536. def __init__(self, env, agent_view_size=7):
  537. super().__init__(env)
  538. assert agent_view_size % 2 == 1
  539. assert agent_view_size >= 3
  540. self.agent_view_size = agent_view_size
  541. # Compute observation space with specified view size
  542. new_image_space = gym.spaces.Box(
  543. low=0, high=255, shape=(agent_view_size, agent_view_size, 3), dtype="uint8"
  544. )
  545. # Override the environment's observation spaceexit
  546. self.observation_space = spaces.Dict(
  547. {**self.observation_space.spaces, "image": new_image_space}
  548. )
  549. def observation(self, obs):
  550. env = self.unwrapped
  551. grid, vis_mask = env.gen_obs_grid(self.agent_view_size)
  552. # Encode the partially observable view into a numpy array
  553. image = grid.encode(vis_mask)
  554. return {**obs, "image": image}
  555. class DirectionObsWrapper(ObservationWrapper):
  556. """
  557. Provides the slope/angular direction to the goal with the observations as modeled by (y2 - y2 )/( x2 - x1)
  558. type = {slope , angle}
  559. Example:
  560. >>> import gymnasium as gym
  561. >>> import matplotlib.pyplot as plt
  562. >>> from minigrid.wrappers import DirectionObsWrapper
  563. >>> env = gym.make("MiniGrid-LavaCrossingS11N5-v0")
  564. >>> env_obs = DirectionObsWrapper(env, type="slope")
  565. >>> obs, _ = env_obs.reset()
  566. >>> obs['goal_direction'].item()
  567. 1.0
  568. """
  569. def __init__(self, env, type="slope"):
  570. super().__init__(env)
  571. self.goal_position: tuple = None
  572. self.type = type
  573. def reset(
  574. self, *, seed: int | None = None, options: dict[str, Any] | None = None
  575. ) -> tuple[ObsType, dict[str, Any]]:
  576. obs, info = self.env.reset()
  577. if not self.goal_position:
  578. self.goal_position = [
  579. x for x, y in enumerate(self.unwrapped.grid.grid) if isinstance(y, Goal)
  580. ]
  581. # in case there are multiple goals , needs to be handled for other env types
  582. if len(self.goal_position) >= 1:
  583. self.goal_position = (
  584. int(self.goal_position[0] / self.unwrapped.height),
  585. self.goal_position[0] % self.unwrapped.width,
  586. )
  587. return self.observation(obs), info
  588. def observation(self, obs):
  589. slope = np.divide(
  590. self.goal_position[1] - self.unwrapped.agent_pos[1],
  591. self.goal_position[0] - self.unwrapped.agent_pos[0],
  592. )
  593. if self.type == "angle":
  594. obs["goal_direction"] = np.arctan(slope)
  595. else:
  596. obs["goal_direction"] = slope
  597. return obs
  598. class SymbolicObsWrapper(ObservationWrapper):
  599. """
  600. Fully observable grid with a symbolic state representation.
  601. The symbol is a triple of (X, Y, IDX), where X and Y are
  602. the coordinates on the grid, and IDX is the id of the object.
  603. Example:
  604. >>> import gymnasium as gym
  605. >>> from minigrid.wrappers import SymbolicObsWrapper
  606. >>> env = gym.make("MiniGrid-LavaCrossingS11N5-v0")
  607. >>> obs, _ = env.reset()
  608. >>> obs['image'].shape
  609. (7, 7, 3)
  610. >>> env_obs = SymbolicObsWrapper(env)
  611. >>> obs, _ = env_obs.reset()
  612. >>> obs['image'].shape
  613. (11, 11, 3)
  614. """
  615. def __init__(self, env):
  616. super().__init__(env)
  617. new_image_space = spaces.Box(
  618. low=0,
  619. high=max(OBJECT_TO_IDX.values()),
  620. shape=(
  621. self.env.unwrapped.width,
  622. self.env.unwrapped.height,
  623. 3,
  624. ), # number of cells
  625. dtype="uint8",
  626. )
  627. self.observation_space = spaces.Dict(
  628. {**self.observation_space.spaces, "image": new_image_space}
  629. )
  630. def observation(self, obs):
  631. objects = np.array(
  632. [
  633. OBJECT_TO_IDX[o.type] if o is not None else -1
  634. for o in self.unwrapped.grid.grid
  635. ]
  636. )
  637. agent_pos = self.env.unwrapped.agent_pos
  638. ncol, nrow = self.unwrapped.width, self.unwrapped.height
  639. grid = np.mgrid[:ncol, :nrow]
  640. _objects = np.transpose(objects.reshape(1, nrow, ncol), (0, 2, 1))
  641. grid = np.concatenate([grid, _objects])
  642. grid = np.transpose(grid, (1, 2, 0))
  643. grid[agent_pos[0], agent_pos[1], 2] = OBJECT_TO_IDX["agent"]
  644. obs["image"] = grid
  645. return obs
  646. class StochasticActionWrapper(ActionWrapper):
  647. """
  648. Add stochasticity to the actions
  649. If a random action is provided, it is returned with probability `1 - prob`.
  650. Else, a random action is sampled from the action space.
  651. """
  652. def __init__(self, env=None, prob=0.9, random_action=None):
  653. super().__init__(env)
  654. self.prob = prob
  655. self.random_action = random_action
  656. def action(self, action):
  657. """ """
  658. if np.random.uniform() < self.prob:
  659. return action
  660. else:
  661. if self.random_action is None:
  662. return self.np_random.integers(0, high=6)
  663. else:
  664. return self.random_action
  665. class NoDeath(Wrapper):
  666. """
  667. Wrapper to prevent death in specific cells (e.g., lava cells).
  668. Instead of dying, the agent will receive a negative reward.
  669. Example:
  670. >>> import gymnasium as gym
  671. >>> from minigrid.wrappers import NoDeath
  672. >>>
  673. >>> env = gym.make("MiniGrid-LavaCrossingS9N1-v0")
  674. >>> _, _ = env.reset(seed=2)
  675. >>> _, _, _, _, _ = env.step(1)
  676. >>> _, reward, term, *_ = env.step(2)
  677. >>> reward, term
  678. (0, True)
  679. >>>
  680. >>> env = NoDeath(env, no_death_types=("lava",), death_cost=-1.0)
  681. >>> _, _ = env.reset(seed=2)
  682. >>> _, _, _, _, _ = env.step(1)
  683. >>> _, reward, term, *_ = env.step(2)
  684. >>> reward, term
  685. (-1.0, False)
  686. >>>
  687. >>>
  688. >>> env = gym.make("MiniGrid-Dynamic-Obstacles-5x5-v0")
  689. >>> _, _ = env.reset(seed=2)
  690. >>> _, reward, term, *_ = env.step(2)
  691. >>> reward, term
  692. (-1, True)
  693. >>>
  694. >>> env = NoDeath(env, no_death_types=("ball",), death_cost=-1.0)
  695. >>> _, _ = env.reset(seed=2)
  696. >>> _, reward, term, *_ = env.step(2)
  697. >>> reward, term
  698. (-2.0, False)
  699. """
  700. def __init__(self, env, no_death_types: tuple[str, ...], death_cost: float = -1.0):
  701. """A wrapper to prevent death in specific cells.
  702. Args:
  703. env: The environment to apply the wrapper
  704. no_death_types: List of strings to identify death cells
  705. death_cost: The negative reward received in death cells
  706. """
  707. assert "goal" not in no_death_types, "goal cannot be a death cell"
  708. super().__init__(env)
  709. self.death_cost = death_cost
  710. self.no_death_types = no_death_types
  711. def step(self, action):
  712. # In Dynamic-Obstacles, obstacles move after the agent moves,
  713. # so we need to check for collision before self.env.step()
  714. front_cell = self.unwrapped.grid.get(*self.unwrapped.front_pos)
  715. going_to_death = (
  716. action == self.unwrapped.actions.forward
  717. and front_cell is not None
  718. and front_cell.type in self.no_death_types
  719. )
  720. obs, reward, terminated, truncated, info = self.env.step(action)
  721. # We also check if the agent stays in death cells (e.g., lava)
  722. # without moving
  723. current_cell = self.unwrapped.grid.get(*self.unwrapped.agent_pos)
  724. in_death = current_cell is not None and current_cell.type in self.no_death_types
  725. if terminated and (going_to_death or in_death):
  726. terminated = False
  727. reward += self.death_cost
  728. return obs, reward, terminated, truncated, info