|
@@ -8,7 +8,7 @@ from typing import Any
|
|
|
import gymnasium as gym
|
|
|
import numpy as np
|
|
|
from gymnasium import logger, spaces
|
|
|
-from gymnasium.core import ObservationWrapper, ObsType, Wrapper
|
|
|
+from gymnasium.core import ActionWrapper, ObservationWrapper, ObsType, Wrapper
|
|
|
|
|
|
from minigrid.core.constants import COLOR_TO_IDX, OBJECT_TO_IDX, STATE_TO_IDX
|
|
|
from minigrid.core.world_object import Goal
|
|
@@ -764,3 +764,27 @@ class SymbolicObsWrapper(ObservationWrapper):
|
|
|
obs["image"] = grid
|
|
|
|
|
|
return obs
|
|
|
+
|
|
|
+
|
|
|
+class StochasticActionWrapper(ActionWrapper):
|
|
|
+ """
|
|
|
+ Add stochasticity to the actions
|
|
|
+
|
|
|
+ If a random action is provided, it is returned with probability `1 - prob`.
|
|
|
+ Else, a random action is sampled from the action space.
|
|
|
+ """
|
|
|
+
|
|
|
+ def __init__(self, env=None, prob=0.9, random_action=None):
|
|
|
+ super().__init__(env)
|
|
|
+ self.prob = prob
|
|
|
+ self.random_action = random_action
|
|
|
+
|
|
|
+ def action(self, action):
|
|
|
+ """ """
|
|
|
+ if np.random.uniform() < self.prob:
|
|
|
+ return action
|
|
|
+ else:
|
|
|
+ if self.random_action is None:
|
|
|
+ return self.np_random.integers(0, high=6)
|
|
|
+ else:
|
|
|
+ return self.random_action
|