wrappers.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. import math
  2. import operator
  3. from functools import reduce
  4. import numpy as np
  5. import gym
  6. from gym import error, spaces, utils
  7. from .minigrid import OBJECT_TO_IDX, COLOR_TO_IDX
  8. class ReseedWrapper(gym.core.Wrapper):
  9. """
  10. Wrapper to always regenerate an environment with the same set of seeds.
  11. This can be used to force an environment to always keep the same
  12. configuration when reset.
  13. """
  14. def __init__(self, env, seeds=[0], seed_idx=0):
  15. self.seeds = list(seeds)
  16. self.seed_idx = seed_idx
  17. super().__init__(env)
  18. def reset(self, **kwargs):
  19. seed = self.seeds[self.seed_idx]
  20. self.seed_idx = (self.seed_idx + 1) % len(self.seeds)
  21. self.env.seed(seed)
  22. return self.env.reset(**kwargs)
  23. def step(self, action):
  24. obs, reward, done, info = self.env.step(action)
  25. return obs, reward, done, info
  26. class ActionBonus(gym.core.Wrapper):
  27. """
  28. Wrapper which adds an exploration bonus.
  29. This is a reward to encourage exploration of less
  30. visited (state,action) pairs.
  31. """
  32. def __init__(self, env):
  33. self.__dict__.update(vars(env)) # Pass values to super wrapper
  34. super().__init__(env)
  35. self.counts = {}
  36. def step(self, action):
  37. obs, reward, done, info = self.env.step(action)
  38. env = self.unwrapped
  39. tup = (tuple(env.agent_pos), env.agent_dir, action)
  40. # Get the count for this (s,a) pair
  41. pre_count = 0
  42. if tup in self.counts:
  43. pre_count = self.counts[tup]
  44. # Update the count for this (s,a) pair
  45. new_count = pre_count + 1
  46. self.counts[tup] = new_count
  47. bonus = 1 / math.sqrt(new_count)
  48. reward += bonus
  49. return obs, reward, done, info
  50. def reset(self, **kwargs):
  51. return self.env.reset(**kwargs)
  52. class StateBonus(gym.core.Wrapper):
  53. """
  54. Adds an exploration bonus based on which positions
  55. are visited on the grid.
  56. """
  57. def __init__(self, env):
  58. self.__dict__.update(vars(env)) # Pass values to super wrapper
  59. super().__init__(env)
  60. self.counts = {}
  61. def step(self, action):
  62. obs, reward, done, info = self.env.step(action)
  63. # Tuple based on which we index the counts
  64. # We use the position after an update
  65. env = self.unwrapped
  66. tup = (tuple(env.agent_pos))
  67. # Get the count for this key
  68. pre_count = 0
  69. if tup in self.counts:
  70. pre_count = self.counts[tup]
  71. # Update the count for this key
  72. new_count = pre_count + 1
  73. self.counts[tup] = new_count
  74. bonus = 1 / math.sqrt(new_count)
  75. reward += bonus
  76. return obs, reward, done, info
  77. def reset(self, **kwargs):
  78. return self.env.reset(**kwargs)
  79. class ImgObsWrapper(gym.core.ObservationWrapper):
  80. """
  81. Use the image as the only observation output, no language/mission.
  82. """
  83. def __init__(self, env):
  84. self.__dict__.update(vars(env)) # Pass values to super wrapper
  85. super().__init__(env)
  86. self.observation_space = env.observation_space.spaces['image']
  87. def observation(self, obs):
  88. return obs['image']
  89. class FullyObsWrapper(gym.core.ObservationWrapper):
  90. """
  91. Fully observable gridworld using a compact grid encoding
  92. """
  93. def __init__(self, env):
  94. self.__dict__.update(vars(env)) # Pass values to super wrapper
  95. super().__init__(env)
  96. self.observation_space = spaces.Box(
  97. low=0,
  98. high=255,
  99. shape=(self.env.width, self.env.height, 3), # number of cells
  100. dtype='uint8'
  101. )
  102. def observation(self, obs):
  103. env = self.unwrapped
  104. full_grid = env.grid.encode()
  105. full_grid[env.agent_pos[0]][env.agent_pos[1]] = np.array([
  106. OBJECT_TO_IDX['agent'],
  107. COLOR_TO_IDX['red'],
  108. env.agent_dir
  109. ])
  110. return full_grid
  111. class FlatObsWrapper(gym.core.ObservationWrapper):
  112. """
  113. Encode mission strings using a one-hot scheme,
  114. and combine these with observed images into one flat array
  115. """
  116. def __init__(self, env, maxStrLen=96):
  117. self.__dict__.update(vars(env)) # Pass values to super wrapper
  118. super().__init__(env)
  119. self.maxStrLen = maxStrLen
  120. self.numCharCodes = 27
  121. imgSpace = env.observation_space.spaces['image']
  122. imgSize = reduce(operator.mul, imgSpace.shape, 1)
  123. self.observation_space = spaces.Box(
  124. low=0,
  125. high=255,
  126. shape=(1, imgSize + self.numCharCodes * self.maxStrLen),
  127. dtype='uint8'
  128. )
  129. self.cachedStr = None
  130. self.cachedArray = None
  131. def observation(self, obs):
  132. image = obs['image']
  133. mission = obs['mission']
  134. # Cache the last-encoded mission string
  135. if mission != self.cachedStr:
  136. assert len(mission) <= self.maxStrLen, 'mission string too long ({} chars)'.format(len(mission))
  137. mission = mission.lower()
  138. strArray = np.zeros(shape=(self.maxStrLen, self.numCharCodes), dtype='float32')
  139. for idx, ch in enumerate(mission):
  140. if ch >= 'a' and ch <= 'z':
  141. chNo = ord(ch) - ord('a')
  142. elif ch == ' ':
  143. chNo = ord('z') - ord('a') + 1
  144. assert chNo < self.numCharCodes, '%s : %d' % (ch, chNo)
  145. strArray[idx, chNo] = 1
  146. self.cachedStr = mission
  147. self.cachedArray = strArray
  148. obs = np.concatenate((image.flatten(), self.cachedArray.flatten()))
  149. return obs
  150. class AgentViewWrapper(gym.core.Wrapper):
  151. """
  152. Wrapper to customize the agent's field of view.
  153. """
  154. def __init__(self, env, agent_view_size=7):
  155. self.__dict__.update(vars(env)) # Pass values to super wrapper
  156. super(AgentViewWrapper, self).__init__(env)
  157. # Override default view size
  158. env.unwrapped.agent_view_size = agent_view_size
  159. # Compute observation space with specified view size
  160. observation_space = gym.spaces.Box(
  161. low=0,
  162. high=255,
  163. shape=(agent_view_size, agent_view_size, 3),
  164. dtype='uint8'
  165. )
  166. # Override the environment's observation space
  167. self.observation_space = spaces.Dict({
  168. 'image': observation_space
  169. })
  170. def reset(self, **kwargs):
  171. return self.env.reset(**kwargs)
  172. def step(self, action):
  173. return self.env.step(action)