envs.py 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. import os
  2. import numpy
  3. import gym
  4. from gym.spaces.box import Box
  5. from baselines import bench
  6. from baselines.common.atari_wrappers import make_atari, wrap_deepmind
  7. try:
  8. import pybullet_envs
  9. except ImportError:
  10. pass
  11. try:
  12. import gym_minigrid
  13. except:
  14. pass
  15. class ScaleActions(gym.ActionWrapper):
  16. def __init__(self, env=None):
  17. super(ScaleActions, self).__init__(env)
  18. def _step(self, action):
  19. action = (numpy.tanh(action) + 1) / 2 * (self.action_space.high - self.action_space.low) + self.action_space.low
  20. return self.env.step(action)
  21. def make_env(env_id, seed, rank, log_dir):
  22. def _thunk():
  23. env = gym.make(env_id)
  24. is_atari = hasattr(gym.envs, 'atari') and isinstance(env.unwrapped, gym.envs.atari.atari_env.AtariEnv)
  25. if is_atari:
  26. env = make_atari(env_id)
  27. env.seed(seed + rank)
  28. if log_dir is not None:
  29. env = bench.Monitor(env, os.path.join(log_dir, str(rank)))
  30. if is_atari:
  31. env = wrap_deepmind(env)
  32. # If the input has shape (W,H,3), wrap for PyTorch convolutions
  33. obs_shape = env.observation_space.shape
  34. if len(obs_shape) == 3 and obs_shape[2] == 3:
  35. env = WrapPyTorch(env)
  36. #env = ScaleActions(env)
  37. return env
  38. return _thunk
  39. class WrapPyTorch(gym.ObservationWrapper):
  40. def __init__(self, env=None):
  41. super(WrapPyTorch, self).__init__(env)
  42. obs_shape = self.observation_space.shape
  43. self.observation_space = Box(
  44. self.observation_space.low[0,0,0],
  45. self.observation_space.high[0,0,0],
  46. [obs_shape[2], obs_shape[1], obs_shape[0]]
  47. )
  48. def _observation(self, observation):
  49. return observation.transpose(2, 0, 1)