wrappers.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
  1. import math
  2. import operator
  3. from functools import reduce
  4. import numpy as np
  5. import gym
  6. from gym import error, spaces, utils
  7. from .minigrid import OBJECT_TO_IDX, COLOR_TO_IDX, STATE_TO_IDX
  8. from .minigrid import CELL_PIXELS
  9. class ReseedWrapper(gym.core.Wrapper):
  10. """
  11. Wrapper to always regenerate an environment with the same set of seeds.
  12. This can be used to force an environment to always keep the same
  13. configuration when reset.
  14. """
  15. def __init__(self, env, seeds=[0], seed_idx=0):
  16. self.seeds = list(seeds)
  17. self.seed_idx = seed_idx
  18. super().__init__(env)
  19. def reset(self, **kwargs):
  20. seed = self.seeds[self.seed_idx]
  21. self.seed_idx = (self.seed_idx + 1) % len(self.seeds)
  22. self.env.seed(seed)
  23. return self.env.reset(**kwargs)
  24. def step(self, action):
  25. obs, reward, done, info = self.env.step(action)
  26. return obs, reward, done, info
  27. class ActionBonus(gym.core.Wrapper):
  28. """
  29. Wrapper which adds an exploration bonus.
  30. This is a reward to encourage exploration of less
  31. visited (state,action) pairs.
  32. """
  33. def __init__(self, env):
  34. super().__init__(env)
  35. self.counts = {}
  36. def step(self, action):
  37. obs, reward, done, info = self.env.step(action)
  38. env = self.unwrapped
  39. tup = (tuple(env.agent_pos), env.agent_dir, action)
  40. # Get the count for this (s,a) pair
  41. pre_count = 0
  42. if tup in self.counts:
  43. pre_count = self.counts[tup]
  44. # Update the count for this (s,a) pair
  45. new_count = pre_count + 1
  46. self.counts[tup] = new_count
  47. bonus = 1 / math.sqrt(new_count)
  48. reward += bonus
  49. return obs, reward, done, info
  50. def reset(self, **kwargs):
  51. return self.env.reset(**kwargs)
  52. class StateBonus(gym.core.Wrapper):
  53. """
  54. Adds an exploration bonus based on which positions
  55. are visited on the grid.
  56. """
  57. def __init__(self, env):
  58. super().__init__(env)
  59. self.counts = {}
  60. def step(self, action):
  61. obs, reward, done, info = self.env.step(action)
  62. # Tuple based on which we index the counts
  63. # We use the position after an update
  64. env = self.unwrapped
  65. tup = (tuple(env.agent_pos))
  66. # Get the count for this key
  67. pre_count = 0
  68. if tup in self.counts:
  69. pre_count = self.counts[tup]
  70. # Update the count for this key
  71. new_count = pre_count + 1
  72. self.counts[tup] = new_count
  73. bonus = 1 / math.sqrt(new_count)
  74. reward += bonus
  75. return obs, reward, done, info
  76. def reset(self, **kwargs):
  77. return self.env.reset(**kwargs)
  78. class ImgObsWrapper(gym.core.ObservationWrapper):
  79. """
  80. Use the image as the only observation output, no language/mission.
  81. """
  82. def __init__(self, env):
  83. super().__init__(env)
  84. self.observation_space = env.observation_space.spaces['image']
  85. def observation(self, obs):
  86. return obs['image']
  87. class OneHotPartialObsWrapper(gym.core.ObservationWrapper):
  88. """
  89. Wrapper to get a one-hot encoding of a partially observable
  90. agent view as observation.
  91. """
  92. def __init__(self, env, tile_size=8):
  93. super().__init__(env)
  94. self.tile_size = tile_size
  95. obs_shape = env.observation_space['image'].shape
  96. # Number of bits per cell
  97. num_bits = len(OBJECT_TO_IDX) + len(COLOR_TO_IDX) + len(STATE_TO_IDX)
  98. self.observation_space.spaces["image"] = spaces.Box(
  99. low=0,
  100. high=255,
  101. shape=(obs_shape[0], obs_shape[1], num_bits),
  102. dtype='uint8'
  103. )
  104. def observation(self, obs):
  105. img = obs['image']
  106. out = np.zeros(self.observation_space.shape, dtype='uint8')
  107. for i in range(img.shape[0]):
  108. for j in range(img.shape[1]):
  109. type = img[i, j, 0]
  110. color = img[i, j, 1]
  111. state = img[i, j, 2]
  112. out[i, j, type] = 1
  113. out[i, j, len(OBJECT_TO_IDX) + color] = 1
  114. out[i, j, len(OBJECT_TO_IDX) + len(COLOR_TO_IDX) + state] = 1
  115. return {
  116. 'mission': obs['mission'],
  117. 'image': out
  118. }
  119. class RGBImgObsWrapper(gym.core.ObservationWrapper):
  120. """
  121. Wrapper to use fully observable RGB image as the only observation output,
  122. no language/mission. This can be used to have the agent to solve the
  123. gridworld in pixel space.
  124. """
  125. def __init__(self, env, tile_size=8):
  126. super().__init__(env)
  127. self.tile_size = tile_size
  128. self.observation_space.spaces['image'] = spaces.Box(
  129. low=0,
  130. high=255,
  131. shape=(self.env.width*tile_size, self.env.height*tile_size, 3),
  132. dtype='uint8'
  133. )
  134. def observation(self, obs):
  135. env = self.unwrapped
  136. rgb_img = env.render(
  137. mode='rgb_array',
  138. highlight=False,
  139. tile_size=self.tile_size
  140. )
  141. return {
  142. 'mission': obs['mission'],
  143. 'image': rgb_img
  144. }
  145. class RGBImgPartialObsWrapper(gym.core.ObservationWrapper):
  146. """
  147. Wrapper to use partially observable RGB image as the only observation output
  148. This can be used to have the agent to solve the gridworld in pixel space.
  149. """
  150. def __init__(self, env, tile_size=8):
  151. super().__init__(env)
  152. self.tile_size = tile_size
  153. obs_shape = env.observation_space['image'].shape
  154. self.observation_space.spaces['image'] = spaces.Box(
  155. low=0,
  156. high=255,
  157. shape=(obs_shape[0] * tile_size, obs_shape[1] * tile_size, 3),
  158. dtype='uint8'
  159. )
  160. def observation(self, obs):
  161. env = self.unwrapped
  162. rgb_img_partial = env.get_obs_render(
  163. obs['image'],
  164. tile_size=self.tile_size,
  165. mode='rgb_array'
  166. )
  167. return {
  168. 'mission': obs['mission'],
  169. 'image': rgb_img_partial
  170. }
  171. class FullyObsWrapper(gym.core.ObservationWrapper):
  172. """
  173. Fully observable gridworld using a compact grid encoding
  174. """
  175. def __init__(self, env):
  176. super().__init__(env)
  177. self.observation_space.spaces["image"] = spaces.Box(
  178. low=0,
  179. high=255,
  180. shape=(self.env.width, self.env.height, 3), # number of cells
  181. dtype='uint8'
  182. )
  183. def observation(self, obs):
  184. env = self.unwrapped
  185. full_grid = env.grid.encode()
  186. full_grid[env.agent_pos[0]][env.agent_pos[1]] = np.array([
  187. OBJECT_TO_IDX['agent'],
  188. COLOR_TO_IDX['red'],
  189. env.agent_dir
  190. ])
  191. return {
  192. 'mission': obs['mission'],
  193. 'image': full_grid
  194. }
  195. class FlatObsWrapper(gym.core.ObservationWrapper):
  196. """
  197. Encode mission strings using a one-hot scheme,
  198. and combine these with observed images into one flat array
  199. """
  200. def __init__(self, env, maxStrLen=96):
  201. super().__init__(env)
  202. self.maxStrLen = maxStrLen
  203. self.numCharCodes = 27
  204. imgSpace = env.observation_space.spaces['image']
  205. imgSize = reduce(operator.mul, imgSpace.shape, 1)
  206. self.observation_space = spaces.Box(
  207. low=0,
  208. high=255,
  209. shape=(1, imgSize + self.numCharCodes * self.maxStrLen),
  210. dtype='uint8'
  211. )
  212. self.cachedStr = None
  213. self.cachedArray = None
  214. def observation(self, obs):
  215. image = obs['image']
  216. mission = obs['mission']
  217. # Cache the last-encoded mission string
  218. if mission != self.cachedStr:
  219. assert len(mission) <= self.maxStrLen, 'mission string too long ({} chars)'.format(len(mission))
  220. mission = mission.lower()
  221. strArray = np.zeros(shape=(self.maxStrLen, self.numCharCodes), dtype='float32')
  222. for idx, ch in enumerate(mission):
  223. if ch >= 'a' and ch <= 'z':
  224. chNo = ord(ch) - ord('a')
  225. elif ch == ' ':
  226. chNo = ord('z') - ord('a') + 1
  227. assert chNo < self.numCharCodes, '%s : %d' % (ch, chNo)
  228. strArray[idx, chNo] = 1
  229. self.cachedStr = mission
  230. self.cachedArray = strArray
  231. obs = np.concatenate((image.flatten(), self.cachedArray.flatten()))
  232. return obs
  233. class ViewSizeWrapper(gym.core.Wrapper):
  234. """
  235. Wrapper to customize the agent field of view size.
  236. This cannot be used with fully observable wrappers.
  237. """
  238. def __init__(self, env, agent_view_size=7):
  239. super().__init__(env)
  240. # Override default view size
  241. env.unwrapped.agent_view_size = agent_view_size
  242. # Compute observation space with specified view size
  243. observation_space = gym.spaces.Box(
  244. low=0,
  245. high=255,
  246. shape=(agent_view_size, agent_view_size, 3),
  247. dtype='uint8'
  248. )
  249. # Override the environment's observation space
  250. self.observation_space = spaces.Dict({
  251. 'image': observation_space
  252. })
  253. def reset(self, **kwargs):
  254. return self.env.reset(**kwargs)
  255. def step(self, action):
  256. return self.env.step(action)