envs.py 1.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. import os
  2. import numpy
  3. import gym
  4. from gym import spaces
  5. try:
  6. import gym_minigrid
  7. from gym_minigrid.wrappers import *
  8. except:
  9. pass
  10. def make_env(env_id, seed, rank, log_dir):
  11. def _thunk():
  12. env = gym.make(env_id)
  13. env.seed(seed + rank)
  14. # Maxime: until RL code supports dict observations, squash observations into a flat vector
  15. if isinstance(env.observation_space, spaces.Dict):
  16. env = FlatObsWrapper(env)
  17. # If the input has shape (W,H,3), wrap for PyTorch convolutions
  18. obs_shape = env.observation_space.shape
  19. if len(obs_shape) == 3 and obs_shape[2] == 3:
  20. env = WrapPyTorch(env)
  21. return env
  22. return _thunk
  23. class WrapPyTorch(gym.ObservationWrapper):
  24. def __init__(self, env=None):
  25. super(WrapPyTorch, self).__init__(env)
  26. obs_shape = self.observation_space.shape
  27. self.observation_space = spaces.Box(
  28. self.observation_space.low[0,0,0],
  29. self.observation_space.high[0,0,0],
  30. [obs_shape[2], obs_shape[1], obs_shape[0]]
  31. )
  32. def _observation(self, observation):
  33. return observation.transpose(2, 0, 1)