Browse Source

Fixed wrappers.py following changes in OpenAI gym

Maxime Chevalier-Boisvert 7 years ago
parent
commit
27e33995ab
2 changed files with 6 additions and 4 deletions
  1. 4 4
      gym_minigrid/wrappers.py
  2. 2 0
      pytorch_rl/main.py

+ 4 - 4
gym_minigrid/wrappers.py

@@ -18,7 +18,7 @@ class ActionBonus(gym.core.Wrapper):
         super().__init__(env)
         self.counts = {}
 
-    def _step(self, action):
+    def step(self, action):
 
         obs, reward, done, info = self.env.step(action)
 
@@ -50,7 +50,7 @@ class StateBonus(gym.core.Wrapper):
         super().__init__(env)
         self.counts = {}
 
-    def _step(self, action):
+    def step(self, action):
 
         obs, reward, done, info = self.env.step(action)
 
@@ -80,7 +80,7 @@ class FlatObsWrapper(gym.core.ObservationWrapper):
     and combine these with observed images into one flat array
     """
 
-    def __init__(self, env, maxStrLen=48):
+    def __init__(self, env, maxStrLen=64):
         super().__init__(env)
 
         self.maxStrLen = maxStrLen
@@ -99,7 +99,7 @@ class FlatObsWrapper(gym.core.ObservationWrapper):
         self.cachedStr = None
         self.cachedArray = None
 
-    def _observation(self, obs):
+    def observation(self, obs):
         image = obs['image']
         mission = obs['mission']
 

+ 2 - 0
pytorch_rl/main.py

@@ -150,6 +150,8 @@ def main():
 
             if current_obs.dim() == 4:
                 current_obs *= masks.unsqueeze(2).unsqueeze(2)
+            elif current_obs.dim() == 3:
+                current_obs *= masks.unsqueeze(2)
             else:
                 current_obs *= masks