wrappers.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790
  1. from __future__ import annotations
  2. import math
  3. import operator
  4. from functools import reduce
  5. import gymnasium as gym
  6. import numpy as np
  7. from gymnasium import spaces
  8. from gymnasium.core import ObservationWrapper, Wrapper
  9. from minigrid.core.constants import COLOR_TO_IDX, OBJECT_TO_IDX, STATE_TO_IDX
  10. from minigrid.core.world_object import Goal
  11. class ReseedWrapper(Wrapper):
  12. """
  13. Wrapper to always regenerate an environment with the same set of seeds.
  14. This can be used to force an environment to always keep the same
  15. configuration when reset.
  16. Example:
  17. >>> import minigrid
  18. >>> import gymnasium as gym
  19. >>> from minigrid.wrappers import ReseedWrapper
  20. >>> env = gym.make("MiniGrid-Empty-5x5-v0")
  21. >>> [env.np_random.integers(10) for i in range(10)]
  22. [1, 9, 5, 8, 4, 3, 8, 8, 3, 1]
  23. >>> env = ReseedWrapper(env, seeds=[0, 1], seed_idx=0)
  24. >>> _, _ = env.reset()
  25. >>> [env.np_random.integers(10) for i in range(10)]
  26. [8, 6, 5, 2, 3, 0, 0, 0, 1, 8]
  27. >>> _, _ = env.reset()
  28. >>> [env.np_random.integers(10) for i in range(10)]
  29. [4, 5, 7, 9, 0, 1, 8, 9, 2, 3]
  30. >>> _, _ = env.reset()
  31. >>> [env.np_random.integers(10) for i in range(10)]
  32. [8, 6, 5, 2, 3, 0, 0, 0, 1, 8]
  33. >>> _, _ = env.reset()
  34. >>> [env.np_random.integers(10) for i in range(10)]
  35. [4, 5, 7, 9, 0, 1, 8, 9, 2, 3]
  36. """
  37. def __init__(self, env, seeds=[0], seed_idx=0):
  38. """A wrapper that always regenerate an environment with the same set of seeds.
  39. Args:
  40. env: The environment to apply the wrapper
  41. seeds: A list of seed to be applied to the env
  42. seed_idx: Index of the initial seed in seeds
  43. """
  44. self.seeds = list(seeds)
  45. self.seed_idx = seed_idx
  46. super().__init__(env)
  47. def reset(self, **kwargs):
  48. """Resets the environment with `kwargs`."""
  49. seed = self.seeds[self.seed_idx]
  50. self.seed_idx = (self.seed_idx + 1) % len(self.seeds)
  51. return self.env.reset(seed=seed, **kwargs)
  52. def step(self, action):
  53. """Steps through the environment with `action`."""
  54. return self.env.step(action)
  55. class ActionBonus(gym.Wrapper):
  56. """
  57. Wrapper which adds an exploration bonus.
  58. This is a reward to encourage exploration of less
  59. visited (state,action) pairs.
  60. Example:
  61. >>> import minigrid
  62. >>> import gymnasium as gym
  63. >>> from minigrid.wrappers import ActionBonus
  64. >>> env = gym.make("MiniGrid-Empty-5x5-v0")
  65. >>> _, _ = env.reset(seed=0)
  66. >>> _, reward, _, _, _ = env.step(1)
  67. >>> print(reward)
  68. 0
  69. >>> _, reward, _, _, _ = env.step(1)
  70. >>> print(reward)
  71. 0
  72. >>> env_bonus = ActionBonus(env)
  73. >>> _, _ = env_bonus.reset(seed=0)
  74. >>> _, reward, _, _, _ = env_bonus.step(1)
  75. >>> print(reward)
  76. 1.0
  77. >>> _, reward, _, _, _ = env_bonus.step(1)
  78. >>> print(reward)
  79. 1.0
  80. """
  81. def __init__(self, env):
  82. """A wrapper that adds an exploration bonus to less visited (state,action) pairs.
  83. Args:
  84. env: The environment to apply the wrapper
  85. """
  86. super().__init__(env)
  87. self.counts = {}
  88. def step(self, action):
  89. """Steps through the environment with `action`."""
  90. obs, reward, terminated, truncated, info = self.env.step(action)
  91. env = self.unwrapped
  92. tup = (tuple(env.agent_pos), env.agent_dir, action)
  93. # Get the count for this (s,a) pair
  94. pre_count = 0
  95. if tup in self.counts:
  96. pre_count = self.counts[tup]
  97. # Update the count for this (s,a) pair
  98. new_count = pre_count + 1
  99. self.counts[tup] = new_count
  100. bonus = 1 / math.sqrt(new_count)
  101. reward += bonus
  102. return obs, reward, terminated, truncated, info
  103. def reset(self, **kwargs):
  104. """Resets the environment with `kwargs`."""
  105. return self.env.reset(**kwargs)
  106. class PositionBonus(Wrapper):
  107. """
  108. Adds an exploration bonus based on which positions
  109. are visited on the grid.
  110. Note:
  111. This wrapper was previously called ``StateBonus``.
  112. Example:
  113. >>> import minigrid
  114. >>> import gymnasium as gym
  115. >>> from minigrid.wrappers import PositionBonus
  116. >>> env = gym.make("MiniGrid-Empty-5x5-v0")
  117. >>> _, _ = env.reset(seed=0)
  118. >>> _, reward, _, _, _ = env.step(1)
  119. >>> print(reward)
  120. 0
  121. >>> _, reward, _, _, _ = env.step(1)
  122. >>> print(reward)
  123. 0
  124. >>> env_bonus = PositionBonus(env)
  125. >>> obs, _ = env_bonus.reset(seed=0)
  126. >>> obs, reward, terminated, truncated, info = env_bonus.step(1)
  127. >>> print(reward)
  128. 1.0
  129. >>> obs, reward, terminated, truncated, info = env_bonus.step(1)
  130. >>> print(reward)
  131. 0.7071067811865475
  132. """
  133. def __init__(self, env):
  134. """A wrapper that adds an exploration bonus to less visited positions.
  135. Args:
  136. env: The environment to apply the wrapper
  137. """
  138. super().__init__(env)
  139. self.counts = {}
  140. def step(self, action):
  141. """Steps through the environment with `action`."""
  142. obs, reward, terminated, truncated, info = self.env.step(action)
  143. # Tuple based on which we index the counts
  144. # We use the position after an update
  145. env = self.unwrapped
  146. tup = tuple(env.agent_pos)
  147. # Get the count for this key
  148. pre_count = 0
  149. if tup in self.counts:
  150. pre_count = self.counts[tup]
  151. # Update the count for this key
  152. new_count = pre_count + 1
  153. self.counts[tup] = new_count
  154. bonus = 1 / math.sqrt(new_count)
  155. reward += bonus
  156. return obs, reward, terminated, truncated, info
  157. def reset(self, **kwargs):
  158. """Resets the environment with `kwargs`."""
  159. return self.env.reset(**kwargs)
  160. class ImgObsWrapper(ObservationWrapper):
  161. """
  162. Use the image as the only observation output, no language/mission.
  163. Example:
  164. >>> import minigrid
  165. >>> import gymnasium as gym
  166. >>> from minigrid.wrappers import ImgObsWrapper
  167. >>> env = gym.make("MiniGrid-Empty-5x5-v0")
  168. >>> obs, _ = env.reset()
  169. >>> obs.keys()
  170. dict_keys(['image', 'direction', 'mission'])
  171. >>> env = ImgObsWrapper(env)
  172. >>> obs, _ = env.reset()
  173. >>> obs.shape
  174. (7, 7, 3)
  175. """
  176. def __init__(self, env):
  177. """A wrapper that makes image the only observation.
  178. Args:
  179. env: The environment to apply the wrapper
  180. """
  181. super().__init__(env)
  182. self.observation_space = env.observation_space.spaces["image"]
  183. def observation(self, obs):
  184. return obs["image"]
  185. class OneHotPartialObsWrapper(ObservationWrapper):
  186. """
  187. Wrapper to get a one-hot encoding of a partially observable
  188. agent view as observation.
  189. Example:
  190. >>> import minigrid
  191. >>> import gymnasium as gym
  192. >>> from minigrid.wrappers import OneHotPartialObsWrapper
  193. >>> env = gym.make("MiniGrid-Empty-5x5-v0")
  194. >>> obs, _ = env.reset()
  195. >>> obs["image"][0, :, :]
  196. array([[2, 5, 0],
  197. [2, 5, 0],
  198. [2, 5, 0],
  199. [2, 5, 0],
  200. [2, 5, 0],
  201. [2, 5, 0],
  202. [2, 5, 0]], dtype=uint8)
  203. >>> env = OneHotPartialObsWrapper(env)
  204. >>> obs, _ = env.reset()
  205. >>> obs["image"][0, :, :]
  206. array([[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
  207. [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
  208. [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
  209. [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
  210. [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
  211. [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
  212. [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0]],
  213. dtype=uint8)
  214. """
  215. def __init__(self, env, tile_size=8):
  216. """A wrapper that makes the image observation a one-hot encoding of a partially observable agent view.
  217. Args:
  218. env: The environment to apply the wrapper
  219. """
  220. super().__init__(env)
  221. self.tile_size = tile_size
  222. obs_shape = env.observation_space["image"].shape
  223. # Number of bits per cell
  224. num_bits = len(OBJECT_TO_IDX) + len(COLOR_TO_IDX) + len(STATE_TO_IDX)
  225. new_image_space = spaces.Box(
  226. low=0, high=255, shape=(obs_shape[0], obs_shape[1], num_bits), dtype="uint8"
  227. )
  228. self.observation_space = spaces.Dict(
  229. {**self.observation_space.spaces, "image": new_image_space}
  230. )
  231. def observation(self, obs):
  232. img = obs["image"]
  233. out = np.zeros(self.observation_space.spaces["image"].shape, dtype="uint8")
  234. for i in range(img.shape[0]):
  235. for j in range(img.shape[1]):
  236. type = img[i, j, 0]
  237. color = img[i, j, 1]
  238. state = img[i, j, 2]
  239. out[i, j, type] = 1
  240. out[i, j, len(OBJECT_TO_IDX) + color] = 1
  241. out[i, j, len(OBJECT_TO_IDX) + len(COLOR_TO_IDX) + state] = 1
  242. return {**obs, "image": out}
  243. class RGBImgObsWrapper(ObservationWrapper):
  244. """
  245. Wrapper to use fully observable RGB image as observation,
  246. This can be used to have the agent to solve the gridworld in pixel space.
  247. Example:
  248. >>> import minigrid
  249. >>> import gymnasium as gym
  250. >>> import matplotlib.pyplot as plt
  251. >>> from minigrid.wrappers import RGBImgObsWrapper
  252. >>> env = gym.make("MiniGrid-Empty-5x5-v0")
  253. >>> obs, _ = env.reset()
  254. >>> plt.imshow(obs['image'])
  255. ![NoWrapper](../figures/lavacrossing_NoWrapper.png)
  256. >>> env = RGBImgObsWrapper(env)
  257. >>> obs, _ = env.reset()
  258. >>> plt.imshow(obs['image'])
  259. ![RGBImgObsWrapper](../figures/lavacrossing_RGBImgObsWrapper.png)
  260. """
  261. def __init__(self, env, tile_size=8):
  262. super().__init__(env)
  263. self.tile_size = tile_size
  264. new_image_space = spaces.Box(
  265. low=0,
  266. high=255,
  267. shape=(self.env.width * tile_size, self.env.height * tile_size, 3),
  268. dtype="uint8",
  269. )
  270. self.observation_space = spaces.Dict(
  271. {**self.observation_space.spaces, "image": new_image_space}
  272. )
  273. def observation(self, obs):
  274. rgb_img = self.get_frame(highlight=True, tile_size=self.tile_size)
  275. return {**obs, "image": rgb_img}
  276. class RGBImgPartialObsWrapper(ObservationWrapper):
  277. """
  278. Wrapper to use partially observable RGB image as observation.
  279. This can be used to have the agent to solve the gridworld in pixel space.
  280. Example:
  281. >>> import minigrid
  282. >>> import gymnasium as gym
  283. >>> import matplotlib.pyplot as plt
  284. >>> from minigrid.wrappers import RGBImgObsWrapper, RGBImgPartialObsWrapper
  285. >>> env = gym.make("MiniGrid-LavaCrossingS11N5-v0")
  286. >>> obs, _ = env.reset()
  287. >>> plt.imshow(obs["image"])
  288. ![NoWrapper](../figures/lavacrossing_NoWrapper.png)
  289. >>> env_obs = RGBImgObsWrapper(env)
  290. >>> obs, _ = env_obs.reset()
  291. >>> plt.imshow(obs["image"])
  292. ![RGBImgObsWrapper](../figures/lavacrossing_RGBImgObsWrapper.png)
  293. >>> env_obs = RGBImgPartialObsWrapper(env)
  294. >>> obs, _ = env_obs.reset()
  295. >>> plt.imshow(obs["image"])
  296. ![RGBImgPartialObsWrapper](../figures/lavacrossing_RGBImgPartialObsWrapper.png)
  297. """
  298. def __init__(self, env, tile_size=8):
  299. super().__init__(env)
  300. # Rendering attributes for observations
  301. self.tile_size = tile_size
  302. obs_shape = env.observation_space.spaces["image"].shape
  303. new_image_space = spaces.Box(
  304. low=0,
  305. high=255,
  306. shape=(obs_shape[0] * tile_size, obs_shape[1] * tile_size, 3),
  307. dtype="uint8",
  308. )
  309. self.observation_space = spaces.Dict(
  310. {**self.observation_space.spaces, "image": new_image_space}
  311. )
  312. def observation(self, obs):
  313. rgb_img_partial = self.get_frame(tile_size=self.tile_size, agent_pov=True)
  314. return {**obs, "image": rgb_img_partial}
  315. class FullyObsWrapper(ObservationWrapper):
  316. """
  317. Fully observable gridworld using a compact grid encoding instead of the agent view.
  318. Example:
  319. >>> import minigrid
  320. >>> import gymnasium as gym
  321. >>> import matplotlib.pyplot as plt
  322. >>> from minigrid.wrappers import FullyObsWrapper
  323. >>> env = gym.make("MiniGrid-LavaCrossingS11N5-v0")
  324. >>> obs, _ = env.reset()
  325. >>> obs['image'].shape
  326. (7, 7, 3)
  327. >>> env_obs = FullyObsWrapper(env)
  328. >>> obs, _ = env_obs.reset()
  329. >>> obs['image'].shape
  330. (11, 11, 3)
  331. """
  332. def __init__(self, env):
  333. super().__init__(env)
  334. new_image_space = spaces.Box(
  335. low=0,
  336. high=255,
  337. shape=(self.env.width, self.env.height, 3), # number of cells
  338. dtype="uint8",
  339. )
  340. self.observation_space = spaces.Dict(
  341. {**self.observation_space.spaces, "image": new_image_space}
  342. )
  343. def observation(self, obs):
  344. env = self.unwrapped
  345. full_grid = env.grid.encode()
  346. full_grid[env.agent_pos[0]][env.agent_pos[1]] = np.array(
  347. [OBJECT_TO_IDX["agent"], COLOR_TO_IDX["red"], env.agent_dir]
  348. )
  349. return {**obs, "image": full_grid}
  350. class DictObservationSpaceWrapper(ObservationWrapper):
  351. """
  352. Transforms the observation space (that has a textual component) to a fully numerical observation space,
  353. where the textual instructions are replaced by arrays representing the indices of each word in a fixed vocabulary.
  354. This wrapper is not applicable to BabyAI environments, given that these have their own language component.
  355. Example:
  356. >>> import minigrid
  357. >>> import gymnasium as gym
  358. >>> import matplotlib.pyplot as plt
  359. >>> from minigrid.wrappers import DictObservationSpaceWrapper
  360. >>> env = gym.make("MiniGrid-LavaCrossingS11N5-v0")
  361. >>> obs, _ = env.reset()
  362. >>> obs['mission']
  363. 'avoid the lava and get to the green goal square'
  364. >>> env_obs = DictObservationSpaceWrapper(env)
  365. >>> obs, _ = env_obs.reset()
  366. >>> obs['mission'][:10]
  367. [19, 31, 17, 36, 20, 38, 31, 2, 15, 35]
  368. """
  369. def __init__(self, env, max_words_in_mission=50, word_dict=None):
  370. """
  371. max_words_in_mission is the length of the array to represent a mission, value 0 for missing words
  372. word_dict is a dictionary of words to use (keys=words, values=indices from 1 to < max_words_in_mission),
  373. if None, use the Minigrid language
  374. """
  375. super().__init__(env)
  376. if word_dict is None:
  377. word_dict = self.get_minigrid_words()
  378. self.max_words_in_mission = max_words_in_mission
  379. self.word_dict = word_dict
  380. image_observation_space = spaces.Box(
  381. low=0,
  382. high=255,
  383. shape=(self.agent_view_size, self.agent_view_size, 3),
  384. dtype="uint8",
  385. )
  386. self.observation_space = spaces.Dict(
  387. {
  388. "image": image_observation_space,
  389. "direction": spaces.Discrete(4),
  390. "mission": spaces.MultiDiscrete(
  391. [len(self.word_dict.keys())] * max_words_in_mission
  392. ),
  393. }
  394. )
  395. @staticmethod
  396. def get_minigrid_words():
  397. colors = ["red", "green", "blue", "yellow", "purple", "grey"]
  398. objects = [
  399. "unseen",
  400. "empty",
  401. "wall",
  402. "floor",
  403. "box",
  404. "key",
  405. "ball",
  406. "door",
  407. "goal",
  408. "agent",
  409. "lava",
  410. ]
  411. verbs = [
  412. "pick",
  413. "avoid",
  414. "get",
  415. "find",
  416. "put",
  417. "use",
  418. "open",
  419. "go",
  420. "fetch",
  421. "reach",
  422. "unlock",
  423. "traverse",
  424. ]
  425. extra_words = [
  426. "up",
  427. "the",
  428. "a",
  429. "at",
  430. ",",
  431. "square",
  432. "and",
  433. "then",
  434. "to",
  435. "of",
  436. "rooms",
  437. "near",
  438. "opening",
  439. "must",
  440. "you",
  441. "matching",
  442. "end",
  443. "hallway",
  444. "object",
  445. "from",
  446. "room",
  447. ]
  448. all_words = colors + objects + verbs + extra_words
  449. assert len(all_words) == len(set(all_words))
  450. return {word: i for i, word in enumerate(all_words)}
  451. def string_to_indices(self, string, offset=1):
  452. """
  453. Convert a string to a list of indices.
  454. """
  455. indices = []
  456. # adding space before and after commas
  457. string = string.replace(",", " , ")
  458. for word in string.split():
  459. if word in self.word_dict.keys():
  460. indices.append(self.word_dict[word] + offset)
  461. else:
  462. raise ValueError(f"Unknown word: {word}")
  463. return indices
  464. def observation(self, obs):
  465. obs["mission"] = self.string_to_indices(obs["mission"])
  466. assert len(obs["mission"]) < self.max_words_in_mission
  467. obs["mission"] += [0] * (self.max_words_in_mission - len(obs["mission"]))
  468. return obs
  469. class FlatObsWrapper(ObservationWrapper):
  470. """
  471. Encode mission strings using a one-hot scheme,
  472. and combine these with observed images into one flat array.
  473. This wrapper is not applicable to BabyAI environments, given that these have their own language component.
  474. Example:
  475. >>> import minigrid
  476. >>> import gymnasium as gym
  477. >>> import matplotlib.pyplot as plt
  478. >>> from minigrid.wrappers import FlatObsWrapper
  479. >>> env = gym.make("MiniGrid-LavaCrossingS11N5-v0")
  480. >>> env_obs = FlatObsWrapper(env)
  481. >>> obs, _ = env_obs.reset()
  482. >>> obs.shape
  483. (2835,)
  484. """
  485. def __init__(self, env, maxStrLen=96):
  486. super().__init__(env)
  487. self.maxStrLen = maxStrLen
  488. self.numCharCodes = 28
  489. imgSpace = env.observation_space.spaces["image"]
  490. imgSize = reduce(operator.mul, imgSpace.shape, 1)
  491. self.observation_space = spaces.Box(
  492. low=0,
  493. high=255,
  494. shape=(imgSize + self.numCharCodes * self.maxStrLen,),
  495. dtype="uint8",
  496. )
  497. self.cachedStr: str = None
  498. def observation(self, obs):
  499. image = obs["image"]
  500. mission = obs["mission"]
  501. # Cache the last-encoded mission string
  502. if mission != self.cachedStr:
  503. assert (
  504. len(mission) <= self.maxStrLen
  505. ), f"mission string too long ({len(mission)} chars)"
  506. mission = mission.lower()
  507. strArray = np.zeros(
  508. shape=(self.maxStrLen, self.numCharCodes), dtype="float32"
  509. )
  510. for idx, ch in enumerate(mission):
  511. if ch >= "a" and ch <= "z":
  512. chNo = ord(ch) - ord("a")
  513. elif ch == " ":
  514. chNo = ord("z") - ord("a") + 1
  515. elif ch == ",":
  516. chNo = ord("z") - ord("a") + 2
  517. else:
  518. raise ValueError(
  519. f"Character {ch} is not available in mission string."
  520. )
  521. assert chNo < self.numCharCodes, "%s : %d" % (ch, chNo)
  522. strArray[idx, chNo] = 1
  523. self.cachedStr = mission
  524. self.cachedArray = strArray
  525. obs = np.concatenate((image.flatten(), self.cachedArray.flatten()))
  526. return obs
  527. class ViewSizeWrapper(ObservationWrapper):
  528. """
  529. Wrapper to customize the agent field of view size.
  530. This cannot be used with fully observable wrappers.
  531. Example:
  532. >>> import minigrid
  533. >>> import gymnasium as gym
  534. >>> import matplotlib.pyplot as plt
  535. >>> from minigrid.wrappers import ViewSizeWrapper
  536. >>> env = gym.make("MiniGrid-LavaCrossingS11N5-v0")
  537. >>> obs, _ = env.reset()
  538. >>> obs['image'].shape
  539. (7, 7, 3)
  540. >>> env_obs = ViewSizeWrapper(env, agent_view_size=5)
  541. >>> obs, _ = env_obs.reset()
  542. >>> obs['image'].shape
  543. (5, 5, 3)
  544. """
  545. def __init__(self, env, agent_view_size=7):
  546. super().__init__(env)
  547. assert agent_view_size % 2 == 1
  548. assert agent_view_size >= 3
  549. self.agent_view_size = agent_view_size
  550. # Compute observation space with specified view size
  551. new_image_space = gym.spaces.Box(
  552. low=0, high=255, shape=(agent_view_size, agent_view_size, 3), dtype="uint8"
  553. )
  554. # Override the environment's observation spaceexit
  555. self.observation_space = spaces.Dict(
  556. {**self.observation_space.spaces, "image": new_image_space}
  557. )
  558. def observation(self, obs):
  559. env = self.unwrapped
  560. grid, vis_mask = env.gen_obs_grid(self.agent_view_size)
  561. # Encode the partially observable view into a numpy array
  562. image = grid.encode(vis_mask)
  563. return {**obs, "image": image}
  564. class DirectionObsWrapper(ObservationWrapper):
  565. """
  566. Provides the slope/angular direction to the goal with the observations as modeled by (y2 - y2 )/( x2 - x1)
  567. type = {slope , angle}
  568. Example:
  569. >>> import minigrid
  570. >>> import gymnasium as gym
  571. >>> import matplotlib.pyplot as plt
  572. >>> from minigrid.wrappers import DirectionObsWrapper
  573. >>> env = gym.make("MiniGrid-LavaCrossingS11N5-v0")
  574. >>> env_obs = DirectionObsWrapper(env, type="slope")
  575. >>> obs, _ = env_obs.reset()
  576. >>> obs['goal_direction']
  577. 1.0
  578. """
  579. def __init__(self, env, type="slope"):
  580. super().__init__(env)
  581. self.goal_position: tuple = None
  582. self.type = type
  583. def reset(self):
  584. obs, _ = self.env.reset()
  585. if not self.goal_position:
  586. self.goal_position = [
  587. x for x, y in enumerate(self.grid.grid) if isinstance(y, Goal)
  588. ]
  589. # in case there are multiple goals , needs to be handled for other env types
  590. if len(self.goal_position) >= 1:
  591. self.goal_position = (
  592. int(self.goal_position[0] / self.height),
  593. self.goal_position[0] % self.width,
  594. )
  595. return self.observation(obs)
  596. def observation(self, obs):
  597. slope = np.divide(
  598. self.goal_position[1] - self.agent_pos[1],
  599. self.goal_position[0] - self.agent_pos[0],
  600. )
  601. if self.type == "angle":
  602. obs["goal_direction"] = np.arctan(slope)
  603. else:
  604. obs["goal_direction"] = slope
  605. return obs
  606. class SymbolicObsWrapper(ObservationWrapper):
  607. """
  608. Fully observable grid with a symbolic state representation.
  609. The symbol is a triple of (X, Y, IDX), where X and Y are
  610. the coordinates on the grid, and IDX is the id of the object.
  611. Example:
  612. >>> import minigrid
  613. >>> import gymnasium as gym
  614. >>> import matplotlib.pyplot as plt
  615. >>> from minigrid.wrappers import SymbolicObsWrapper
  616. >>> env = gym.make("MiniGrid-LavaCrossingS11N5-v0")
  617. >>> obs, _ = env.reset()
  618. >>> obs['image'].shape
  619. (7, 7, 3)
  620. >>> env_obs = SymbolicObsWrapper(env)
  621. >>> obs, _ = env_obs.reset()
  622. >>> obs['image'].shape
  623. (11, 11, 3)
  624. """
  625. def __init__(self, env):
  626. super().__init__(env)
  627. new_image_space = spaces.Box(
  628. low=0,
  629. high=max(OBJECT_TO_IDX.values()),
  630. shape=(self.env.width, self.env.height, 3), # number of cells
  631. dtype="uint8",
  632. )
  633. self.observation_space = spaces.Dict(
  634. {**self.observation_space.spaces, "image": new_image_space}
  635. )
  636. def observation(self, obs):
  637. objects = np.array(
  638. [OBJECT_TO_IDX[o.type] if o is not None else -1 for o in self.grid.grid]
  639. )
  640. agent_pos = self.env.agent_pos
  641. ncol, nrow = self.width, self.height
  642. grid = np.mgrid[:ncol, :nrow]
  643. _objects = np.transpose(objects.reshape(1, nrow, ncol), (0, 2, 1))
  644. grid = np.concatenate([grid, _objects])
  645. grid = np.transpose(grid, (1, 2, 0))
  646. grid[agent_pos[0], agent_pos[1], 2] = OBJECT_TO_IDX["agent"]
  647. obs["image"] = grid
  648. return obs