Explorar o código

Wrapper Improvements (#55)

* Wrapper improvements to avoid overriding env members (from wrapped env)

* Fixed State/Action bonus wrappers, ensure env vals are being passed

* Fixed deprecation warnings

* Use unwrapped env with FullyObsWrapper; doesn't work otherwise with
other wrappers

* Modify unwrapped in agent view wrapper, to work with wrapped env
Abdelrahman Ahmed %!s(int64=6) %!d(string=hai) anos
pai
achega
798ce869b2
Modificáronse 1 ficheiros con 36 adicións e 23 borrados
  1. 36 23
      gym_minigrid/wrappers.py

+ 36 - 23
gym_minigrid/wrappers.py

@@ -15,31 +15,33 @@ class ActionBonus(gym.core.Wrapper):
     """
 
     def __init__(self, env):
+        self.__dict__.update(vars(env))  # Pass values to super wrapper
         super().__init__(env)
         self.counts = {}
 
     def step(self, action):
-
         obs, reward, done, info = self.env.step(action)
 
         env = self.unwrapped
-        tup = (env.agentPos, env.agentDir, action)
+        tup = (tuple(env.agent_pos), env.agent_dir, action)
 
         # Get the count for this (s,a) pair
-        preCnt = 0
+        pre_count = 0
         if tup in self.counts:
-            preCnt = self.counts[tup]
+            pre_count = self.counts[tup]
 
         # Update the count for this (s,a) pair
-        newCnt = preCnt + 1
-        self.counts[tup] = newCnt
-
-        bonus = 1 / math.sqrt(newCnt)
+        new_count = pre_count + 1
+        self.counts[tup] = new_count
 
+        bonus = 1 / math.sqrt(new_count)
         reward += bonus
 
         return obs, reward, done, info
 
+    def reset(self, **kwargs):
+        return self.env.reset(**kwargs)
+
 class StateBonus(gym.core.Wrapper):
     """
     Adds an exploration bonus based on which positions
@@ -47,42 +49,44 @@ class StateBonus(gym.core.Wrapper):
     """
 
     def __init__(self, env):
+        self.__dict__.update(vars(env))  # Pass values to super wrapper
         super().__init__(env)
         self.counts = {}
 
     def step(self, action):
-
         obs, reward, done, info = self.env.step(action)
 
         # Tuple based on which we index the counts
         # We use the position after an update
         env = self.unwrapped
-        tup = (env.agentPos)
+        tup = (tuple(env.agent_pos))
 
         # Get the count for this key
-        preCnt = 0
+        pre_count = 0
         if tup in self.counts:
-            preCnt = self.counts[tup]
+            pre_count = self.counts[tup]
 
         # Update the count for this key
-        newCnt = preCnt + 1
-        self.counts[tup] = newCnt
-
-        bonus = 1 / math.sqrt(newCnt)
+        new_count = pre_count + 1
+        self.counts[tup] = new_count
 
+        bonus = 1 / math.sqrt(new_count)
         reward += bonus
 
         return obs, reward, done, info
 
+    def reset(self, **kwargs):
+        return self.env.reset(**kwargs)
+
 class ImgObsWrapper(gym.core.ObservationWrapper):
     """
     Use the image as the only observation output, no language/mission.
     """
 
     def __init__(self, env):
+        self.__dict__.update(vars(env))  # Pass values to super wrapper
         super().__init__(env)
-        # Hack to pass values to super wrapper
-        self.__dict__.update(vars(env))
+
         self.observation_space = env.observation_space.spaces['image']
 
     def observation(self, obs):
@@ -94,8 +98,9 @@ class FullyObsWrapper(gym.core.ObservationWrapper):
     """
 
     def __init__(self, env):
+        self.__dict__.update(vars(env))  # Pass values to super wrapper
         super().__init__(env)
-        self.__dict__.update(vars(env))  # hack to pass values to super wrapper
+
         self.observation_space = spaces.Box(
             low=0,
             high=255,
@@ -104,8 +109,9 @@ class FullyObsWrapper(gym.core.ObservationWrapper):
         )
 
     def observation(self, obs):
-        full_grid = self.env.grid.encode()
-        full_grid[self.env.agent_pos[0]][self.env.agent_pos[1]] = np.array([255, self.env.agent_dir, 0])
+        env = self.unwrapped
+        full_grid = env.grid.encode()
+        full_grid[env.agent_pos[0]][env.agent_pos[1]] = np.array([255, env.agent_dir, 0])
         return full_grid
 
 class FlatObsWrapper(gym.core.ObservationWrapper):
@@ -115,6 +121,7 @@ class FlatObsWrapper(gym.core.ObservationWrapper):
     """
 
     def __init__(self, env, maxStrLen=64):
+        self.__dict__.update(vars(env))  # Pass values to super wrapper
         super().__init__(env)
 
         self.maxStrLen = maxStrLen
@@ -166,11 +173,11 @@ class AgentViewWrapper(gym.core.Wrapper):
     """
 
     def __init__(self, env, agent_view_size=7):
+        self.__dict__.update(vars(env))  # Pass values to super wrapper
         super(AgentViewWrapper, self).__init__(env)
-        self.__dict__.update(vars(env))  # Hack to pass values to super wrapper
 
         # Override default view size
-        env.agent_view_size = agent_view_size
+        env.unwrapped.agent_view_size = agent_view_size
 
         # Compute observation space with specified view size
         observation_space = gym.spaces.Box(
@@ -184,3 +191,9 @@ class AgentViewWrapper(gym.core.Wrapper):
         self.observation_space = spaces.Dict({
             'image': observation_space
         })
+
+    def reset(self, **kwargs):
+        return self.env.reset(**kwargs)
+
+    def step(self, action):
+        return self.env.step(action)