瀏覽代碼

Added a stochastic action wrapper along with its test (#355)

Bolun 2 年之前
父節點
當前提交
3a6e9cc324
共有 2 個文件被更改,包括 43 次插入1 次删除
  1. 25 1
      minigrid/wrappers.py
  2. 18 0
      tests/test_wrappers.py

+ 25 - 1
minigrid/wrappers.py

@@ -8,7 +8,7 @@ from typing import Any
 import gymnasium as gym
 import numpy as np
 from gymnasium import logger, spaces
-from gymnasium.core import ObservationWrapper, ObsType, Wrapper
+from gymnasium.core import ActionWrapper, ObservationWrapper, ObsType, Wrapper
 
 from minigrid.core.constants import COLOR_TO_IDX, OBJECT_TO_IDX, STATE_TO_IDX
 from minigrid.core.world_object import Goal
@@ -764,3 +764,27 @@ class SymbolicObsWrapper(ObservationWrapper):
         obs["image"] = grid
 
         return obs
+
+
+class StochasticActionWrapper(ActionWrapper):
+    """
+    Add stochasticity to the actions
+
+    If a random action is provided, it is returned with probability `1 - prob`.
+    Else, a random action is sampled from the action space.
+    """
+
+    def __init__(self, env=None, prob=0.9, random_action=None):
+        super().__init__(env)
+        self.prob = prob
+        self.random_action = random_action
+
+    def action(self, action):
+        """ """
+        if np.random.uniform() < self.prob:
+            return action
+        else:
+            if self.random_action is None:
+                return self.np_random.integers(0, high=6)
+            else:
+                return self.random_action

+ 18 - 0
tests/test_wrappers.py

@@ -21,6 +21,7 @@ from minigrid.wrappers import (
     ReseedWrapper,
     RGBImgObsWrapper,
     RGBImgPartialObsWrapper,
+    StochasticActionWrapper,
     SymbolicObsWrapper,
     ViewSizeWrapper,
 )
@@ -329,6 +330,23 @@ def test_symbolic_obs_wrapper(env_id):
     env.close()
 
 
+@pytest.mark.parametrize("env_id", ["MiniGrid-Empty-16x16-v0"])
+def test_stochastic_action_wrapper(env_id):
+    env = gym.make(env_id)
+    env = StochasticActionWrapper(env, prob=0.2)
+    _, _ = env.reset()
+    for _ in range(20):
+        _, _, _, _, _ = env.step(0)
+    env.close()
+
+    env = gym.make(env_id)
+    env = StochasticActionWrapper(env, prob=0.2, random_action=1)
+    _, _ = env.reset()
+    for _ in range(20):
+        _, _, _, _, _ = env.step(0)
+    env.close()
+
+
 def test_dict_observation_space_doesnt_clash_with_one_hot():
     env = gym.make("MiniGrid-Empty-5x5-v0")
     env = OneHotPartialObsWrapper(env)