wrappers.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523
  1. import math
  2. import operator
  3. from functools import reduce
  4. import gym
  5. import numpy as np
  6. from gym import spaces
  7. from gym.core import ObservationWrapper, Wrapper
  8. from gym_minigrid.minigrid import COLOR_TO_IDX, OBJECT_TO_IDX, STATE_TO_IDX, Goal
  9. class ReseedWrapper(Wrapper):
  10. """
  11. Wrapper to always regenerate an environment with the same set of seeds.
  12. This can be used to force an environment to always keep the same
  13. configuration when reset.
  14. """
  15. def __init__(self, env, seeds=[0], seed_idx=0):
  16. self.seeds = list(seeds)
  17. self.seed_idx = seed_idx
  18. super().__init__(env)
  19. def reset(self, **kwargs):
  20. seed = self.seeds[self.seed_idx]
  21. self.seed_idx = (self.seed_idx + 1) % len(self.seeds)
  22. return self.env.reset(seed=seed, **kwargs)
  23. class ActionBonus(gym.Wrapper):
  24. """
  25. Wrapper which adds an exploration bonus.
  26. This is a reward to encourage exploration of less
  27. visited (state,action) pairs.
  28. """
  29. def __init__(self, env):
  30. super().__init__(env)
  31. self.counts = {}
  32. def step(self, action):
  33. obs, reward, terminated, truncated, info = self.env.step(action)
  34. env = self.unwrapped
  35. tup = (tuple(env.agent_pos), env.agent_dir, action)
  36. # Get the count for this (s,a) pair
  37. pre_count = 0
  38. if tup in self.counts:
  39. pre_count = self.counts[tup]
  40. # Update the count for this (s,a) pair
  41. new_count = pre_count + 1
  42. self.counts[tup] = new_count
  43. bonus = 1 / math.sqrt(new_count)
  44. reward += bonus
  45. return obs, reward, terminated, truncated, info
  46. def reset(self, **kwargs):
  47. return self.env.reset(**kwargs)
  48. class StateBonus(Wrapper):
  49. """
  50. Adds an exploration bonus based on which positions
  51. are visited on the grid.
  52. """
  53. def __init__(self, env):
  54. super().__init__(env)
  55. self.counts = {}
  56. def step(self, action):
  57. obs, reward, terminated, truncated, info = self.env.step(action)
  58. # Tuple based on which we index the counts
  59. # We use the position after an update
  60. env = self.unwrapped
  61. tup = tuple(env.agent_pos)
  62. # Get the count for this key
  63. pre_count = 0
  64. if tup in self.counts:
  65. pre_count = self.counts[tup]
  66. # Update the count for this key
  67. new_count = pre_count + 1
  68. self.counts[tup] = new_count
  69. bonus = 1 / math.sqrt(new_count)
  70. reward += bonus
  71. return obs, reward, terminated, truncated, info
  72. def reset(self, **kwargs):
  73. return self.env.reset(**kwargs)
  74. class ImgObsWrapper(ObservationWrapper):
  75. """
  76. Use the image as the only observation output, no language/mission.
  77. """
  78. def __init__(self, env):
  79. super().__init__(env)
  80. self.observation_space = env.observation_space.spaces["image"]
  81. def observation(self, obs):
  82. return obs["image"]
  83. class OneHotPartialObsWrapper(ObservationWrapper):
  84. """
  85. Wrapper to get a one-hot encoding of a partially observable
  86. agent view as observation.
  87. """
  88. def __init__(self, env, tile_size=8):
  89. super().__init__(env)
  90. self.tile_size = tile_size
  91. obs_shape = env.observation_space["image"].shape
  92. # Number of bits per cell
  93. num_bits = len(OBJECT_TO_IDX) + len(COLOR_TO_IDX) + len(STATE_TO_IDX)
  94. new_image_space = spaces.Box(
  95. low=0, high=255, shape=(obs_shape[0], obs_shape[1], num_bits), dtype="uint8"
  96. )
  97. self.observation_space = spaces.Dict(
  98. {**self.observation_space.spaces, "image": new_image_space}
  99. )
  100. def observation(self, obs):
  101. img = obs["image"]
  102. out = np.zeros(self.observation_space.spaces["image"].shape, dtype="uint8")
  103. for i in range(img.shape[0]):
  104. for j in range(img.shape[1]):
  105. type = img[i, j, 0]
  106. color = img[i, j, 1]
  107. state = img[i, j, 2]
  108. out[i, j, type] = 1
  109. out[i, j, len(OBJECT_TO_IDX) + color] = 1
  110. out[i, j, len(OBJECT_TO_IDX) + len(COLOR_TO_IDX) + state] = 1
  111. return {**obs, "image": out}
  112. class RGBImgObsWrapper(ObservationWrapper):
  113. """
  114. Wrapper to use fully observable RGB image as observation,
  115. This can be used to have the agent to solve the gridworld in pixel space.
  116. """
  117. def __init__(self, env, tile_size=8):
  118. super().__init__(env)
  119. self.tile_size = tile_size
  120. new_image_space = spaces.Box(
  121. low=0,
  122. high=255,
  123. shape=(self.env.width * tile_size, self.env.height * tile_size, 3),
  124. dtype="uint8",
  125. )
  126. self.observation_space = spaces.Dict(
  127. {**self.observation_space.spaces, "image": new_image_space}
  128. )
  129. def observation(self, obs):
  130. env = self.unwrapped
  131. rgb_img = env.render(mode="rgb_array", highlight=True, tile_size=self.tile_size)
  132. return {**obs, "image": rgb_img}
  133. class RGBImgPartialObsWrapper(ObservationWrapper):
  134. """
  135. Wrapper to use partially observable RGB image as observation.
  136. This can be used to have the agent to solve the gridworld in pixel space.
  137. """
  138. def __init__(self, env, tile_size=8):
  139. super().__init__(env)
  140. self.tile_size = tile_size
  141. obs_shape = env.observation_space.spaces["image"].shape
  142. new_image_space = spaces.Box(
  143. low=0,
  144. high=255,
  145. shape=(obs_shape[0] * tile_size, obs_shape[1] * tile_size, 3),
  146. dtype="uint8",
  147. )
  148. self.observation_space = spaces.Dict(
  149. {**self.observation_space.spaces, "image": new_image_space}
  150. )
  151. def observation(self, obs):
  152. env = self.unwrapped
  153. rgb_img_partial = env.get_obs_render(obs["image"], tile_size=self.tile_size)
  154. return {**obs, "image": rgb_img_partial}
  155. class FullyObsWrapper(ObservationWrapper):
  156. """
  157. Fully observable gridworld using a compact grid encoding
  158. """
  159. def __init__(self, env):
  160. super().__init__(env)
  161. new_image_space = spaces.Box(
  162. low=0,
  163. high=255,
  164. shape=(self.env.width, self.env.height, 3), # number of cells
  165. dtype="uint8",
  166. )
  167. self.observation_space = spaces.Dict(
  168. {**self.observation_space.spaces, "image": new_image_space}
  169. )
  170. def observation(self, obs):
  171. env = self.unwrapped
  172. full_grid = env.grid.encode()
  173. full_grid[env.agent_pos[0]][env.agent_pos[1]] = np.array(
  174. [OBJECT_TO_IDX["agent"], COLOR_TO_IDX["red"], env.agent_dir]
  175. )
  176. return {**obs, "image": full_grid}
  177. class DictObservationSpaceWrapper(ObservationWrapper):
  178. """
  179. Transforms the observation space (that has a textual component) to a fully numerical observation space,
  180. where the textual instructions are replaced by arrays representing the indices of each word in a fixed vocabulary.
  181. """
  182. def __init__(self, env, max_words_in_mission=50, word_dict=None):
  183. """
  184. max_words_in_mission is the length of the array to represent a mission, value 0 for missing words
  185. word_dict is a dictionary of words to use (keys=words, values=indices from 1 to < max_words_in_mission),
  186. if None, use the Minigrid language
  187. """
  188. super().__init__(env)
  189. if word_dict is None:
  190. word_dict = self.get_minigrid_words()
  191. self.max_words_in_mission = max_words_in_mission
  192. self.word_dict = word_dict
  193. image_observation_space = spaces.Box(
  194. low=0,
  195. high=255,
  196. shape=(self.agent_view_size, self.agent_view_size, 3),
  197. dtype="uint8",
  198. )
  199. self.observation_space = spaces.Dict(
  200. {
  201. "image": image_observation_space,
  202. "direction": spaces.Discrete(4),
  203. "mission": spaces.MultiDiscrete(
  204. [len(self.word_dict.keys())] * max_words_in_mission
  205. ),
  206. }
  207. )
  208. @staticmethod
  209. def get_minigrid_words():
  210. colors = ["red", "green", "blue", "yellow", "purple", "grey"]
  211. objects = [
  212. "unseen",
  213. "empty",
  214. "wall",
  215. "floor",
  216. "box",
  217. "key",
  218. "ball",
  219. "door",
  220. "goal",
  221. "agent",
  222. "lava",
  223. ]
  224. verbs = [
  225. "pick",
  226. "avoid",
  227. "get",
  228. "find",
  229. "put",
  230. "use",
  231. "open",
  232. "go",
  233. "fetch",
  234. "reach",
  235. "unlock",
  236. "traverse",
  237. ]
  238. extra_words = [
  239. "up",
  240. "the",
  241. "a",
  242. "at",
  243. ",",
  244. "square",
  245. "and",
  246. "then",
  247. "to",
  248. "of",
  249. "rooms",
  250. "near",
  251. "opening",
  252. "must",
  253. "you",
  254. "matching",
  255. "end",
  256. "hallway",
  257. "object",
  258. "from",
  259. "room",
  260. ]
  261. all_words = colors + objects + verbs + extra_words
  262. assert len(all_words) == len(set(all_words))
  263. return {word: i for i, word in enumerate(all_words)}
  264. def string_to_indices(self, string, offset=1):
  265. """
  266. Convert a string to a list of indices.
  267. """
  268. indices = []
  269. # adding space before and after commas
  270. string = string.replace(",", " , ")
  271. for word in string.split():
  272. if word in self.word_dict.keys():
  273. indices.append(self.word_dict[word] + offset)
  274. else:
  275. raise ValueError(f"Unknown word: {word}")
  276. return indices
  277. def observation(self, obs):
  278. obs["mission"] = self.string_to_indices(obs["mission"])
  279. assert len(obs["mission"]) < self.max_words_in_mission
  280. obs["mission"] += [0] * (self.max_words_in_mission - len(obs["mission"]))
  281. return obs
  282. class FlatObsWrapper(ObservationWrapper):
  283. """
  284. Encode mission strings using a one-hot scheme,
  285. and combine these with observed images into one flat array
  286. """
  287. def __init__(self, env, maxStrLen=96):
  288. super().__init__(env)
  289. self.maxStrLen = maxStrLen
  290. self.numCharCodes = 28
  291. imgSpace = env.observation_space.spaces["image"]
  292. imgSize = reduce(operator.mul, imgSpace.shape, 1)
  293. self.observation_space = spaces.Box(
  294. low=0,
  295. high=255,
  296. shape=(imgSize + self.numCharCodes * self.maxStrLen,),
  297. dtype="uint8",
  298. )
  299. self.cachedStr: str = None
  300. def observation(self, obs):
  301. image = obs["image"]
  302. mission = obs["mission"]
  303. # Cache the last-encoded mission string
  304. if mission != self.cachedStr:
  305. assert (
  306. len(mission) <= self.maxStrLen
  307. ), f"mission string too long ({len(mission)} chars)"
  308. mission = mission.lower()
  309. strArray = np.zeros(
  310. shape=(self.maxStrLen, self.numCharCodes), dtype="float32"
  311. )
  312. for idx, ch in enumerate(mission):
  313. if ch >= "a" and ch <= "z":
  314. chNo = ord(ch) - ord("a")
  315. elif ch == " ":
  316. chNo = ord("z") - ord("a") + 1
  317. elif ch == ",":
  318. chNo = ord("z") - ord("a") + 2
  319. else:
  320. raise ValueError(
  321. f"Character {ch} is not available in mission string."
  322. )
  323. assert chNo < self.numCharCodes, "%s : %d" % (ch, chNo)
  324. strArray[idx, chNo] = 1
  325. self.cachedStr = mission
  326. self.cachedArray = strArray
  327. obs = np.concatenate((image.flatten(), self.cachedArray.flatten()))
  328. return obs
  329. class ViewSizeWrapper(Wrapper):
  330. """
  331. Wrapper to customize the agent field of view size.
  332. This cannot be used with fully observable wrappers.
  333. """
  334. def __init__(self, env, agent_view_size=7):
  335. super().__init__(env)
  336. assert agent_view_size % 2 == 1
  337. assert agent_view_size >= 3
  338. self.agent_view_size = agent_view_size
  339. # Compute observation space with specified view size
  340. new_image_space = gym.spaces.Box(
  341. low=0, high=255, shape=(agent_view_size, agent_view_size, 3), dtype="uint8"
  342. )
  343. # Override the environment's observation spaceexit
  344. self.observation_space = spaces.Dict(
  345. {**self.observation_space.spaces, "image": new_image_space}
  346. )
  347. def observation(self, obs):
  348. env = self.unwrapped
  349. grid, vis_mask = env.gen_obs_grid(self.agent_view_size)
  350. # Encode the partially observable view into a numpy array
  351. image = grid.encode(vis_mask)
  352. return {**obs, "image": image}
  353. class DirectionObsWrapper(ObservationWrapper):
  354. """
  355. Provides the slope/angular direction to the goal with the observations as modeled by (y2 - y2 )/( x2 - x1)
  356. type = {slope , angle}
  357. """
  358. def __init__(self, env, type="slope"):
  359. super().__init__(env)
  360. self.goal_position: tuple = None
  361. self.type = type
  362. def reset(self):
  363. obs = self.env.reset()
  364. if not self.goal_position:
  365. self.goal_position = [
  366. x for x, y in enumerate(self.grid.grid) if isinstance(y, Goal)
  367. ]
  368. # in case there are multiple goals , needs to be handled for other env types
  369. if len(self.goal_position) >= 1:
  370. self.goal_position = (
  371. int(self.goal_position[0] / self.height),
  372. self.goal_position[0] % self.width,
  373. )
  374. return obs
  375. def observation(self, obs):
  376. slope = np.divide(
  377. self.goal_position[1] - self.agent_pos[1],
  378. self.goal_position[0] - self.agent_pos[0],
  379. )
  380. obs["goal_direction"] = np.arctan(slope) if self.type == "angle" else slope
  381. return obs
  382. class SymbolicObsWrapper(ObservationWrapper):
  383. """
  384. Fully observable grid with a symbolic state representation.
  385. The symbol is a triple of (X, Y, IDX), where X and Y are
  386. the coordinates on the grid, and IDX is the id of the object.
  387. """
  388. def __init__(self, env):
  389. super().__init__(env)
  390. new_image_space = spaces.Box(
  391. low=0,
  392. high=max(OBJECT_TO_IDX.values()),
  393. shape=(self.env.width, self.env.height, 3), # number of cells
  394. dtype="uint8",
  395. )
  396. self.observation_space = spaces.Dict(
  397. {**self.observation_space.spaces, "image": new_image_space}
  398. )
  399. def observation(self, obs):
  400. objects = np.array(
  401. [OBJECT_TO_IDX[o.type] if o is not None else -1 for o in self.grid.grid]
  402. )
  403. w, h = self.width, self.height
  404. grid = np.mgrid[:w, :h]
  405. grid = np.concatenate([grid, objects.reshape(1, w, h)])
  406. grid = np.transpose(grid, (1, 2, 0))
  407. obs["image"] = grid
  408. return obs