wrappers.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. import math
  2. import operator
  3. from functools import reduce
  4. import numpy as np
  5. import gym
  6. from gym import error, spaces, utils
  7. from .minigrid import OBJECT_TO_IDX
  8. class ActionBonus(gym.core.Wrapper):
  9. """
  10. Wrapper which adds an exploration bonus.
  11. This is a reward to encourage exploration of less
  12. visited (state,action) pairs.
  13. """
  14. def __init__(self, env):
  15. self.__dict__.update(vars(env)) # Pass values to super wrapper
  16. super().__init__(env)
  17. self.counts = {}
  18. def step(self, action):
  19. obs, reward, done, info = self.env.step(action)
  20. env = self.unwrapped
  21. tup = (tuple(env.agent_pos), env.agent_dir, action)
  22. # Get the count for this (s,a) pair
  23. pre_count = 0
  24. if tup in self.counts:
  25. pre_count = self.counts[tup]
  26. # Update the count for this (s,a) pair
  27. new_count = pre_count + 1
  28. self.counts[tup] = new_count
  29. bonus = 1 / math.sqrt(new_count)
  30. reward += bonus
  31. return obs, reward, done, info
  32. def reset(self, **kwargs):
  33. return self.env.reset(**kwargs)
  34. class StateBonus(gym.core.Wrapper):
  35. """
  36. Adds an exploration bonus based on which positions
  37. are visited on the grid.
  38. """
  39. def __init__(self, env):
  40. self.__dict__.update(vars(env)) # Pass values to super wrapper
  41. super().__init__(env)
  42. self.counts = {}
  43. def step(self, action):
  44. obs, reward, done, info = self.env.step(action)
  45. # Tuple based on which we index the counts
  46. # We use the position after an update
  47. env = self.unwrapped
  48. tup = (tuple(env.agent_pos))
  49. # Get the count for this key
  50. pre_count = 0
  51. if tup in self.counts:
  52. pre_count = self.counts[tup]
  53. # Update the count for this key
  54. new_count = pre_count + 1
  55. self.counts[tup] = new_count
  56. bonus = 1 / math.sqrt(new_count)
  57. reward += bonus
  58. return obs, reward, done, info
  59. def reset(self, **kwargs):
  60. return self.env.reset(**kwargs)
  61. class ImgObsWrapper(gym.core.ObservationWrapper):
  62. """
  63. Use the image as the only observation output, no language/mission.
  64. """
  65. def __init__(self, env):
  66. self.__dict__.update(vars(env)) # Pass values to super wrapper
  67. super().__init__(env)
  68. self.observation_space = env.observation_space.spaces['image']
  69. def observation(self, obs):
  70. return obs['image']
  71. class FullyObsWrapper(gym.core.ObservationWrapper):
  72. """
  73. Fully observable gridworld using a compact grid encoding
  74. """
  75. def __init__(self, env):
  76. self.__dict__.update(vars(env)) # Pass values to super wrapper
  77. super().__init__(env)
  78. self.observation_space = spaces.Box(
  79. low=0,
  80. high=255,
  81. shape=(self.env.width, self.env.height, 3), # number of cells
  82. dtype='uint8'
  83. )
  84. def observation(self, obs):
  85. env = self.unwrapped
  86. full_grid = env.grid.encode()
  87. full_grid[env.agent_pos[0]][env.agent_pos[1]] = np.array([
  88. OBJECT_TO_IDX['agent'],
  89. env.agent_dir,
  90. 0
  91. ])
  92. return full_grid
  93. class FlatObsWrapper(gym.core.ObservationWrapper):
  94. """
  95. Encode mission strings using a one-hot scheme,
  96. and combine these with observed images into one flat array
  97. """
  98. def __init__(self, env, maxStrLen=96):
  99. self.__dict__.update(vars(env)) # Pass values to super wrapper
  100. super().__init__(env)
  101. self.maxStrLen = maxStrLen
  102. self.numCharCodes = 27
  103. imgSpace = env.observation_space.spaces['image']
  104. imgSize = reduce(operator.mul, imgSpace.shape, 1)
  105. self.observation_space = spaces.Box(
  106. low=0,
  107. high=255,
  108. shape=(1, imgSize + self.numCharCodes * self.maxStrLen),
  109. dtype='uint8'
  110. )
  111. self.cachedStr = None
  112. self.cachedArray = None
  113. def observation(self, obs):
  114. image = obs['image']
  115. mission = obs['mission']
  116. # Cache the last-encoded mission string
  117. if mission != self.cachedStr:
  118. assert len(mission) <= self.maxStrLen, 'mission string too long ({} chars)'.format(len(mission))
  119. mission = mission.lower()
  120. strArray = np.zeros(shape=(self.maxStrLen, self.numCharCodes), dtype='float32')
  121. for idx, ch in enumerate(mission):
  122. if ch >= 'a' and ch <= 'z':
  123. chNo = ord(ch) - ord('a')
  124. elif ch == ' ':
  125. chNo = ord('z') - ord('a') + 1
  126. assert chNo < self.numCharCodes, '%s : %d' % (ch, chNo)
  127. strArray[idx, chNo] = 1
  128. self.cachedStr = mission
  129. self.cachedArray = strArray
  130. obs = np.concatenate((image.flatten(), self.cachedArray.flatten()))
  131. return obs
  132. class AgentViewWrapper(gym.core.Wrapper):
  133. """
  134. Wrapper to customize the agent's field of view.
  135. """
  136. def __init__(self, env, agent_view_size=7):
  137. self.__dict__.update(vars(env)) # Pass values to super wrapper
  138. super(AgentViewWrapper, self).__init__(env)
  139. # Override default view size
  140. env.unwrapped.agent_view_size = agent_view_size
  141. # Compute observation space with specified view size
  142. observation_space = gym.spaces.Box(
  143. low=0,
  144. high=255,
  145. shape=(agent_view_size, agent_view_size, 3),
  146. dtype='uint8'
  147. )
  148. # Override the environment's observation space
  149. self.observation_space = spaces.Dict({
  150. 'image': observation_space
  151. })
  152. def reset(self, **kwargs):
  153. return self.env.reset(**kwargs)
  154. def step(self, action):
  155. return self.env.step(action)