envs.py 1.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. import os
  2. import numpy
  3. import gym
  4. from gym.spaces.box import Box
  5. try:
  6. import gym_minigrid
  7. from gym_minigrid.wrappers import *
  8. #from gym_minigrid.envs import *
  9. except:
  10. pass
  11. def make_env(env_id, seed, rank, log_dir):
  12. def _thunk():
  13. env = gym.make(env_id)
  14. env.seed(seed + rank)
  15. #env = FlatObsWrapper(env)
  16. # If the input has shape (W,H,3), wrap for PyTorch convolutions
  17. obs_shape = env.observation_space.shape
  18. if len(obs_shape) == 3 and obs_shape[2] == 3:
  19. env = WrapPyTorch(env)
  20. return env
  21. return _thunk
  22. class WrapPyTorch(gym.ObservationWrapper):
  23. def __init__(self, env=None):
  24. super(WrapPyTorch, self).__init__(env)
  25. obs_shape = self.observation_space.shape
  26. self.observation_space = Box(
  27. self.observation_space.low[0,0,0],
  28. self.observation_space.high[0,0,0],
  29. [obs_shape[2], obs_shape[1], obs_shape[0]]
  30. )
  31. def _observation(self, observation):
  32. return observation.transpose(2, 0, 1)