wrappers.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358
  1. import math
  2. import operator
  3. from functools import reduce
  4. import numpy as np
  5. import gym
  6. from gym import error, spaces, utils
  7. from .minigrid import OBJECT_TO_IDX, COLOR_TO_IDX, STATE_TO_IDX
  8. class ReseedWrapper(gym.core.Wrapper):
  9. """
  10. Wrapper to always regenerate an environment with the same set of seeds.
  11. This can be used to force an environment to always keep the same
  12. configuration when reset.
  13. """
  14. def __init__(self, env, seeds=[0], seed_idx=0):
  15. self.seeds = list(seeds)
  16. self.seed_idx = seed_idx
  17. super().__init__(env)
  18. def reset(self, **kwargs):
  19. seed = self.seeds[self.seed_idx]
  20. self.seed_idx = (self.seed_idx + 1) % len(self.seeds)
  21. self.env.seed(seed)
  22. return self.env.reset(**kwargs)
  23. def step(self, action):
  24. obs, reward, done, info = self.env.step(action)
  25. return obs, reward, done, info
  26. class ActionBonus(gym.core.Wrapper):
  27. """
  28. Wrapper which adds an exploration bonus.
  29. This is a reward to encourage exploration of less
  30. visited (state,action) pairs.
  31. """
  32. def __init__(self, env):
  33. super().__init__(env)
  34. self.counts = {}
  35. def step(self, action):
  36. obs, reward, done, info = self.env.step(action)
  37. env = self.unwrapped
  38. tup = (tuple(env.agent_pos), env.agent_dir, action)
  39. # Get the count for this (s,a) pair
  40. pre_count = 0
  41. if tup in self.counts:
  42. pre_count = self.counts[tup]
  43. # Update the count for this (s,a) pair
  44. new_count = pre_count + 1
  45. self.counts[tup] = new_count
  46. bonus = 1 / math.sqrt(new_count)
  47. reward += bonus
  48. return obs, reward, done, info
  49. def reset(self, **kwargs):
  50. return self.env.reset(**kwargs)
  51. class StateBonus(gym.core.Wrapper):
  52. """
  53. Adds an exploration bonus based on which positions
  54. are visited on the grid.
  55. """
  56. def __init__(self, env):
  57. super().__init__(env)
  58. self.counts = {}
  59. def step(self, action):
  60. obs, reward, done, info = self.env.step(action)
  61. # Tuple based on which we index the counts
  62. # We use the position after an update
  63. env = self.unwrapped
  64. tup = (tuple(env.agent_pos))
  65. # Get the count for this key
  66. pre_count = 0
  67. if tup in self.counts:
  68. pre_count = self.counts[tup]
  69. # Update the count for this key
  70. new_count = pre_count + 1
  71. self.counts[tup] = new_count
  72. bonus = 1 / math.sqrt(new_count)
  73. reward += bonus
  74. return obs, reward, done, info
  75. def reset(self, **kwargs):
  76. return self.env.reset(**kwargs)
  77. class ImgObsWrapper(gym.core.ObservationWrapper):
  78. """
  79. Use the image as the only observation output, no language/mission.
  80. """
  81. def __init__(self, env):
  82. super().__init__(env)
  83. self.observation_space = env.observation_space.spaces['image']
  84. def observation(self, obs):
  85. return obs['image']
  86. class OneHotPartialObsWrapper(gym.core.ObservationWrapper):
  87. """
  88. Wrapper to get a one-hot encoding of a partially observable
  89. agent view as observation.
  90. """
  91. def __init__(self, env, tile_size=8):
  92. super().__init__(env)
  93. self.tile_size = tile_size
  94. obs_shape = env.observation_space['image'].shape
  95. # Number of bits per cell
  96. num_bits = len(OBJECT_TO_IDX) + len(COLOR_TO_IDX) + len(STATE_TO_IDX)
  97. self.observation_space.spaces["image"] = spaces.Box(
  98. low=0,
  99. high=255,
  100. shape=(obs_shape[0], obs_shape[1], num_bits),
  101. dtype='uint8'
  102. )
  103. def observation(self, obs):
  104. img = obs['image']
  105. out = np.zeros(self.observation_space.spaces['image'].shape, dtype='uint8')
  106. for i in range(img.shape[0]):
  107. for j in range(img.shape[1]):
  108. type = img[i, j, 0]
  109. color = img[i, j, 1]
  110. state = img[i, j, 2]
  111. out[i, j, type] = 1
  112. out[i, j, len(OBJECT_TO_IDX) + color] = 1
  113. out[i, j, len(OBJECT_TO_IDX) + len(COLOR_TO_IDX) + state] = 1
  114. return {
  115. 'mission': obs['mission'],
  116. 'image': out
  117. }
  118. class RGBImgObsWrapper(gym.core.ObservationWrapper):
  119. """
  120. Wrapper to use fully observable RGB image as the only observation output,
  121. no language/mission. This can be used to have the agent to solve the
  122. gridworld in pixel space.
  123. """
  124. def __init__(self, env, tile_size=8):
  125. super().__init__(env)
  126. self.tile_size = tile_size
  127. self.observation_space.spaces['image'] = spaces.Box(
  128. low=0,
  129. high=255,
  130. shape=(self.env.width * tile_size, self.env.height * tile_size, 3),
  131. dtype='uint8'
  132. )
  133. def observation(self, obs):
  134. env = self.unwrapped
  135. rgb_img = env.render(
  136. mode='rgb_array',
  137. highlight=False,
  138. tile_size=self.tile_size
  139. )
  140. return {
  141. 'mission': obs['mission'],
  142. 'image': rgb_img
  143. }
  144. class RGBImgPartialObsWrapper(gym.core.ObservationWrapper):
  145. """
  146. Wrapper to use partially observable RGB image as the only observation output
  147. This can be used to have the agent to solve the gridworld in pixel space.
  148. """
  149. def __init__(self, env, tile_size=8):
  150. super().__init__(env)
  151. self.tile_size = tile_size
  152. obs_shape = env.observation_space.spaces['image'].shape
  153. self.observation_space.spaces['image'] = spaces.Box(
  154. low=0,
  155. high=255,
  156. shape=(obs_shape[0] * tile_size, obs_shape[1] * tile_size, 3),
  157. dtype='uint8'
  158. )
  159. def observation(self, obs):
  160. env = self.unwrapped
  161. rgb_img_partial = env.get_obs_render(
  162. obs['image'],
  163. tile_size=self.tile_size
  164. )
  165. return {
  166. 'mission': obs['mission'],
  167. 'image': rgb_img_partial
  168. }
  169. class FullyObsWrapper(gym.core.ObservationWrapper):
  170. """
  171. Fully observable gridworld using a compact grid encoding
  172. """
  173. def __init__(self, env):
  174. super().__init__(env)
  175. self.observation_space.spaces["image"] = spaces.Box(
  176. low=0,
  177. high=255,
  178. shape=(self.env.width, self.env.height, 3), # number of cells
  179. dtype='uint8'
  180. )
  181. def observation(self, obs):
  182. env = self.unwrapped
  183. full_grid = env.grid.encode()
  184. full_grid[env.agent_pos[0]][env.agent_pos[1]] = np.array([
  185. OBJECT_TO_IDX['agent'],
  186. COLOR_TO_IDX['red'],
  187. env.agent_dir
  188. ])
  189. return {
  190. 'mission': obs['mission'],
  191. 'image': full_grid
  192. }
  193. class FlatObsWrapper(gym.core.ObservationWrapper):
  194. """
  195. Encode mission strings using a one-hot scheme,
  196. and combine these with observed images into one flat array
  197. """
  198. def __init__(self, env, maxStrLen=96):
  199. super().__init__(env)
  200. self.maxStrLen = maxStrLen
  201. self.numCharCodes = 27
  202. imgSpace = env.observation_space.spaces['image']
  203. imgSize = reduce(operator.mul, imgSpace.shape, 1)
  204. self.observation_space = spaces.Box(
  205. low=0,
  206. high=255,
  207. shape=(imgSize + self.numCharCodes * self.maxStrLen,),
  208. dtype='uint8'
  209. )
  210. self.cachedStr = None
  211. self.cachedArray = None
  212. def observation(self, obs):
  213. image = obs['image']
  214. mission = obs['mission']
  215. # Cache the last-encoded mission string
  216. if mission != self.cachedStr:
  217. assert len(mission) <= self.maxStrLen, 'mission string too long ({} chars)'.format(len(mission))
  218. mission = mission.lower()
  219. strArray = np.zeros(shape=(self.maxStrLen, self.numCharCodes), dtype='float32')
  220. for idx, ch in enumerate(mission):
  221. if ch >= 'a' and ch <= 'z':
  222. chNo = ord(ch) - ord('a')
  223. elif ch == ' ':
  224. chNo = ord('z') - ord('a') + 1
  225. assert chNo < self.numCharCodes, '%s : %d' % (ch, chNo)
  226. strArray[idx, chNo] = 1
  227. self.cachedStr = mission
  228. self.cachedArray = strArray
  229. obs = np.concatenate((image.flatten(), self.cachedArray.flatten()))
  230. return obs
  231. class ViewSizeWrapper(gym.core.Wrapper):
  232. """
  233. Wrapper to customize the agent field of view size.
  234. This cannot be used with fully observable wrappers.
  235. """
  236. def __init__(self, env, agent_view_size=7):
  237. super().__init__(env)
  238. assert agent_view_size % 2 == 1
  239. assert agent_view_size >= 3
  240. # Override default view size
  241. env.unwrapped.agent_view_size = agent_view_size
  242. # Compute observation space with specified view size
  243. observation_space = gym.spaces.Box(
  244. low=0,
  245. high=255,
  246. shape=(agent_view_size, agent_view_size, 3),
  247. dtype='uint8'
  248. )
  249. # Override the environment's observation space
  250. self.observation_space = spaces.Dict({
  251. 'image': observation_space
  252. })
  253. def reset(self, **kwargs):
  254. return self.env.reset(**kwargs)
  255. def step(self, action):
  256. return self.env.step(action)
  257. from .minigrid import Goal
  258. class DirectionObsWrapper(gym.core.ObservationWrapper):
  259. """
  260. Provides the slope/angular direction to the goal with the observations as modeled by (y2 - y2 )/( x2 - x1)
  261. type = {slope , angle}
  262. """
  263. def __init__(self, env,type='slope'):
  264. super().__init__(env)
  265. self.goal_position = None
  266. self.type = type
  267. def reset(self):
  268. obs = self.env.reset()
  269. if not self.goal_position:
  270. self.goal_position = [x for x,y in enumerate(self.grid.grid) if isinstance(y,(Goal) ) ]
  271. if len(self.goal_position) >= 1: # in case there are multiple goals , needs to be handled for other env types
  272. self.goal_position = (int(self.goal_position[0]/self.height) , self.goal_position[0]%self.width)
  273. return obs
  274. def observation(self, obs):
  275. slope = np.divide( self.goal_position[1] - self.agent_pos[1] , self.goal_position[0] - self.agent_pos[0])
  276. obs['goal_direction'] = np.arctan( slope ) if self.type == 'angle' else slope
  277. return obs