|
@@ -7,6 +7,28 @@ import gym
|
|
|
from gym import error, spaces, utils
|
|
|
from .minigrid import OBJECT_TO_IDX, COLOR_TO_IDX
|
|
|
|
|
|
+class ReseedWrapper(gym.core.Wrapper):
|
|
|
+ """
|
|
|
+ Wrapper to always regenerate an environment with the same set of seeds.
|
|
|
+ This can be used to force an environment to always keep the same
|
|
|
+ configuration when reset.
|
|
|
+ """
|
|
|
+
|
|
|
+ def __init__(self, env, seeds=[0], seed_idx=0):
|
|
|
+ self.seeds = list(seeds)
|
|
|
+ self.seed_idx = seed_idx
|
|
|
+ super().__init__(env)
|
|
|
+
|
|
|
+ def reset(self, **kwargs):
|
|
|
+ seed = self.seeds[self.seed_idx]
|
|
|
+ self.seed_idx = (self.seed_idx + 1) % len(self.seeds)
|
|
|
+ self.env.seed(seed)
|
|
|
+ return self.env.reset(**kwargs)
|
|
|
+
|
|
|
+ def step(self, action):
|
|
|
+ obs, reward, done, info = self.env.step(action)
|
|
|
+ return obs, reward, done, info
|
|
|
+
|
|
|
class ActionBonus(gym.core.Wrapper):
|
|
|
"""
|
|
|
Wrapper which adds an exploration bonus.
|