wrappers.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. import math
  2. import operator
  3. from functools import reduce
  4. import numpy as np
  5. import gym
  6. from gym import error, spaces, utils
  7. class ActionBonus(gym.core.Wrapper):
  8. """
  9. Wrapper which adds an exploration bonus.
  10. This is a reward to encourage exploration of less
  11. visited (state,action) pairs.
  12. """
  13. def __init__(self, env):
  14. super().__init__(env)
  15. self.counts = {}
  16. def step(self, action):
  17. obs, reward, done, info = self.env.step(action)
  18. env = self.unwrapped
  19. tup = (env.agentPos, env.agentDir, action)
  20. # Get the count for this (s,a) pair
  21. preCnt = 0
  22. if tup in self.counts:
  23. preCnt = self.counts[tup]
  24. # Update the count for this (s,a) pair
  25. newCnt = preCnt + 1
  26. self.counts[tup] = newCnt
  27. bonus = 1 / math.sqrt(newCnt)
  28. reward += bonus
  29. return obs, reward, done, info
  30. class StateBonus(gym.core.Wrapper):
  31. """
  32. Adds an exploration bonus based on which positions
  33. are visited on the grid.
  34. """
  35. def __init__(self, env):
  36. super().__init__(env)
  37. self.counts = {}
  38. def step(self, action):
  39. obs, reward, done, info = self.env.step(action)
  40. # Tuple based on which we index the counts
  41. # We use the position after an update
  42. env = self.unwrapped
  43. tup = (env.agentPos)
  44. # Get the count for this key
  45. preCnt = 0
  46. if tup in self.counts:
  47. preCnt = self.counts[tup]
  48. # Update the count for this key
  49. newCnt = preCnt + 1
  50. self.counts[tup] = newCnt
  51. bonus = 1 / math.sqrt(newCnt)
  52. reward += bonus
  53. return obs, reward, done, info
  54. class ImgObsWrapper(gym.core.ObservationWrapper):
  55. """
  56. Use rgb image as the only observation output
  57. """
  58. def __init__(self, env):
  59. super().__init__(env)
  60. # Hack to pass values to super wrapper
  61. self.__dict__.update(vars(env))
  62. self.observation_space = env.observation_space.spaces['image']
  63. def observation(self, obs):
  64. return obs['image']
  65. class FullyObsWrapper(gym.core.ObservationWrapper):
  66. """
  67. Fully observable gridworld using a compact grid encoding
  68. """
  69. def __init__(self, env):
  70. super().__init__(env)
  71. self.__dict__.update(vars(env)) # hack to pass values to super wrapper
  72. self.observation_space = spaces.Box(
  73. low=0,
  74. high=self.env.grid_size,
  75. shape=(self.env.grid_size, self.env.grid_size, 3), # number of cells
  76. dtype='uint8'
  77. )
  78. def observation(self, obs):
  79. full_grid = self.env.grid.encode()
  80. full_grid[self.env.agent_pos[0]][self.env.agent_pos[1]] = np.array([255, self.env.agent_dir, 0])
  81. return full_grid
  82. class FlatObsWrapper(gym.core.ObservationWrapper):
  83. """
  84. Encode mission strings using a one-hot scheme,
  85. and combine these with observed images into one flat array
  86. """
  87. def __init__(self, env, maxStrLen=64):
  88. super().__init__(env)
  89. self.maxStrLen = maxStrLen
  90. self.numCharCodes = 27
  91. imgSpace = env.observation_space.spaces['image']
  92. imgSize = reduce(operator.mul, imgSpace.shape, 1)
  93. self.observation_space = spaces.Box(
  94. low=0,
  95. high=255,
  96. shape=(1, imgSize + self.numCharCodes * self.maxStrLen),
  97. dtype='uint8'
  98. )
  99. self.cachedStr = None
  100. self.cachedArray = None
  101. def observation(self, obs):
  102. image = obs['image']
  103. mission = obs['mission']
  104. # Cache the last-encoded mission string
  105. if mission != self.cachedStr:
  106. assert len(mission) <= self.maxStrLen, "mission string too long"
  107. mission = mission.lower()
  108. strArray = np.zeros(shape=(self.maxStrLen, self.numCharCodes), dtype='float32')
  109. for idx, ch in enumerate(mission):
  110. if ch >= 'a' and ch <= 'z':
  111. chNo = ord(ch) - ord('a')
  112. elif ch == ' ':
  113. chNo = ord('z') - ord('a') + 1
  114. assert chNo < self.numCharCodes, '%s : %d' % (ch, chNo)
  115. strArray[idx, chNo] = 1
  116. self.cachedStr = mission
  117. self.cachedArray = strArray
  118. obs = np.concatenate((image.flatten(), self.cachedArray.flatten()))
  119. return obs