wrappers.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. import math
  2. import operator
  3. from functools import reduce
  4. import numpy as np
  5. import gym
  6. from gym import error, spaces, utils
  7. from .minigrid import OBJECT_TO_IDX, COLOR_TO_IDX
  8. from .minigrid import CELL_PIXELS
  9. class ReseedWrapper(gym.core.Wrapper):
  10. """
  11. Wrapper to always regenerate an environment with the same set of seeds.
  12. This can be used to force an environment to always keep the same
  13. configuration when reset.
  14. """
  15. def __init__(self, env, seeds=[0], seed_idx=0):
  16. self.seeds = list(seeds)
  17. self.seed_idx = seed_idx
  18. super().__init__(env)
  19. def reset(self, **kwargs):
  20. seed = self.seeds[self.seed_idx]
  21. self.seed_idx = (self.seed_idx + 1) % len(self.seeds)
  22. self.env.seed(seed)
  23. return self.env.reset(**kwargs)
  24. def step(self, action):
  25. obs, reward, done, info = self.env.step(action)
  26. return obs, reward, done, info
  27. class ActionBonus(gym.core.Wrapper):
  28. """
  29. Wrapper which adds an exploration bonus.
  30. This is a reward to encourage exploration of less
  31. visited (state,action) pairs.
  32. """
  33. def __init__(self, env):
  34. self.__dict__.update(vars(env)) # Pass values to super wrapper
  35. super().__init__(env)
  36. self.counts = {}
  37. def step(self, action):
  38. obs, reward, done, info = self.env.step(action)
  39. env = self.unwrapped
  40. tup = (tuple(env.agent_pos), env.agent_dir, action)
  41. # Get the count for this (s,a) pair
  42. pre_count = 0
  43. if tup in self.counts:
  44. pre_count = self.counts[tup]
  45. # Update the count for this (s,a) pair
  46. new_count = pre_count + 1
  47. self.counts[tup] = new_count
  48. bonus = 1 / math.sqrt(new_count)
  49. reward += bonus
  50. return obs, reward, done, info
  51. def reset(self, **kwargs):
  52. return self.env.reset(**kwargs)
  53. class StateBonus(gym.core.Wrapper):
  54. """
  55. Adds an exploration bonus based on which positions
  56. are visited on the grid.
  57. """
  58. def __init__(self, env):
  59. self.__dict__.update(vars(env)) # Pass values to super wrapper
  60. super().__init__(env)
  61. self.counts = {}
  62. def step(self, action):
  63. obs, reward, done, info = self.env.step(action)
  64. # Tuple based on which we index the counts
  65. # We use the position after an update
  66. env = self.unwrapped
  67. tup = (tuple(env.agent_pos))
  68. # Get the count for this key
  69. pre_count = 0
  70. if tup in self.counts:
  71. pre_count = self.counts[tup]
  72. # Update the count for this key
  73. new_count = pre_count + 1
  74. self.counts[tup] = new_count
  75. bonus = 1 / math.sqrt(new_count)
  76. reward += bonus
  77. return obs, reward, done, info
  78. def reset(self, **kwargs):
  79. return self.env.reset(**kwargs)
  80. class ImgObsWrapper(gym.core.ObservationWrapper):
  81. """
  82. Use the image as the only observation output, no language/mission.
  83. """
  84. def __init__(self, env):
  85. self.__dict__.update(vars(env)) # Pass values to super wrapper
  86. super().__init__(env)
  87. self.observation_space = env.observation_space.spaces['image']
  88. def observation(self, obs):
  89. return obs['image']
  90. class RGBImgObsWrapper(gym.core.ObservationWrapper):
  91. """
  92. Wrapper to use fully observable RGB image as the only observation output,
  93. no language/mission. This can be used to have the agent to solve the
  94. gridworld in pixel space.
  95. """
  96. def __init__(self, env):
  97. self.__dict__.update(vars(env)) # Pass values to super wrapper
  98. super().__init__(env)
  99. self.observation_space = spaces.Box(
  100. low=0,
  101. high=255,
  102. shape=(self.env.width*CELL_PIXELS, self.env.height*CELL_PIXELS, 3),
  103. dtype='uint8'
  104. )
  105. def observation(self, obs):
  106. env = self.unwrapped
  107. return env.render(mode = 'rgb_array', highlight = False)
  108. class FullyObsWrapper(gym.core.ObservationWrapper):
  109. """
  110. Fully observable gridworld using a compact grid encoding
  111. """
  112. def __init__(self, env):
  113. self.__dict__.update(vars(env)) # Pass values to super wrapper
  114. super().__init__(env)
  115. self.observation_space = spaces.Box(
  116. low=0,
  117. high=255,
  118. shape=(self.env.width, self.env.height, 3), # number of cells
  119. dtype='uint8'
  120. )
  121. def observation(self, obs):
  122. env = self.unwrapped
  123. full_grid = env.grid.encode()
  124. full_grid[env.agent_pos[0]][env.agent_pos[1]] = np.array([
  125. OBJECT_TO_IDX['agent'],
  126. COLOR_TO_IDX['red'],
  127. env.agent_dir
  128. ])
  129. return full_grid
  130. class FlatObsWrapper(gym.core.ObservationWrapper):
  131. """
  132. Encode mission strings using a one-hot scheme,
  133. and combine these with observed images into one flat array
  134. """
  135. def __init__(self, env, maxStrLen=96):
  136. self.__dict__.update(vars(env)) # Pass values to super wrapper
  137. super().__init__(env)
  138. self.maxStrLen = maxStrLen
  139. self.numCharCodes = 27
  140. imgSpace = env.observation_space.spaces['image']
  141. imgSize = reduce(operator.mul, imgSpace.shape, 1)
  142. self.observation_space = spaces.Box(
  143. low=0,
  144. high=255,
  145. shape=(1, imgSize + self.numCharCodes * self.maxStrLen),
  146. dtype='uint8'
  147. )
  148. self.cachedStr = None
  149. self.cachedArray = None
  150. def observation(self, obs):
  151. image = obs['image']
  152. mission = obs['mission']
  153. # Cache the last-encoded mission string
  154. if mission != self.cachedStr:
  155. assert len(mission) <= self.maxStrLen, 'mission string too long ({} chars)'.format(len(mission))
  156. mission = mission.lower()
  157. strArray = np.zeros(shape=(self.maxStrLen, self.numCharCodes), dtype='float32')
  158. for idx, ch in enumerate(mission):
  159. if ch >= 'a' and ch <= 'z':
  160. chNo = ord(ch) - ord('a')
  161. elif ch == ' ':
  162. chNo = ord('z') - ord('a') + 1
  163. assert chNo < self.numCharCodes, '%s : %d' % (ch, chNo)
  164. strArray[idx, chNo] = 1
  165. self.cachedStr = mission
  166. self.cachedArray = strArray
  167. obs = np.concatenate((image.flatten(), self.cachedArray.flatten()))
  168. return obs
  169. class AgentViewWrapper(gym.core.Wrapper):
  170. """
  171. Wrapper to customize the agent's field of view.
  172. """
  173. def __init__(self, env, agent_view_size=7):
  174. self.__dict__.update(vars(env)) # Pass values to super wrapper
  175. super(AgentViewWrapper, self).__init__(env)
  176. # Override default view size
  177. env.unwrapped.agent_view_size = agent_view_size
  178. # Compute observation space with specified view size
  179. observation_space = gym.spaces.Box(
  180. low=0,
  181. high=255,
  182. shape=(agent_view_size, agent_view_size, 3),
  183. dtype='uint8'
  184. )
  185. # Override the environment's observation space
  186. self.observation_space = spaces.Dict({
  187. 'image': observation_space
  188. })
  189. def reset(self, **kwargs):
  190. return self.env.reset(**kwargs)
  191. def step(self, action):
  192. return self.env.step(action)