wrappers.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532
  1. from __future__ import annotations
  2. import math
  3. import operator
  4. from functools import reduce
  5. import gymnasium as gym
  6. import numpy as np
  7. from gymnasium import spaces
  8. from gymnasium.core import ObservationWrapper, Wrapper
  9. from minigrid.core.constants import COLOR_TO_IDX, OBJECT_TO_IDX, STATE_TO_IDX
  10. from minigrid.core.world_object import Goal
  11. class ReseedWrapper(Wrapper):
  12. """
  13. Wrapper to always regenerate an environment with the same set of seeds.
  14. This can be used to force an environment to always keep the same
  15. configuration when reset.
  16. """
  17. def __init__(self, env, seeds=[0], seed_idx=0):
  18. self.seeds = list(seeds)
  19. self.seed_idx = seed_idx
  20. super().__init__(env)
  21. def reset(self, **kwargs):
  22. seed = self.seeds[self.seed_idx]
  23. self.seed_idx = (self.seed_idx + 1) % len(self.seeds)
  24. return self.env.reset(seed=seed, **kwargs)
  25. def step(self, action):
  26. return self.env.step(action)
  27. class ActionBonus(gym.Wrapper):
  28. """
  29. Wrapper which adds an exploration bonus.
  30. This is a reward to encourage exploration of less
  31. visited (state,action) pairs.
  32. """
  33. def __init__(self, env):
  34. super().__init__(env)
  35. self.counts = {}
  36. def step(self, action):
  37. obs, reward, terminated, truncated, info = self.env.step(action)
  38. env = self.unwrapped
  39. tup = (tuple(env.agent_pos), env.agent_dir, action)
  40. # Get the count for this (s,a) pair
  41. pre_count = 0
  42. if tup in self.counts:
  43. pre_count = self.counts[tup]
  44. # Update the count for this (s,a) pair
  45. new_count = pre_count + 1
  46. self.counts[tup] = new_count
  47. bonus = 1 / math.sqrt(new_count)
  48. reward += bonus
  49. return obs, reward, terminated, truncated, info
  50. def reset(self, **kwargs):
  51. return self.env.reset(**kwargs)
  52. class StateBonus(Wrapper):
  53. """
  54. Adds an exploration bonus based on which positions
  55. are visited on the grid.
  56. """
  57. def __init__(self, env):
  58. super().__init__(env)
  59. self.counts = {}
  60. def step(self, action):
  61. obs, reward, terminated, truncated, info = self.env.step(action)
  62. # Tuple based on which we index the counts
  63. # We use the position after an update
  64. env = self.unwrapped
  65. tup = tuple(env.agent_pos)
  66. # Get the count for this key
  67. pre_count = 0
  68. if tup in self.counts:
  69. pre_count = self.counts[tup]
  70. # Update the count for this key
  71. new_count = pre_count + 1
  72. self.counts[tup] = new_count
  73. bonus = 1 / math.sqrt(new_count)
  74. reward += bonus
  75. return obs, reward, terminated, truncated, info
  76. def reset(self, **kwargs):
  77. return self.env.reset(**kwargs)
  78. class ImgObsWrapper(ObservationWrapper):
  79. """
  80. Use the image as the only observation output, no language/mission.
  81. """
  82. def __init__(self, env):
  83. super().__init__(env)
  84. self.observation_space = env.observation_space.spaces["image"]
  85. def observation(self, obs):
  86. return obs["image"]
  87. class OneHotPartialObsWrapper(ObservationWrapper):
  88. """
  89. Wrapper to get a one-hot encoding of a partially observable
  90. agent view as observation.
  91. """
  92. def __init__(self, env, tile_size=8):
  93. super().__init__(env)
  94. self.tile_size = tile_size
  95. obs_shape = env.observation_space["image"].shape
  96. # Number of bits per cell
  97. num_bits = len(OBJECT_TO_IDX) + len(COLOR_TO_IDX) + len(STATE_TO_IDX)
  98. new_image_space = spaces.Box(
  99. low=0, high=255, shape=(obs_shape[0], obs_shape[1], num_bits), dtype="uint8"
  100. )
  101. self.observation_space = spaces.Dict(
  102. {**self.observation_space.spaces, "image": new_image_space}
  103. )
  104. def observation(self, obs):
  105. img = obs["image"]
  106. out = np.zeros(self.observation_space.spaces["image"].shape, dtype="uint8")
  107. for i in range(img.shape[0]):
  108. for j in range(img.shape[1]):
  109. type = img[i, j, 0]
  110. color = img[i, j, 1]
  111. state = img[i, j, 2]
  112. out[i, j, type] = 1
  113. out[i, j, len(OBJECT_TO_IDX) + color] = 1
  114. out[i, j, len(OBJECT_TO_IDX) + len(COLOR_TO_IDX) + state] = 1
  115. return {**obs, "image": out}
  116. class RGBImgObsWrapper(ObservationWrapper):
  117. """
  118. Wrapper to use fully observable RGB image as observation,
  119. This can be used to have the agent to solve the gridworld in pixel space.
  120. """
  121. def __init__(self, env, tile_size=8):
  122. super().__init__(env)
  123. self.tile_size = tile_size
  124. new_image_space = spaces.Box(
  125. low=0,
  126. high=255,
  127. shape=(self.env.width * tile_size, self.env.height * tile_size, 3),
  128. dtype="uint8",
  129. )
  130. self.observation_space = spaces.Dict(
  131. {**self.observation_space.spaces, "image": new_image_space}
  132. )
  133. def observation(self, obs):
  134. rgb_img = self.get_frame(highlight=True, tile_size=self.tile_size)
  135. return {**obs, "image": rgb_img}
  136. class RGBImgPartialObsWrapper(ObservationWrapper):
  137. """
  138. Wrapper to use partially observable RGB image as observation.
  139. This can be used to have the agent to solve the gridworld in pixel space.
  140. """
  141. def __init__(self, env, tile_size=8):
  142. super().__init__(env)
  143. # Rendering attributes for observations
  144. self.tile_size = tile_size
  145. obs_shape = env.observation_space.spaces["image"].shape
  146. new_image_space = spaces.Box(
  147. low=0,
  148. high=255,
  149. shape=(obs_shape[0] * tile_size, obs_shape[1] * tile_size, 3),
  150. dtype="uint8",
  151. )
  152. self.observation_space = spaces.Dict(
  153. {**self.observation_space.spaces, "image": new_image_space}
  154. )
  155. def observation(self, obs):
  156. rgb_img_partial = self.get_frame(tile_size=self.tile_size, agent_pov=True)
  157. return {**obs, "image": rgb_img_partial}
  158. class FullyObsWrapper(ObservationWrapper):
  159. """
  160. Fully observable gridworld using a compact grid encoding
  161. """
  162. def __init__(self, env):
  163. super().__init__(env)
  164. new_image_space = spaces.Box(
  165. low=0,
  166. high=255,
  167. shape=(self.env.width, self.env.height, 3), # number of cells
  168. dtype="uint8",
  169. )
  170. self.observation_space = spaces.Dict(
  171. {**self.observation_space.spaces, "image": new_image_space}
  172. )
  173. def observation(self, obs):
  174. env = self.unwrapped
  175. full_grid = env.grid.encode()
  176. full_grid[env.agent_pos[0]][env.agent_pos[1]] = np.array(
  177. [OBJECT_TO_IDX["agent"], COLOR_TO_IDX["red"], env.agent_dir]
  178. )
  179. return {**obs, "image": full_grid}
  180. class DictObservationSpaceWrapper(ObservationWrapper):
  181. """
  182. Transforms the observation space (that has a textual component) to a fully numerical observation space,
  183. where the textual instructions are replaced by arrays representing the indices of each word in a fixed vocabulary.
  184. This wrapper is not applicable to BabyAI environments, given that these have their own language component.
  185. """
  186. def __init__(self, env, max_words_in_mission=50, word_dict=None):
  187. """
  188. max_words_in_mission is the length of the array to represent a mission, value 0 for missing words
  189. word_dict is a dictionary of words to use (keys=words, values=indices from 1 to < max_words_in_mission),
  190. if None, use the Minigrid language
  191. """
  192. super().__init__(env)
  193. if word_dict is None:
  194. word_dict = self.get_minigrid_words()
  195. self.max_words_in_mission = max_words_in_mission
  196. self.word_dict = word_dict
  197. image_observation_space = spaces.Box(
  198. low=0,
  199. high=255,
  200. shape=(self.agent_view_size, self.agent_view_size, 3),
  201. dtype="uint8",
  202. )
  203. self.observation_space = spaces.Dict(
  204. {
  205. "image": image_observation_space,
  206. "direction": spaces.Discrete(4),
  207. "mission": spaces.MultiDiscrete(
  208. [len(self.word_dict.keys())] * max_words_in_mission
  209. ),
  210. }
  211. )
  212. @staticmethod
  213. def get_minigrid_words():
  214. colors = ["red", "green", "blue", "yellow", "purple", "grey"]
  215. objects = [
  216. "unseen",
  217. "empty",
  218. "wall",
  219. "floor",
  220. "box",
  221. "key",
  222. "ball",
  223. "door",
  224. "goal",
  225. "agent",
  226. "lava",
  227. ]
  228. verbs = [
  229. "pick",
  230. "avoid",
  231. "get",
  232. "find",
  233. "put",
  234. "use",
  235. "open",
  236. "go",
  237. "fetch",
  238. "reach",
  239. "unlock",
  240. "traverse",
  241. ]
  242. extra_words = [
  243. "up",
  244. "the",
  245. "a",
  246. "at",
  247. ",",
  248. "square",
  249. "and",
  250. "then",
  251. "to",
  252. "of",
  253. "rooms",
  254. "near",
  255. "opening",
  256. "must",
  257. "you",
  258. "matching",
  259. "end",
  260. "hallway",
  261. "object",
  262. "from",
  263. "room",
  264. ]
  265. all_words = colors + objects + verbs + extra_words
  266. assert len(all_words) == len(set(all_words))
  267. return {word: i for i, word in enumerate(all_words)}
  268. def string_to_indices(self, string, offset=1):
  269. """
  270. Convert a string to a list of indices.
  271. """
  272. indices = []
  273. # adding space before and after commas
  274. string = string.replace(",", " , ")
  275. for word in string.split():
  276. if word in self.word_dict.keys():
  277. indices.append(self.word_dict[word] + offset)
  278. else:
  279. raise ValueError(f"Unknown word: {word}")
  280. return indices
  281. def observation(self, obs):
  282. obs["mission"] = self.string_to_indices(obs["mission"])
  283. assert len(obs["mission"]) < self.max_words_in_mission
  284. obs["mission"] += [0] * (self.max_words_in_mission - len(obs["mission"]))
  285. return obs
  286. class FlatObsWrapper(ObservationWrapper):
  287. """
  288. Encode mission strings using a one-hot scheme,
  289. and combine these with observed images into one flat array.
  290. This wrapper is not applicable to BabyAI environments, given that these have their own language component.
  291. """
  292. def __init__(self, env, maxStrLen=96):
  293. super().__init__(env)
  294. self.maxStrLen = maxStrLen
  295. self.numCharCodes = 28
  296. imgSpace = env.observation_space.spaces["image"]
  297. imgSize = reduce(operator.mul, imgSpace.shape, 1)
  298. self.observation_space = spaces.Box(
  299. low=0,
  300. high=255,
  301. shape=(imgSize + self.numCharCodes * self.maxStrLen,),
  302. dtype="uint8",
  303. )
  304. self.cachedStr: str = None
  305. def observation(self, obs):
  306. image = obs["image"]
  307. mission = obs["mission"]
  308. # Cache the last-encoded mission string
  309. if mission != self.cachedStr:
  310. assert (
  311. len(mission) <= self.maxStrLen
  312. ), f"mission string too long ({len(mission)} chars)"
  313. mission = mission.lower()
  314. strArray = np.zeros(
  315. shape=(self.maxStrLen, self.numCharCodes), dtype="float32"
  316. )
  317. for idx, ch in enumerate(mission):
  318. if ch >= "a" and ch <= "z":
  319. chNo = ord(ch) - ord("a")
  320. elif ch == " ":
  321. chNo = ord("z") - ord("a") + 1
  322. elif ch == ",":
  323. chNo = ord("z") - ord("a") + 2
  324. else:
  325. raise ValueError(
  326. f"Character {ch} is not available in mission string."
  327. )
  328. assert chNo < self.numCharCodes, "%s : %d" % (ch, chNo)
  329. strArray[idx, chNo] = 1
  330. self.cachedStr = mission
  331. self.cachedArray = strArray
  332. obs = np.concatenate((image.flatten(), self.cachedArray.flatten()))
  333. return obs
  334. class ViewSizeWrapper(Wrapper):
  335. """
  336. Wrapper to customize the agent field of view size.
  337. This cannot be used with fully observable wrappers.
  338. """
  339. def __init__(self, env, agent_view_size=7):
  340. super().__init__(env)
  341. assert agent_view_size % 2 == 1
  342. assert agent_view_size >= 3
  343. self.agent_view_size = agent_view_size
  344. # Compute observation space with specified view size
  345. new_image_space = gym.spaces.Box(
  346. low=0, high=255, shape=(agent_view_size, agent_view_size, 3), dtype="uint8"
  347. )
  348. # Override the environment's observation spaceexit
  349. self.observation_space = spaces.Dict(
  350. {**self.observation_space.spaces, "image": new_image_space}
  351. )
  352. def observation(self, obs):
  353. env = self.unwrapped
  354. grid, vis_mask = env.gen_obs_grid(self.agent_view_size)
  355. # Encode the partially observable view into a numpy array
  356. image = grid.encode(vis_mask)
  357. return {**obs, "image": image}
  358. class DirectionObsWrapper(ObservationWrapper):
  359. """
  360. Provides the slope/angular direction to the goal with the observations as modeled by (y2 - y2 )/( x2 - x1)
  361. type = {slope , angle}
  362. """
  363. def __init__(self, env, type="slope"):
  364. super().__init__(env)
  365. self.goal_position: tuple = None
  366. self.type = type
  367. def reset(self):
  368. obs = self.env.reset()
  369. if not self.goal_position:
  370. self.goal_position = [
  371. x for x, y in enumerate(self.grid.grid) if isinstance(y, Goal)
  372. ]
  373. # in case there are multiple goals , needs to be handled for other env types
  374. if len(self.goal_position) >= 1:
  375. self.goal_position = (
  376. int(self.goal_position[0] / self.height),
  377. self.goal_position[0] % self.width,
  378. )
  379. return obs
  380. def observation(self, obs):
  381. slope = np.divide(
  382. self.goal_position[1] - self.agent_pos[1],
  383. self.goal_position[0] - self.agent_pos[0],
  384. )
  385. obs["goal_direction"] = np.arctan(slope) if self.type == "angle" else slope
  386. return obs
  387. class SymbolicObsWrapper(ObservationWrapper):
  388. """
  389. Fully observable grid with a symbolic state representation.
  390. The symbol is a triple of (X, Y, IDX), where X and Y are
  391. the coordinates on the grid, and IDX is the id of the object.
  392. """
  393. def __init__(self, env):
  394. super().__init__(env)
  395. new_image_space = spaces.Box(
  396. low=0,
  397. high=max(OBJECT_TO_IDX.values()),
  398. shape=(self.env.width, self.env.height, 3), # number of cells
  399. dtype="uint8",
  400. )
  401. self.observation_space = spaces.Dict(
  402. {**self.observation_space.spaces, "image": new_image_space}
  403. )
  404. def observation(self, obs):
  405. objects = np.array(
  406. [OBJECT_TO_IDX[o.type] if o is not None else -1 for o in self.grid.grid]
  407. )
  408. agent_pos = self.env.agent_pos
  409. w, h = self.width, self.height
  410. grid = np.mgrid[:w, :h]
  411. grid = np.concatenate([grid, objects.reshape(1, w, h)])
  412. grid = np.transpose(grid, (1, 2, 0))
  413. grid[agent_pos[0], agent_pos[1], 2] = OBJECT_TO_IDX["agent"]
  414. obs["image"] = grid
  415. return obs