12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970 |
- import math
- import gym
- class ActionBonus(gym.core.Wrapper):
- """
- Wrapper which adds an exploration bonus.
- This is a reward to encourage exploration of less
- visited (state,action) pairs.
- """
- def __init__(self, env):
- super().__init__(env)
- self.counts = {}
- def _step(self, action):
- obs, reward, done, info = self.env.step(action)
- env = self.unwrapped
- tup = (env.agentPos, env.agentDir, action)
- # Get the count for this (s,a) pair
- preCnt = 0
- if tup in self.counts:
- preCnt = self.counts[tup]
- # Update the count for this (s,a) pair
- newCnt = preCnt + 1
- self.counts[tup] = newCnt
- bonus = 1 / math.sqrt(newCnt)
- reward += bonus
- return obs, reward, done, info
- class StateBonus(gym.core.Wrapper):
- """
- Adds an exploration bonus based on which positions
- are visited on the grid.
- """
- def __init__(self, env):
- super().__init__(env)
- self.counts = {}
- def _step(self, action):
- obs, reward, done, info = self.env.step(action)
- # Tuple based on which we index the counts
- # We use the position after an update
- env = self.unwrapped
- tup = (env.agentPos)
- # Get the count for this key
- preCnt = 0
- if tup in self.counts:
- preCnt = self.counts[tup]
- # Update the count for this key
- newCnt = preCnt + 1
- self.counts[tup] = newCnt
- bonus = 1 / math.sqrt(newCnt)
- reward += bonus
- return obs, reward, done, info
|