wrappers.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. import math
  2. import operator
  3. from functools import reduce
  4. import numpy as np
  5. import gym
  6. from gym import error, spaces, utils
  7. class ActionBonus(gym.core.Wrapper):
  8. """
  9. Wrapper which adds an exploration bonus.
  10. This is a reward to encourage exploration of less
  11. visited (state,action) pairs.
  12. """
  13. def __init__(self, env):
  14. super().__init__(env)
  15. self.counts = {}
  16. def step(self, action):
  17. obs, reward, done, info = self.env.step(action)
  18. env = self.unwrapped
  19. tup = (env.agentPos, env.agentDir, action)
  20. # Get the count for this (s,a) pair
  21. preCnt = 0
  22. if tup in self.counts:
  23. preCnt = self.counts[tup]
  24. # Update the count for this (s,a) pair
  25. newCnt = preCnt + 1
  26. self.counts[tup] = newCnt
  27. bonus = 1 / math.sqrt(newCnt)
  28. reward += bonus
  29. return obs, reward, done, info
  30. class StateBonus(gym.core.Wrapper):
  31. """
  32. Adds an exploration bonus based on which positions
  33. are visited on the grid.
  34. """
  35. def __init__(self, env):
  36. super().__init__(env)
  37. self.counts = {}
  38. def step(self, action):
  39. obs, reward, done, info = self.env.step(action)
  40. # Tuple based on which we index the counts
  41. # We use the position after an update
  42. env = self.unwrapped
  43. tup = (env.agentPos)
  44. # Get the count for this key
  45. preCnt = 0
  46. if tup in self.counts:
  47. preCnt = self.counts[tup]
  48. # Update the count for this key
  49. newCnt = preCnt + 1
  50. self.counts[tup] = newCnt
  51. bonus = 1 / math.sqrt(newCnt)
  52. reward += bonus
  53. return obs, reward, done, info
  54. class ImgObsWrapper(gym.core.ObservationWrapper):
  55. """
  56. Use rgb image as the only observation output
  57. """
  58. def __init__(self, env):
  59. super().__init__(env)
  60. self.__dict__.update(vars(env)) # hack to pass values to super wrapper
  61. self.observation_space = env.observation_space.spaces['image']
  62. def observation(self, obs):
  63. return obs['image']
  64. class FullyObsWrapper(gym.core.ObservationWrapper):
  65. """
  66. Fully observable gridworld using a compact grid encoding
  67. """
  68. def __init__(self, env):
  69. super().__init__(env)
  70. self.__dict__.update(vars(env)) # hack to pass values to super wrapper
  71. self.observation_space = spaces.Box(
  72. low=0,
  73. high=self.env.grid_size,
  74. shape=(self.env.grid_size, self.env.grid_size, 3), # number of cells
  75. dtype='uint8'
  76. )
  77. def observation(self, obs):
  78. full_grid = self.env.grid.encode()
  79. full_grid[self.env.agent_pos[0]][self.env.agent_pos[1]] = np.array([255, self.env.agent_dir, 0])
  80. return full_grid
  81. class FlatObsWrapper(gym.core.ObservationWrapper):
  82. """
  83. Encode mission strings using a one-hot scheme,
  84. and combine these with observed images into one flat array
  85. """
  86. def __init__(self, env, maxStrLen=64):
  87. super().__init__(env)
  88. self.maxStrLen = maxStrLen
  89. self.numCharCodes = 27
  90. imgSpace = env.observation_space.spaces['image']
  91. imgSize = reduce(operator.mul, imgSpace.shape, 1)
  92. self.observation_space = spaces.Box(
  93. low=0,
  94. high=255,
  95. shape=(1, imgSize + self.numCharCodes * self.maxStrLen),
  96. dtype='uint8'
  97. )
  98. self.cachedStr = None
  99. self.cachedArray = None
  100. def observation(self, obs):
  101. image = obs['image']
  102. mission = obs['mission']
  103. # Cache the last-encoded mission string
  104. if mission != self.cachedStr:
  105. assert len(mission) <= self.maxStrLen, "mission string too long"
  106. mission = mission.lower()
  107. strArray = np.zeros(shape=(self.maxStrLen, self.numCharCodes), dtype='float32')
  108. for idx, ch in enumerate(mission):
  109. if ch >= 'a' and ch <= 'z':
  110. chNo = ord(ch) - ord('a')
  111. elif ch == ' ':
  112. chNo = ord('z') - ord('a') + 1
  113. assert chNo < self.numCharCodes, '%s : %d' % (ch, chNo)
  114. strArray[idx, chNo] = 1
  115. self.cachedStr = mission
  116. self.cachedArray = strArray
  117. obs = np.concatenate((image.flatten(), self.cachedArray.flatten()))
  118. return obs