Browse Source

Eliminated WrapPyTorch

Maxime Chevalier-Boisvert 7 years ago
parent
commit
d70e134948
2 changed files with 0 additions and 19 deletions
  1. 0 18
      pytorch_rl/envs.py
  2. 0 1
      pytorch_rl/main.py

+ 0 - 18
pytorch_rl/envs.py

@@ -19,24 +19,6 @@ def make_env(env_id, seed, rank, log_dir):
         if isinstance(env.observation_space, spaces.Dict):
             env = FlatObsWrapper(env)
 
-        # If the input has shape (W,H,3), wrap for PyTorch convolutions
-        obs_shape = env.observation_space.shape
-        if len(obs_shape) == 3 and obs_shape[2] == 3:
-            env = WrapPyTorch(env)
-
         return env
 
     return _thunk
-
-class WrapPyTorch(gym.ObservationWrapper):
-    def __init__(self, env=None):
-        super(WrapPyTorch, self).__init__(env)
-        obs_shape = self.observation_space.shape
-        self.observation_space = spaces.Box(
-            self.observation_space.low[0,0,0],
-            self.observation_space.high[0,0,0],
-            [obs_shape[2], obs_shape[1], obs_shape[0]]
-        )
-
-    def _observation(self, observation):
-        return observation.transpose(2, 0, 1)

+ 0 - 1
pytorch_rl/main.py

@@ -113,7 +113,6 @@ def main():
 
     obs = envs.reset()
     update_current_obs(obs)
-
     rollouts.observations[0].copy_(current_obs)
 
     # These variables are used to compute average rewards for all processes.