wrappers.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. import math
  2. import operator
  3. from functools import reduce
  4. import numpy as np
  5. import gym
  6. from gym import error, spaces, utils
  7. class ActionBonus(gym.core.Wrapper):
  8. """
  9. Wrapper which adds an exploration bonus.
  10. This is a reward to encourage exploration of less
  11. visited (state,action) pairs.
  12. """
  13. def __init__(self, env):
  14. self.__dict__.update(vars(env)) # Pass values to super wrapper
  15. super().__init__(env)
  16. self.counts = {}
  17. def step(self, action):
  18. obs, reward, done, info = self.env.step(action)
  19. env = self.unwrapped
  20. tup = (tuple(env.agent_pos), env.agent_dir, action)
  21. # Get the count for this (s,a) pair
  22. pre_count = 0
  23. if tup in self.counts:
  24. pre_count = self.counts[tup]
  25. # Update the count for this (s,a) pair
  26. new_count = pre_count + 1
  27. self.counts[tup] = new_count
  28. bonus = 1 / math.sqrt(new_count)
  29. reward += bonus
  30. return obs, reward, done, info
  31. def reset(self, **kwargs):
  32. return self.env.reset(**kwargs)
  33. class StateBonus(gym.core.Wrapper):
  34. """
  35. Adds an exploration bonus based on which positions
  36. are visited on the grid.
  37. """
  38. def __init__(self, env):
  39. self.__dict__.update(vars(env)) # Pass values to super wrapper
  40. super().__init__(env)
  41. self.counts = {}
  42. def step(self, action):
  43. obs, reward, done, info = self.env.step(action)
  44. # Tuple based on which we index the counts
  45. # We use the position after an update
  46. env = self.unwrapped
  47. tup = (tuple(env.agent_pos))
  48. # Get the count for this key
  49. pre_count = 0
  50. if tup in self.counts:
  51. pre_count = self.counts[tup]
  52. # Update the count for this key
  53. new_count = pre_count + 1
  54. self.counts[tup] = new_count
  55. bonus = 1 / math.sqrt(new_count)
  56. reward += bonus
  57. return obs, reward, done, info
  58. def reset(self, **kwargs):
  59. return self.env.reset(**kwargs)
  60. class ImgObsWrapper(gym.core.ObservationWrapper):
  61. """
  62. Use the image as the only observation output, no language/mission.
  63. """
  64. def __init__(self, env):
  65. self.__dict__.update(vars(env)) # Pass values to super wrapper
  66. super().__init__(env)
  67. self.observation_space = env.observation_space.spaces['image']
  68. def observation(self, obs):
  69. return obs['image']
  70. class FullyObsWrapper(gym.core.ObservationWrapper):
  71. """
  72. Fully observable gridworld using a compact grid encoding
  73. """
  74. def __init__(self, env):
  75. self.__dict__.update(vars(env)) # Pass values to super wrapper
  76. super().__init__(env)
  77. self.observation_space = spaces.Box(
  78. low=0,
  79. high=255,
  80. shape=(self.env.width, self.env.height, 3), # number of cells
  81. dtype='uint8'
  82. )
  83. def observation(self, obs):
  84. env = self.unwrapped
  85. full_grid = env.grid.encode()
  86. full_grid[env.agent_pos[0]][env.agent_pos[1]] = np.array([255, env.agent_dir, 0])
  87. return full_grid
  88. class FlatObsWrapper(gym.core.ObservationWrapper):
  89. """
  90. Encode mission strings using a one-hot scheme,
  91. and combine these with observed images into one flat array
  92. """
  93. def __init__(self, env, maxStrLen=96):
  94. self.__dict__.update(vars(env)) # Pass values to super wrapper
  95. super().__init__(env)
  96. self.maxStrLen = maxStrLen
  97. self.numCharCodes = 27
  98. imgSpace = env.observation_space.spaces['image']
  99. imgSize = reduce(operator.mul, imgSpace.shape, 1)
  100. self.observation_space = spaces.Box(
  101. low=0,
  102. high=255,
  103. shape=(1, imgSize + self.numCharCodes * self.maxStrLen),
  104. dtype='uint8'
  105. )
  106. self.cachedStr = None
  107. self.cachedArray = None
  108. def observation(self, obs):
  109. image = obs['image']
  110. mission = obs['mission']
  111. # Cache the last-encoded mission string
  112. if mission != self.cachedStr:
  113. assert len(mission) <= self.maxStrLen, 'mission string too long ({} chars)'.format(len(mission))
  114. mission = mission.lower()
  115. strArray = np.zeros(shape=(self.maxStrLen, self.numCharCodes), dtype='float32')
  116. for idx, ch in enumerate(mission):
  117. if ch >= 'a' and ch <= 'z':
  118. chNo = ord(ch) - ord('a')
  119. elif ch == ' ':
  120. chNo = ord('z') - ord('a') + 1
  121. assert chNo < self.numCharCodes, '%s : %d' % (ch, chNo)
  122. strArray[idx, chNo] = 1
  123. self.cachedStr = mission
  124. self.cachedArray = strArray
  125. obs = np.concatenate((image.flatten(), self.cachedArray.flatten()))
  126. return obs
  127. class AgentViewWrapper(gym.core.Wrapper):
  128. """
  129. Wrapper to customize the agent's field of view.
  130. """
  131. def __init__(self, env, agent_view_size=7):
  132. self.__dict__.update(vars(env)) # Pass values to super wrapper
  133. super(AgentViewWrapper, self).__init__(env)
  134. # Override default view size
  135. env.unwrapped.agent_view_size = agent_view_size
  136. # Compute observation space with specified view size
  137. observation_space = gym.spaces.Box(
  138. low=0,
  139. high=255,
  140. shape=(agent_view_size, agent_view_size, 3),
  141. dtype='uint8'
  142. )
  143. # Override the environment's observation space
  144. self.observation_space = spaces.Dict({
  145. 'image': observation_space
  146. })
  147. def reset(self, **kwargs):
  148. return self.env.reset(**kwargs)
  149. def step(self, action):
  150. return self.env.step(action)