|
@@ -501,7 +501,7 @@
|
|
|
<a class="viewcode-back" href="../../../api/wrappers/#minigrid.wrappers.PositionBonus">[docs]</a>
|
|
|
<span class="k">class</span><span class="w"> </span><span class="nc">PositionBonus</span><span class="p">(</span><span class="n">Wrapper</span><span class="p">):</span>
|
|
|
<span class="w"> </span><span class="sd">"""</span>
|
|
|
-<span class="sd"> Adds an exploration bonus based on which positions</span>
|
|
|
+<span class="sd"> Adds a scaled exploration bonus based on which positions</span>
|
|
|
<span class="sd"> are visited on the grid.</span>
|
|
|
|
|
|
<span class="sd"> Note:</span>
|
|
@@ -518,7 +518,7 @@
|
|
|
<span class="sd"> >>> _, reward, _, _, _ = env.step(1)</span>
|
|
|
<span class="sd"> >>> print(reward)</span>
|
|
|
<span class="sd"> 0</span>
|
|
|
-<span class="sd"> >>> env_bonus = PositionBonus(env)</span>
|
|
|
+<span class="sd"> >>> env_bonus = PositionBonus(env, scale=1)</span>
|
|
|
<span class="sd"> >>> obs, _ = env_bonus.reset(seed=0)</span>
|
|
|
<span class="sd"> >>> obs, reward, terminated, truncated, info = env_bonus.step(1)</span>
|
|
|
<span class="sd"> >>> print(reward)</span>
|
|
@@ -528,7 +528,7 @@
|
|
|
<span class="sd"> 0.7071067811865475</span>
|
|
|
<span class="sd"> """</span>
|
|
|
|
|
|
- <span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">env</span><span class="p">):</span>
|
|
|
+ <span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">env</span><span class="p">,</span> <span class="n">scale</span><span class="o">=</span><span class="mi">1</span><span class="p">):</span>
|
|
|
<span class="w"> </span><span class="sd">"""A wrapper that adds an exploration bonus to less visited positions.</span>
|
|
|
|
|
|
<span class="sd"> Args:</span>
|
|
@@ -536,6 +536,7 @@
|
|
|
<span class="sd"> """</span>
|
|
|
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">env</span><span class="p">)</span>
|
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">counts</span> <span class="o">=</span> <span class="p">{}</span>
|
|
|
+ <span class="bp">self</span><span class="o">.</span><span class="n">scale</span> <span class="o">=</span> <span class="mi">1</span>
|
|
|
|
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">action</span><span class="p">):</span>
|
|
|
<span class="w"> </span><span class="sd">"""Steps through the environment with `action`."""</span>
|
|
@@ -547,16 +548,14 @@
|
|
|
<span class="n">tup</span> <span class="o">=</span> <span class="nb">tuple</span><span class="p">(</span><span class="n">env</span><span class="o">.</span><span class="n">agent_pos</span><span class="p">)</span>
|
|
|
|
|
|
<span class="c1"># Get the count for this key</span>
|
|
|
- <span class="n">pre_count</span> <span class="o">=</span> <span class="mi">0</span>
|
|
|
- <span class="k">if</span> <span class="n">tup</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">counts</span><span class="p">:</span>
|
|
|
- <span class="n">pre_count</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">counts</span><span class="p">[</span><span class="n">tup</span><span class="p">]</span>
|
|
|
+ <span class="n">pre_count</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">counts</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">tup</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
|
|
|
|
|
|
<span class="c1"># Update the count for this key</span>
|
|
|
<span class="n">new_count</span> <span class="o">=</span> <span class="n">pre_count</span> <span class="o">+</span> <span class="mi">1</span>
|
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">counts</span><span class="p">[</span><span class="n">tup</span><span class="p">]</span> <span class="o">=</span> <span class="n">new_count</span>
|
|
|
|
|
|
<span class="n">bonus</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">/</span> <span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">new_count</span><span class="p">)</span>
|
|
|
- <span class="n">reward</span> <span class="o">+=</span> <span class="n">bonus</span>
|
|
|
+ <span class="n">reward</span> <span class="o">+=</span> <span class="n">bonus</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span>
|
|
|
|
|
|
<span class="k">return</span> <span class="n">obs</span><span class="p">,</span> <span class="n">reward</span><span class="p">,</span> <span class="n">terminated</span><span class="p">,</span> <span class="n">truncated</span><span class="p">,</span> <span class="n">info</span></div>
|
|
|
|