wrappers.py 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. import math
  2. import gym
  3. class ActionBonus(gym.core.Wrapper):
  4. """
  5. Wrapper which adds an exploration bonus.
  6. This is a reward to encourage exploration of less
  7. visited (state,action) pairs.
  8. """
  9. def __init__(self, env):
  10. super().__init__(env)
  11. self.counts = {}
  12. def _step(self, action):
  13. obs, reward, done, info = self.env.step(action)
  14. env = self.unwrapped
  15. tup = (env.agentPos, env.agentDir, action)
  16. # Get the count for this (s,a) pair
  17. preCnt = 0
  18. if tup in self.counts:
  19. preCnt = self.counts[tup]
  20. # Update the count for this (s,a) pair
  21. newCnt = preCnt + 1
  22. self.counts[tup] = newCnt
  23. bonus = 1 / math.sqrt(newCnt)
  24. reward += bonus
  25. return obs, reward, done, info
  26. class StateBonus(gym.core.Wrapper):
  27. """
  28. Adds an exploration bonus based on which positions
  29. are visited on the grid.
  30. """
  31. def __init__(self, env):
  32. super().__init__(env)
  33. self.counts = {}
  34. def _step(self, action):
  35. obs, reward, done, info = self.env.step(action)
  36. # Tuple based on which we index the counts
  37. # We use the position after an update
  38. env = self.unwrapped
  39. tup = (env.agentPos)
  40. # Get the count for this key
  41. preCnt = 0
  42. if tup in self.counts:
  43. preCnt = self.counts[tup]
  44. # Update the count for this key
  45. newCnt = preCnt + 1
  46. self.counts[tup] = newCnt
  47. bonus = 1 / math.sqrt(newCnt)
  48. reward += bonus
  49. return obs, reward, done, info