浏览代码

Add scaling for PositionBonus (#433)

Co-authored-by: Mark Towers <mark.m.towers@gmail.com>
Samuel Bubán 2 月之前
父节点
当前提交
8710e91d23
共有 1 个文件被更改,包括 6 次插入7 次删除
  1. 6 7
      minigrid/wrappers.py

+ 6 - 7
minigrid/wrappers.py

@@ -125,7 +125,7 @@ class ActionBonus(gym.Wrapper):
 
 class PositionBonus(Wrapper):
     """
-    Adds an exploration bonus based on which positions
+    Adds a scaled exploration bonus based on which positions
     are visited on the grid.
 
     Note:
@@ -142,7 +142,7 @@ class PositionBonus(Wrapper):
         >>> _, reward, _, _, _ = env.step(1)
         >>> print(reward)
         0
-        >>> env_bonus = PositionBonus(env)
+        >>> env_bonus = PositionBonus(env, scale=1)
         >>> obs, _ = env_bonus.reset(seed=0)
         >>> obs, reward, terminated, truncated, info = env_bonus.step(1)
         >>> print(reward)
@@ -152,7 +152,7 @@ class PositionBonus(Wrapper):
         0.7071067811865475
     """
 
-    def __init__(self, env):
+    def __init__(self, env, scale=1):
         """A wrapper that adds an exploration bonus to less visited positions.
 
         Args:
@@ -160,6 +160,7 @@ class PositionBonus(Wrapper):
         """
         super().__init__(env)
         self.counts = {}
+        self.scale = 1
 
     def step(self, action):
         """Steps through the environment with `action`."""
@@ -171,16 +172,14 @@ class PositionBonus(Wrapper):
         tup = tuple(env.agent_pos)
 
         # Get the count for this key
-        pre_count = 0
-        if tup in self.counts:
-            pre_count = self.counts[tup]
+        pre_count = self.counts.get(tup, 0)
 
         # Update the count for this key
         new_count = pre_count + 1
         self.counts[tup] = new_count
 
         bonus = 1 / math.sqrt(new_count)
-        reward += bonus
+        reward += bonus * self.scale
 
         return obs, reward, terminated, truncated, info