wrappers.py 809 B

12345678910111213141516171819202122232425262728293031323334353637
  1. import math
  2. import gym
  3. class ExplBonus(gym.core.Wrapper):
  4. """
  5. Wrapper which adds an exploration bonus.
  6. This is a reward to encourage exploration of less
  7. visited (state,action) pairs.
  8. """
  9. def __init__(self, env):
  10. super().__init__(env)
  11. self.counts = {}
  12. def _step(self, action):
  13. obs, reward, done, info = self.env.step(action)
  14. env = self.unwrapped
  15. tup = (env.agentPos, env.agentDir, action)
  16. # Get the count for this (s,a) pair
  17. preCnt = 0
  18. if tup in self.counts:
  19. preCnt = self.counts[tup]
  20. # Update the count for this (s,a) pair
  21. newCnt = preCnt + 1
  22. self.counts[tup] = newCnt
  23. bonus = 1 / math.sqrt(newCnt)
  24. reward += bonus
  25. return obs, reward, done, info