wrappers.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496
  1. import math
  2. import operator
  3. from functools import reduce
  4. import numpy as np
  5. import gym
  6. from gym import error, spaces, utils
  7. from .minigrid import OBJECT_TO_IDX, COLOR_TO_IDX, STATE_TO_IDX, Goal
  8. class ReseedWrapper(gym.Wrapper):
  9. """
  10. Wrapper to always regenerate an environment with the same set of seeds.
  11. This can be used to force an environment to always keep the same
  12. configuration when reset.
  13. """
  14. def __init__(self, env, seeds=[0], seed_idx=0):
  15. self.seeds = list(seeds)
  16. self.seed_idx = seed_idx
  17. super().__init__(env)
  18. def reset(self, **kwargs):
  19. seed = self.seeds[self.seed_idx]
  20. self.seed_idx = (self.seed_idx + 1) % len(self.seeds)
  21. return self.env.reset(seed=seed, **kwargs)
  22. def step(self, action):
  23. obs, reward, done, info = self.env.step(action)
  24. return obs, reward, done, info
  25. class ActionBonus(gym.Wrapper):
  26. """
  27. Wrapper which adds an exploration bonus.
  28. This is a reward to encourage exploration of less
  29. visited (state,action) pairs.
  30. """
  31. def __init__(self, env):
  32. super().__init__(env)
  33. self.counts = {}
  34. def step(self, action):
  35. obs, reward, done, info = self.env.step(action)
  36. env = self.unwrapped
  37. tup = (tuple(env.agent_pos), env.agent_dir, action)
  38. # Get the count for this (s,a) pair
  39. pre_count = 0
  40. if tup in self.counts:
  41. pre_count = self.counts[tup]
  42. # Update the count for this (s,a) pair
  43. new_count = pre_count + 1
  44. self.counts[tup] = new_count
  45. bonus = 1 / math.sqrt(new_count)
  46. reward += bonus
  47. return obs, reward, done, info
  48. def reset(self, **kwargs):
  49. return self.env.reset(**kwargs)
  50. class StateBonus(gym.Wrapper):
  51. """
  52. Adds an exploration bonus based on which positions
  53. are visited on the grid.
  54. """
  55. def __init__(self, env):
  56. super().__init__(env)
  57. self.counts = {}
  58. def step(self, action):
  59. obs, reward, done, info = self.env.step(action)
  60. # Tuple based on which we index the counts
  61. # We use the position after an update
  62. env = self.unwrapped
  63. tup = (tuple(env.agent_pos))
  64. # Get the count for this key
  65. pre_count = 0
  66. if tup in self.counts:
  67. pre_count = self.counts[tup]
  68. # Update the count for this key
  69. new_count = pre_count + 1
  70. self.counts[tup] = new_count
  71. bonus = 1 / math.sqrt(new_count)
  72. reward += bonus
  73. return obs, reward, done, info
  74. def reset(self, **kwargs):
  75. return self.env.reset(**kwargs)
  76. class ImgObsWrapper(gym.ObservationWrapper):
  77. """
  78. Use the image as the only observation output, no language/mission.
  79. """
  80. def __init__(self, env):
  81. super().__init__(env)
  82. self.observation_space = env.observation_space.spaces['image']
  83. def observation(self, obs):
  84. return obs['image']
  85. class OneHotPartialObsWrapper(gym.ObservationWrapper):
  86. """
  87. Wrapper to get a one-hot encoding of a partially observable
  88. agent view as observation.
  89. """
  90. def __init__(self, env, tile_size=8):
  91. super().__init__(env)
  92. self.tile_size = tile_size
  93. obs_shape = env.observation_space['image'].shape
  94. # Number of bits per cell
  95. num_bits = len(OBJECT_TO_IDX) + len(COLOR_TO_IDX) + len(STATE_TO_IDX)
  96. new_image_space = spaces.Box(
  97. low=0,
  98. high=255,
  99. shape=(obs_shape[0], obs_shape[1], num_bits),
  100. dtype='uint8'
  101. )
  102. self.observation_space = spaces.Dict(
  103. {**self.observation_space, 'image': new_image_space})
  104. def observation(self, obs):
  105. img = obs['image']
  106. out = np.zeros(
  107. self.observation_space.spaces['image'].shape, dtype='uint8')
  108. for i in range(img.shape[0]):
  109. for j in range(img.shape[1]):
  110. type = img[i, j, 0]
  111. color = img[i, j, 1]
  112. state = img[i, j, 2]
  113. out[i, j, type] = 1
  114. out[i, j, len(OBJECT_TO_IDX) + color] = 1
  115. out[i, j, len(OBJECT_TO_IDX) + len(COLOR_TO_IDX) + state] = 1
  116. return {
  117. **obs,
  118. 'image': out
  119. }
  120. class RGBImgObsWrapper(gym.ObservationWrapper):
  121. """
  122. Wrapper to use fully observable RGB image as observation,
  123. This can be used to have the agent to solve the gridworld in pixel space.
  124. To use it, make the unwrapped environment with render_mode='rgb_array'.
  125. """
  126. def __init__(self, env, tile_size=8):
  127. super().__init__(env)
  128. self.tile_size = tile_size
  129. new_image_space = spaces.Box(
  130. low=0,
  131. high=255,
  132. shape=(self.env.width * tile_size, self.env.height * tile_size, 3),
  133. dtype='uint8'
  134. )
  135. self.observation_space = spaces.Dict(
  136. {**self.observation_space, 'image': new_image_space})
  137. def observation(self, obs):
  138. env = self.unwrapped
  139. assert env.render_mode == 'rgb_array', env.render_mode
  140. rgb_img = env.render(
  141. highlight=False,
  142. tile_size=self.tile_size
  143. )
  144. return {
  145. **obs,
  146. 'image': rgb_img
  147. }
  148. class RGBImgPartialObsWrapper(gym.ObservationWrapper):
  149. """
  150. Wrapper to use partially observable RGB image as observation.
  151. This can be used to have the agent to solve the gridworld in pixel space.
  152. """
  153. def __init__(self, env, tile_size=8):
  154. super().__init__(env)
  155. self.tile_size = tile_size
  156. obs_shape = env.observation_space.spaces['image'].shape
  157. new_image_space = spaces.Box(
  158. low=0,
  159. high=255,
  160. shape=(obs_shape[0] * tile_size, obs_shape[1] * tile_size, 3),
  161. dtype='uint8'
  162. )
  163. self.observation_space = spaces.Dict(
  164. {**self.observation_space, 'image': new_image_space})
  165. def observation(self, obs):
  166. env = self.unwrapped
  167. rgb_img_partial = env.get_obs_render(
  168. obs['image'],
  169. tile_size=self.tile_size
  170. )
  171. return {
  172. **obs,
  173. 'image': rgb_img_partial
  174. }
  175. class FullyObsWrapper(gym.ObservationWrapper):
  176. """
  177. Fully observable gridworld using a compact grid encoding
  178. """
  179. def __init__(self, env):
  180. super().__init__(env)
  181. new_image_space = spaces.Box(
  182. low=0,
  183. high=255,
  184. shape=(self.env.width, self.env.height, 3), # number of cells
  185. dtype='uint8'
  186. )
  187. self.observation_space = spaces.Dict(
  188. {**self.observation_space, 'image': new_image_space})
  189. def observation(self, obs):
  190. env = self.unwrapped
  191. full_grid = env.grid.encode()
  192. full_grid[env.agent_pos[0]][env.agent_pos[1]] = np.array([
  193. OBJECT_TO_IDX['agent'],
  194. COLOR_TO_IDX['red'],
  195. env.agent_dir
  196. ])
  197. return {
  198. **obs,
  199. 'image': full_grid
  200. }
  201. class DictObservationSpaceWrapper(gym.ObservationWrapper):
  202. """
  203. Use a Dict Obsevation Space encoding images, missions, and directions
  204. """
  205. def __init__(self, env, max_words_in_mission=50, word_dict=None):
  206. """
  207. max_words_in_mission is the length of the array to represent a mission, value 0 for missing words
  208. word_dict is a dictionary of words to use (keys=words, values=indices from 1 to < max_words_in_mission),
  209. if None, use the Minigrid language
  210. """
  211. super().__init__(env)
  212. if word_dict is None:
  213. word_dict = DictObservationSpaceWrapper.get_minigrid_words()
  214. self.max_words_in_mission = max_words_in_mission
  215. self.word_dict = word_dict
  216. image_observation_space = spaces.Box(
  217. low=0,
  218. high=255,
  219. shape=(self.agent_view_size, self.agent_view_size, 3),
  220. dtype='uint8'
  221. )
  222. self.observation_space = spaces.Dict({
  223. 'image': image_observation_space,
  224. 'direction': spaces.Discrete(4),
  225. 'mission': spaces.MultiDiscrete([len(self.word_dict.keys())]
  226. * max_words_in_mission)
  227. })
  228. @staticmethod
  229. def get_minigrid_words():
  230. colors = ['red', 'green', 'blue', 'yellow', 'purple', 'grey']
  231. objects = ['unseen', 'empty', 'wall', 'floor', 'box', 'key', 'ball',
  232. 'door', 'goal', 'agent', 'lava']
  233. verbs = ['pick', 'avoid', 'get', 'find', 'put',
  234. 'use', 'open', 'go', 'fetch',
  235. 'reach', 'unlock', 'traverse']
  236. extra_words = ['up', 'the', 'a', 'at', ',', 'square',
  237. 'and', 'then', 'to', 'of', 'rooms', 'near',
  238. 'opening', 'must', 'you', 'matching', 'end',
  239. 'hallway', 'object', 'from', 'room']
  240. all_words = colors + objects + verbs + extra_words
  241. assert len(all_words) == len(set(all_words))
  242. return {word: i for i, word in enumerate(all_words)}
  243. def string_to_indices(self, string, offset=1):
  244. """
  245. Convert a string to a list of indices.
  246. """
  247. indices = []
  248. # adding space before and after commas
  249. string = string.replace(',', ' , ')
  250. for word in string.split():
  251. if word in self.word_dict.keys():
  252. indices.append(self.word_dict[word] + offset)
  253. else:
  254. raise ValueError('Unknown word: {}'.format(word))
  255. return indices
  256. def observation(self, obs):
  257. obs['mission'] = self.string_to_indices(obs['mission'])
  258. assert len(obs['mission']) < self.max_words_in_mission
  259. obs['mission'] += [0] * \
  260. (self.max_words_in_mission - len(obs['mission']))
  261. return obs
  262. class FlatObsWrapper(gym.ObservationWrapper):
  263. """
  264. Encode mission strings using a one-hot scheme,
  265. and combine these with observed images into one flat array
  266. """
  267. def __init__(self, env, maxStrLen=96):
  268. super().__init__(env)
  269. self.maxStrLen = maxStrLen
  270. self.numCharCodes = 27
  271. imgSpace = env.observation_space.spaces['image']
  272. imgSize = reduce(operator.mul, imgSpace.shape, 1)
  273. self.observation_space = spaces.Box(
  274. low=0,
  275. high=255,
  276. shape=(imgSize + self.numCharCodes * self.maxStrLen,),
  277. dtype='uint8'
  278. )
  279. self.cachedStr = None
  280. self.cachedArray = None
  281. def observation(self, obs):
  282. image = obs['image']
  283. mission = obs['mission']
  284. # Cache the last-encoded mission string
  285. if mission != self.cachedStr:
  286. assert len(mission) <= self.maxStrLen, 'mission string too long ({} chars)'.format(
  287. len(mission))
  288. mission = mission.lower()
  289. strArray = np.zeros(
  290. shape=(self.maxStrLen, self.numCharCodes), dtype='float32')
  291. for idx, ch in enumerate(mission):
  292. if ch >= 'a' and ch <= 'z':
  293. chNo = ord(ch) - ord('a')
  294. elif ch == ' ':
  295. chNo = ord('z') - ord('a') + 1
  296. assert chNo < self.numCharCodes, '%s : %d' % (ch, chNo)
  297. strArray[idx, chNo] = 1
  298. self.cachedStr = mission
  299. self.cachedArray = strArray
  300. obs = np.concatenate((image.flatten(), self.cachedArray.flatten()))
  301. return obs
  302. class ViewSizeWrapper(gym.Wrapper):
  303. """
  304. Wrapper to customize the agent field of view size.
  305. This cannot be used with fully observable wrappers.
  306. """
  307. def __init__(self, env, agent_view_size=7):
  308. super().__init__(env)
  309. assert agent_view_size % 2 == 1
  310. assert agent_view_size >= 3
  311. self.agent_view_size = agent_view_size
  312. # Compute observation space with specified view size
  313. new_image_space = gym.spaces.Box(
  314. low=0,
  315. high=255,
  316. shape=(agent_view_size, agent_view_size, 3),
  317. dtype='uint8'
  318. )
  319. # Override the environment's observation spaceexit
  320. self.observation_space = spaces.Dict(
  321. {**self.observation_space, 'image': new_image_space})
  322. def observation(self, obs):
  323. env = self.unwrapped
  324. grid, vis_mask = env.gen_obs_grid(self.agent_view_size)
  325. # Encode the partially observable view into a numpy array
  326. image = grid.encode(vis_mask)
  327. return {
  328. **obs,
  329. 'image': image
  330. }
  331. class DirectionObsWrapper(gym.ObservationWrapper):
  332. """
  333. Provides the slope/angular direction to the goal with the observations as modeled by (y2 - y2 )/( x2 - x1)
  334. type = {slope , angle}
  335. """
  336. def __init__(self, env, type='slope'):
  337. super().__init__(env)
  338. self.goal_position = None
  339. self.type = type
  340. def reset(self):
  341. obs = self.env.reset()
  342. if not self.goal_position:
  343. self.goal_position = [x for x, y in enumerate(
  344. self.grid.grid) if isinstance(y, (Goal))]
  345. # in case there are multiple goals , needs to be handled for other env types
  346. if len(self.goal_position) >= 1:
  347. self.goal_position = (
  348. int(self.goal_position[0]/self.height), self.goal_position[0] % self.width)
  349. return obs
  350. def observation(self, obs):
  351. slope = np.divide(
  352. self.goal_position[1] - self.agent_pos[1], self.goal_position[0] - self.agent_pos[0])
  353. obs['goal_direction'] = np.arctan(
  354. slope) if self.type == 'angle' else slope
  355. return obs
  356. class SymbolicObsWrapper(gym.ObservationWrapper):
  357. """
  358. Fully observable grid with a symbolic state representation.
  359. The symbol is a triple of (X, Y, IDX), where X and Y are
  360. the coordinates on the grid, and IDX is the id of the object.
  361. """
  362. def __init__(self, env):
  363. super().__init__(env)
  364. new_image_space = spaces.Box(
  365. low=0,
  366. high=max(OBJECT_TO_IDX.values()),
  367. shape=(self.env.width, self.env.height, 3), # number of cells
  368. dtype="uint8",
  369. )
  370. self.observation_space = spaces.Dict(
  371. {**self.observation_space, 'image': new_image_space})
  372. def observation(self, obs):
  373. objects = np.array(
  374. [OBJECT_TO_IDX[o.type] if o is not None else -1 for o in self.grid.grid]
  375. )
  376. w, h = self.width, self.height
  377. grid = np.mgrid[:w, :h]
  378. grid = np.concatenate([grid, objects.reshape(1, w, h)])
  379. grid = np.transpose(grid, (1, 2, 0))
  380. obs['image'] = grid
  381. return obs