wrappers.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753
  1. from __future__ import annotations
  2. import math
  3. import operator
  4. from functools import reduce
  5. import gymnasium as gym
  6. import numpy as np
  7. from gymnasium import spaces
  8. from gymnasium.core import ObservationWrapper, Wrapper
  9. from minigrid.core.constants import COLOR_TO_IDX, OBJECT_TO_IDX, STATE_TO_IDX
  10. from minigrid.core.world_object import Goal
  11. class ReseedWrapper(Wrapper):
  12. """
  13. Wrapper to always regenerate an environment with the same set of seeds.
  14. This can be used to force an environment to always keep the same
  15. configuration when reset.
  16. Example:
  17. >>> import minigrid
  18. >>> import gymnasium as gym
  19. >>> from minigrid.wrappers import ReseedWrapper
  20. >>> env = gym.make("MiniGrid-Empty-5x5-v0")
  21. >>> [env.np_random.integers(10) for i in range(10)]
  22. [1, 9, 5, 8, 4, 3, 8, 8, 3, 1]
  23. >>> env = ReseedWrapper(env, seeds=[0, 1], seed_idx=0)
  24. >>> _, _ = env.reset()
  25. >>> [env.np_random.integers(10) for i in range(10)]
  26. [8, 6, 5, 2, 3, 0, 0, 0, 1, 8]
  27. >>> _, _ = env.reset()
  28. >>> [env.np_random.integers(10) for i in range(10)]
  29. [4, 5, 7, 9, 0, 1, 8, 9, 2, 3]
  30. >>> _, _ = env.reset()
  31. >>> [env.np_random.integers(10) for i in range(10)]
  32. [8, 6, 5, 2, 3, 0, 0, 0, 1, 8]
  33. >>> _, _ = env.reset()
  34. >>> [env.np_random.integers(10) for i in range(10)]
  35. [4, 5, 7, 9, 0, 1, 8, 9, 2, 3]
  36. """
  37. def __init__(self, env, seeds=[0], seed_idx=0):
  38. """A wrapper that always regenerate an environment with the same set of seeds.
  39. Args:
  40. env: The environment to apply the wrapper
  41. seeds: A list of seed to be applied to the env
  42. seed_idx: Index of the initial seed in seeds
  43. """
  44. self.seeds = list(seeds)
  45. self.seed_idx = seed_idx
  46. super().__init__(env)
  47. def reset(self, **kwargs):
  48. """Resets the environment with `kwargs`."""
  49. seed = self.seeds[self.seed_idx]
  50. self.seed_idx = (self.seed_idx + 1) % len(self.seeds)
  51. return self.env.reset(seed=seed, **kwargs)
  52. def step(self, action):
  53. """Steps through the environment with `action`."""
  54. return self.env.step(action)
  55. class ActionBonus(gym.Wrapper):
  56. """
  57. Wrapper which adds an exploration bonus.
  58. This is a reward to encourage exploration of less
  59. visited (state,action) pairs.
  60. Example:
  61. >>> import miniworld
  62. >>> import gymnasium as gym
  63. >>> from minigrid.wrappers import ActionBonus
  64. >>> env = gym.make("MiniGrid-Empty-5x5-v0")
  65. >>> _, _ = env.reset(seed=0)
  66. >>> _, reward, _, _, _ = env.step(1)
  67. >>> print(reward)
  68. 0
  69. >>> _, reward, _, _, _ = env.step(1)
  70. >>> print(reward)
  71. 0
  72. >>> env_bonus = ActionBonus(env)
  73. >>> _, _ = env_bonus.reset(seed=0)
  74. >>> _, reward, _, _, _ = env_bonus.step(1)
  75. >>> print(reward)
  76. 1.0
  77. >>> _, reward, _, _, _ = env_bonus.step(1)
  78. >>> print(reward)
  79. 1.0
  80. """
  81. def __init__(self, env):
  82. """A wrapper that adds an exploration bonus to less visited (state,action) pairs.
  83. Args:
  84. env: The environment to apply the wrapper
  85. """
  86. super().__init__(env)
  87. self.counts = {}
  88. def step(self, action):
  89. """Steps through the environment with `action`."""
  90. obs, reward, terminated, truncated, info = self.env.step(action)
  91. env = self.unwrapped
  92. tup = (tuple(env.agent_pos), env.agent_dir, action)
  93. # Get the count for this (s,a) pair
  94. pre_count = 0
  95. if tup in self.counts:
  96. pre_count = self.counts[tup]
  97. # Update the count for this (s,a) pair
  98. new_count = pre_count + 1
  99. self.counts[tup] = new_count
  100. bonus = 1 / math.sqrt(new_count)
  101. reward += bonus
  102. return obs, reward, terminated, truncated, info
  103. def reset(self, **kwargs):
  104. """Resets the environment with `kwargs`."""
  105. return self.env.reset(**kwargs)
  106. # Should be named PositionBonus
  107. class StateBonus(Wrapper):
  108. """
  109. Adds an exploration bonus based on which positions
  110. are visited on the grid.
  111. Example:
  112. >>> import miniworld
  113. >>> import gymnasium as gym
  114. >>> from minigrid.wrappers import StateBonus
  115. >>> env = gym.make("MiniGrid-Empty-5x5-v0")
  116. >>> _, _ = env.reset(seed=0)
  117. >>> _, reward, _, _, _ = env.step(1)
  118. >>> print(reward)
  119. 0
  120. >>> _, reward, _, _, _ = env.step(1)
  121. >>> print(reward)
  122. 0
  123. >>> env_bonus = StateBonus(env)
  124. >>> obs, _ = env_bonus.reset(seed=0)
  125. >>> obs, reward, terminated, truncated, info = env_bonus.step(1)
  126. >>> print(reward)
  127. 1.0
  128. >>> obs, reward, terminated, truncated, info = env_bonus.step(1)
  129. >>> print(reward)
  130. 0.7071067811865475
  131. """
  132. def __init__(self, env):
  133. """A wrapper that adds an exploration bonus to less visited positions.
  134. Args:
  135. env: The environment to apply the wrapper
  136. """
  137. super().__init__(env)
  138. self.counts = {}
  139. def step(self, action):
  140. """Steps through the environment with `action`."""
  141. obs, reward, terminated, truncated, info = self.env.step(action)
  142. # Tuple based on which we index the counts
  143. # We use the position after an update
  144. env = self.unwrapped
  145. tup = tuple(env.agent_pos)
  146. # Get the count for this key
  147. pre_count = 0
  148. if tup in self.counts:
  149. pre_count = self.counts[tup]
  150. # Update the count for this key
  151. new_count = pre_count + 1
  152. self.counts[tup] = new_count
  153. bonus = 1 / math.sqrt(new_count)
  154. reward += bonus
  155. return obs, reward, terminated, truncated, info
  156. def reset(self, **kwargs):
  157. """Resets the environment with `kwargs`."""
  158. return self.env.reset(**kwargs)
  159. class ImgObsWrapper(ObservationWrapper):
  160. """
  161. Use the image as the only observation output, no language/mission.
  162. Example:
  163. >>> import miniworld
  164. >>> import gymnasium as gym
  165. >>> from minigrid.wrappers import ImgObsWrapper
  166. >>> env = gym.make("MiniGrid-Empty-5x5-v0")
  167. >>> obs, _ = env.reset()
  168. >>> obs.keys()
  169. dict_keys(['image', 'direction', 'mission'])
  170. >>> env = ImgObsWrapper(env)
  171. >>> obs, _ = env.reset()
  172. >>> obs.shape
  173. (7, 7, 3)
  174. """
  175. def __init__(self, env):
  176. """A wrapper that makes image the only observation.
  177. Args:
  178. env: The environment to apply the wrapper
  179. """
  180. super().__init__(env)
  181. self.observation_space = env.observation_space.spaces["image"]
  182. def observation(self, obs):
  183. return obs["image"]
  184. class OneHotPartialObsWrapper(ObservationWrapper):
  185. """
  186. Wrapper to get a one-hot encoding of a partially observable
  187. agent view as observation.
  188. Example:
  189. >>> import miniworld
  190. >>> import gymnasium as gym
  191. >>> from minigrid.wrappers import OneHotPartialObsWrapper
  192. >>> env = gym.make("MiniGrid-Empty-5x5-v0")
  193. >>> obs, _ = env.reset()
  194. >>> obs["image"][0, :, :]
  195. array([[2, 5, 0],
  196. [2, 5, 0],
  197. [2, 5, 0],
  198. [2, 5, 0],
  199. [2, 5, 0],
  200. [2, 5, 0],
  201. [2, 5, 0]], dtype=uint8)
  202. >>> env = OneHotPartialObsWrapper(env)
  203. >>> obs, _ = env.reset()
  204. >>> obs["image"][0, :, :]
  205. array([[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
  206. [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
  207. [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
  208. [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
  209. [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
  210. [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
  211. [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0]],
  212. dtype=uint8)
  213. """
  214. def __init__(self, env, tile_size=8):
  215. """A wrapper that makes the image observation a one-hot encoding of a partially observable agent view.
  216. Args:
  217. env: The environment to apply the wrapper
  218. """
  219. super().__init__(env)
  220. self.tile_size = tile_size
  221. obs_shape = env.observation_space["image"].shape
  222. # Number of bits per cell
  223. num_bits = len(OBJECT_TO_IDX) + len(COLOR_TO_IDX) + len(STATE_TO_IDX)
  224. new_image_space = spaces.Box(
  225. low=0, high=255, shape=(obs_shape[0], obs_shape[1], num_bits), dtype="uint8"
  226. )
  227. self.observation_space = spaces.Dict(
  228. {**self.observation_space.spaces, "image": new_image_space}
  229. )
  230. def observation(self, obs):
  231. img = obs["image"]
  232. out = np.zeros(self.observation_space.spaces["image"].shape, dtype="uint8")
  233. for i in range(img.shape[0]):
  234. for j in range(img.shape[1]):
  235. type = img[i, j, 0]
  236. color = img[i, j, 1]
  237. state = img[i, j, 2]
  238. out[i, j, type] = 1
  239. out[i, j, len(OBJECT_TO_IDX) + color] = 1
  240. out[i, j, len(OBJECT_TO_IDX) + len(COLOR_TO_IDX) + state] = 1
  241. return {**obs, "image": out}
  242. class RGBImgObsWrapper(ObservationWrapper):
  243. """
  244. Wrapper to use fully observable RGB image as observation,
  245. This can be used to have the agent to solve the gridworld in pixel space.
  246. Example:
  247. >>> import miniworld
  248. >>> import gymnasium as gym
  249. >>> import matplotlib.pyplot as plt
  250. >>> from minigrid.wrappers import RGBImgObsWrapper
  251. >>> env = gym.make("MiniGrid-Empty-5x5-v0")
  252. >>> obs, _ = env.reset()
  253. >>> plt.imshow(obs['image'])
  254. ![NoWrapper](../figures/lavacrossing_NoWrapper.png)
  255. >>> env = RGBImgObsWrapper(env)
  256. >>> obs, _ = env.reset()
  257. >>> plt.imshow(obs['image'])
  258. ![RGBImgObsWrapper](../figures/lavacrossing_RGBImgObsWrapper.png)
  259. """
  260. def __init__(self, env, tile_size=8):
  261. super().__init__(env)
  262. self.tile_size = tile_size
  263. new_image_space = spaces.Box(
  264. low=0,
  265. high=255,
  266. shape=(self.env.width * tile_size, self.env.height * tile_size, 3),
  267. dtype="uint8",
  268. )
  269. self.observation_space = spaces.Dict(
  270. {**self.observation_space.spaces, "image": new_image_space}
  271. )
  272. def observation(self, obs):
  273. rgb_img = self.get_frame(highlight=True, tile_size=self.tile_size)
  274. return {**obs, "image": rgb_img}
  275. class RGBImgPartialObsWrapper(ObservationWrapper):
  276. """
  277. Wrapper to use partially observable RGB image as observation.
  278. This can be used to have the agent to solve the gridworld in pixel space.
  279. Example:
  280. >>> import miniworld
  281. >>> import gymnasium as gym
  282. >>> import matplotlib.pyplot as plt
  283. >>> from minigrid.wrappers import RGBImgObsWrapper, RGBImgPartialObsWrapper
  284. >>> env = gym.make("MiniGrid-LavaCrossingS11N5-v0")
  285. >>> obs, _ = env.reset()
  286. >>> plt.imshow(obs["image"])
  287. ![NoWrapper](../figures/lavacrossing_NoWrapper.png)
  288. >>> env_obs = RGBImgObsWrapper(env)
  289. >>> obs, _ = env_obs.reset()
  290. >>> plt.imshow(obs["image"])
  291. ![RGBImgObsWrapper](../figures/lavacrossing_RGBImgObsWrapper.png)
  292. >>> env_obs = RGBImgPartialObsWrapper(env)
  293. >>> obs, _ = env_obs.reset()
  294. >>> plt.imshow(obs["image"])
  295. ![RGBImgPartialObsWrapper](../figures/lavacrossing_RGBImgPartialObsWrapper.png)
  296. """
  297. def __init__(self, env, tile_size=8):
  298. super().__init__(env)
  299. # Rendering attributes for observations
  300. self.tile_size = tile_size
  301. obs_shape = env.observation_space.spaces["image"].shape
  302. new_image_space = spaces.Box(
  303. low=0,
  304. high=255,
  305. shape=(obs_shape[0] * tile_size, obs_shape[1] * tile_size, 3),
  306. dtype="uint8",
  307. )
  308. self.observation_space = spaces.Dict(
  309. {**self.observation_space.spaces, "image": new_image_space}
  310. )
  311. def observation(self, obs):
  312. rgb_img_partial = self.get_frame(tile_size=self.tile_size, agent_pov=True)
  313. return {**obs, "image": rgb_img_partial}
  314. class FullyObsWrapper(ObservationWrapper):
  315. """
  316. Fully observable gridworld using a compact grid encoding instead of the agent view.
  317. Example:
  318. >>> import miniworld
  319. >>> import gymnasium as gym
  320. >>> import matplotlib.pyplot as plt
  321. >>> from minigrid.wrappers import FullyObsWrapper
  322. >>> env = gym.make("MiniGrid-LavaCrossingS11N5-v0")
  323. >>> obs, _ = env.reset()
  324. >>> obs['image'].shape
  325. (7, 7, 3)
  326. >>> env_obs = FullyObsWrapper(env)
  327. >>> obs, _ = env_obs.reset()
  328. >>> obs['image'].shape
  329. (11, 11, 3)
  330. """
  331. def __init__(self, env):
  332. super().__init__(env)
  333. new_image_space = spaces.Box(
  334. low=0,
  335. high=255,
  336. shape=(self.env.width, self.env.height, 3), # number of cells
  337. dtype="uint8",
  338. )
  339. self.observation_space = spaces.Dict(
  340. {**self.observation_space.spaces, "image": new_image_space}
  341. )
  342. def observation(self, obs):
  343. env = self.unwrapped
  344. full_grid = env.grid.encode()
  345. full_grid[env.agent_pos[0]][env.agent_pos[1]] = np.array(
  346. [OBJECT_TO_IDX["agent"], COLOR_TO_IDX["red"], env.agent_dir]
  347. )
  348. return {**obs, "image": full_grid}
  349. class DictObservationSpaceWrapper(ObservationWrapper):
  350. """
  351. Transforms the observation space (that has a textual component) to a fully numerical observation space,
  352. where the textual instructions are replaced by arrays representing the indices of each word in a fixed vocabulary.
  353. This wrapper is not applicable to BabyAI environments, given that these have their own language component.
  354. Example:
  355. >>> import miniworld
  356. >>> import gymnasium as gym
  357. >>> import matplotlib.pyplot as plt
  358. >>> from minigrid.wrappers import DictObservationSpaceWrapper
  359. >>> env = gym.make("MiniGrid-LavaCrossingS11N5-v0")
  360. >>> obs, _ = env.reset()
  361. >>> obs['mission']
  362. 'avoid the lava and get to the green goal square'
  363. >>> env_obs = DictObservationSpaceWrapper(env)
  364. >>> obs, _ = env_obs.reset()
  365. >>> obs['mission'][:10]
  366. [19, 31, 17, 36, 20, 38, 31, 2, 15, 35]
  367. """
  368. def __init__(self, env, max_words_in_mission=50, word_dict=None):
  369. """
  370. max_words_in_mission is the length of the array to represent a mission, value 0 for missing words
  371. word_dict is a dictionary of words to use (keys=words, values=indices from 1 to < max_words_in_mission),
  372. if None, use the Minigrid language
  373. """
  374. super().__init__(env)
  375. if word_dict is None:
  376. word_dict = self.get_minigrid_words()
  377. self.max_words_in_mission = max_words_in_mission
  378. self.word_dict = word_dict
  379. image_observation_space = spaces.Box(
  380. low=0,
  381. high=255,
  382. shape=(self.agent_view_size, self.agent_view_size, 3),
  383. dtype="uint8",
  384. )
  385. self.observation_space = spaces.Dict(
  386. {
  387. "image": image_observation_space,
  388. "direction": spaces.Discrete(4),
  389. "mission": spaces.MultiDiscrete(
  390. [len(self.word_dict.keys())] * max_words_in_mission
  391. ),
  392. }
  393. )
  394. @staticmethod
  395. def get_minigrid_words():
  396. colors = ["red", "green", "blue", "yellow", "purple", "grey"]
  397. objects = [
  398. "unseen",
  399. "empty",
  400. "wall",
  401. "floor",
  402. "box",
  403. "key",
  404. "ball",
  405. "door",
  406. "goal",
  407. "agent",
  408. "lava",
  409. ]
  410. verbs = [
  411. "pick",
  412. "avoid",
  413. "get",
  414. "find",
  415. "put",
  416. "use",
  417. "open",
  418. "go",
  419. "fetch",
  420. "reach",
  421. "unlock",
  422. "traverse",
  423. ]
  424. extra_words = [
  425. "up",
  426. "the",
  427. "a",
  428. "at",
  429. ",",
  430. "square",
  431. "and",
  432. "then",
  433. "to",
  434. "of",
  435. "rooms",
  436. "near",
  437. "opening",
  438. "must",
  439. "you",
  440. "matching",
  441. "end",
  442. "hallway",
  443. "object",
  444. "from",
  445. "room",
  446. ]
  447. all_words = colors + objects + verbs + extra_words
  448. assert len(all_words) == len(set(all_words))
  449. return {word: i for i, word in enumerate(all_words)}
  450. def string_to_indices(self, string, offset=1):
  451. """
  452. Convert a string to a list of indices.
  453. """
  454. indices = []
  455. # adding space before and after commas
  456. string = string.replace(",", " , ")
  457. for word in string.split():
  458. if word in self.word_dict.keys():
  459. indices.append(self.word_dict[word] + offset)
  460. else:
  461. raise ValueError(f"Unknown word: {word}")
  462. return indices
  463. def observation(self, obs):
  464. obs["mission"] = self.string_to_indices(obs["mission"])
  465. assert len(obs["mission"]) < self.max_words_in_mission
  466. obs["mission"] += [0] * (self.max_words_in_mission - len(obs["mission"]))
  467. return obs
  468. class FlatObsWrapper(ObservationWrapper):
  469. """
  470. Encode mission strings using a one-hot scheme,
  471. and combine these with observed images into one flat array.
  472. This wrapper is not applicable to BabyAI environments, given that these have their own language component.
  473. Example:
  474. >>> import miniworld
  475. >>> import gymnasium as gym
  476. >>> import matplotlib.pyplot as plt
  477. >>> from minigrid.wrappers import FlatObsWrapper
  478. >>> env = gym.make("MiniGrid-LavaCrossingS11N5-v0")
  479. >>> env_obs = FlatObsWrapper(env)
  480. >>> obs, _ = env_obs.reset()
  481. >>> obs.shape
  482. (2835,)
  483. """
  484. def __init__(self, env, maxStrLen=96):
  485. super().__init__(env)
  486. self.maxStrLen = maxStrLen
  487. self.numCharCodes = 28
  488. imgSpace = env.observation_space.spaces["image"]
  489. imgSize = reduce(operator.mul, imgSpace.shape, 1)
  490. self.observation_space = spaces.Box(
  491. low=0,
  492. high=255,
  493. shape=(imgSize + self.numCharCodes * self.maxStrLen,),
  494. dtype="uint8",
  495. )
  496. self.cachedStr: str = None
  497. def observation(self, obs):
  498. image = obs["image"]
  499. mission = obs["mission"]
  500. # Cache the last-encoded mission string
  501. if mission != self.cachedStr:
  502. assert (
  503. len(mission) <= self.maxStrLen
  504. ), f"mission string too long ({len(mission)} chars)"
  505. mission = mission.lower()
  506. strArray = np.zeros(
  507. shape=(self.maxStrLen, self.numCharCodes), dtype="float32"
  508. )
  509. for idx, ch in enumerate(mission):
  510. if ch >= "a" and ch <= "z":
  511. chNo = ord(ch) - ord("a")
  512. elif ch == " ":
  513. chNo = ord("z") - ord("a") + 1
  514. elif ch == ",":
  515. chNo = ord("z") - ord("a") + 2
  516. else:
  517. raise ValueError(
  518. f"Character {ch} is not available in mission string."
  519. )
  520. assert chNo < self.numCharCodes, "%s : %d" % (ch, chNo)
  521. strArray[idx, chNo] = 1
  522. self.cachedStr = mission
  523. self.cachedArray = strArray
  524. obs = np.concatenate((image.flatten(), self.cachedArray.flatten()))
  525. return obs
  526. class ViewSizeWrapper(ObservationWrapper):
  527. """
  528. Wrapper to customize the agent field of view size.
  529. This cannot be used with fully observable wrappers.
  530. Example:
  531. >>> import miniworld
  532. >>> import gymnasium as gym
  533. >>> import matplotlib.pyplot as plt
  534. >>> from minigrid.wrappers import ViewSizeWrapper
  535. >>> env = gym.make("MiniGrid-LavaCrossingS11N5-v0")
  536. >>> obs, _ = env.reset()
  537. >>> obs['image'].shape
  538. (7, 7, 3)
  539. >>> env_obs = ViewSizeWrapper(env, agent_view_size=5)
  540. >>> obs, _ = env_obs.reset()
  541. >>> obs['image'].shape
  542. (5, 5, 3)
  543. """
  544. def __init__(self, env, agent_view_size=7):
  545. super().__init__(env)
  546. assert agent_view_size % 2 == 1
  547. assert agent_view_size >= 3
  548. self.agent_view_size = agent_view_size
  549. # Compute observation space with specified view size
  550. new_image_space = gym.spaces.Box(
  551. low=0, high=255, shape=(agent_view_size, agent_view_size, 3), dtype="uint8"
  552. )
  553. # Override the environment's observation spaceexit
  554. self.observation_space = spaces.Dict(
  555. {**self.observation_space.spaces, "image": new_image_space}
  556. )
  557. def observation(self, obs):
  558. env = self.unwrapped
  559. grid, vis_mask = env.gen_obs_grid(self.agent_view_size)
  560. # Encode the partially observable view into a numpy array
  561. image = grid.encode(vis_mask)
  562. return {**obs, "image": image}
  563. class DirectionObsWrapper(ObservationWrapper):
  564. """
  565. Provides the slope/angular direction to the goal with the observations as modeled by (y2 - y2 )/( x2 - x1)
  566. type = {slope , angle}
  567. """
  568. def __init__(self, env, type="slope"):
  569. super().__init__(env)
  570. self.goal_position: tuple = None
  571. self.type = type
  572. def reset(self):
  573. obs = self.env.reset()
  574. if not self.goal_position:
  575. self.goal_position = [
  576. x for x, y in enumerate(self.grid.grid) if isinstance(y, Goal)
  577. ]
  578. # in case there are multiple goals , needs to be handled for other env types
  579. if len(self.goal_position) >= 1:
  580. self.goal_position = (
  581. int(self.goal_position[0] / self.height),
  582. self.goal_position[0] % self.width,
  583. )
  584. return obs
  585. def observation(self, obs):
  586. slope = np.divide(
  587. self.goal_position[1] - self.agent_pos[1],
  588. self.goal_position[0] - self.agent_pos[0],
  589. )
  590. obs["goal_direction"] = np.arctan(slope) if self.type == "angle" else slope
  591. return obs
  592. class SymbolicObsWrapper(ObservationWrapper):
  593. """
  594. Fully observable grid with a symbolic state representation.
  595. The symbol is a triple of (X, Y, IDX), where X and Y are
  596. the coordinates on the grid, and IDX is the id of the object.
  597. """
  598. def __init__(self, env):
  599. super().__init__(env)
  600. new_image_space = spaces.Box(
  601. low=0,
  602. high=max(OBJECT_TO_IDX.values()),
  603. shape=(self.env.width, self.env.height, 3), # number of cells
  604. dtype="uint8",
  605. )
  606. self.observation_space = spaces.Dict(
  607. {**self.observation_space.spaces, "image": new_image_space}
  608. )
  609. def observation(self, obs):
  610. objects = np.array(
  611. [OBJECT_TO_IDX[o.type] if o is not None else -1 for o in self.grid.grid]
  612. )
  613. agent_pos = self.env.agent_pos
  614. w, h = self.width, self.height
  615. grid = np.mgrid[:w, :h]
  616. grid = np.concatenate([grid, objects.reshape(1, w, h)])
  617. grid = np.transpose(grid, (1, 2, 0))
  618. grid[agent_pos[0], agent_pos[1], 2] = OBJECT_TO_IDX["agent"]
  619. obs["image"] = grid
  620. return obs