|
@@ -125,7 +125,7 @@ class ActionBonus(gym.Wrapper):
|
|
|
|
|
|
class PositionBonus(Wrapper):
|
|
|
"""
|
|
|
- Adds an exploration bonus based on which positions
|
|
|
+ Adds a scaled exploration bonus based on which positions
|
|
|
are visited on the grid.
|
|
|
|
|
|
Note:
|
|
@@ -142,7 +142,7 @@ class PositionBonus(Wrapper):
|
|
|
>>> _, reward, _, _, _ = env.step(1)
|
|
|
>>> print(reward)
|
|
|
0
|
|
|
- >>> env_bonus = PositionBonus(env)
|
|
|
+ >>> env_bonus = PositionBonus(env, scale=1)
|
|
|
>>> obs, _ = env_bonus.reset(seed=0)
|
|
|
>>> obs, reward, terminated, truncated, info = env_bonus.step(1)
|
|
|
>>> print(reward)
|
|
@@ -152,7 +152,7 @@ class PositionBonus(Wrapper):
|
|
|
0.7071067811865475
|
|
|
"""
|
|
|
|
|
|
- def __init__(self, env):
|
|
|
+ def __init__(self, env, scale=1):
|
|
|
"""A wrapper that adds an exploration bonus to less visited positions.
|
|
|
|
|
|
Args:
|
|
@@ -160,6 +160,7 @@ class PositionBonus(Wrapper):
|
|
|
"""
|
|
|
super().__init__(env)
|
|
|
self.counts = {}
|
|
|
+ self.scale = 1
|
|
|
|
|
|
def step(self, action):
|
|
|
"""Steps through the environment with `action`."""
|
|
@@ -171,16 +172,14 @@ class PositionBonus(Wrapper):
|
|
|
tup = tuple(env.agent_pos)
|
|
|
|
|
|
# Get the count for this key
|
|
|
- pre_count = 0
|
|
|
- if tup in self.counts:
|
|
|
- pre_count = self.counts[tup]
|
|
|
+ pre_count = self.counts.get(tup, 0)
|
|
|
|
|
|
# Update the count for this key
|
|
|
new_count = pre_count + 1
|
|
|
self.counts[tup] = new_count
|
|
|
|
|
|
bonus = 1 / math.sqrt(new_count)
|
|
|
- reward += bonus
|
|
|
+ reward += bonus * self.scale
|
|
|
|
|
|
return obs, reward, terminated, truncated, info
|
|
|
|