envs.py 1.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. import os
  2. import numpy
  3. import gym
  4. from gym.spaces.box import Box
  5. from baselines import bench
  6. from baselines.common.atari_wrappers import make_atari, wrap_deepmind
  7. try:
  8. import pybullet_envs
  9. except ImportError:
  10. pass
  11. try:
  12. import gym_minigrid
  13. from gym_minigrid.wrappers import *
  14. except:
  15. pass
  16. def make_env(env_id, seed, rank, log_dir, size=None):
  17. def _thunk():
  18. env = gym.make(env_id)
  19. env.seed(seed + rank)
  20. if size is not None:
  21. env.gridSize = size
  22. # If the input has shape (W,H,3), wrap for PyTorch convolutions
  23. obs_shape = env.observation_space.shape
  24. if len(obs_shape) == 3 and obs_shape[2] == 3:
  25. env = WrapPyTorch(env)
  26. #env = StateBonus(env)
  27. return env
  28. return _thunk
  29. class WrapPyTorch(gym.ObservationWrapper):
  30. def __init__(self, env=None):
  31. super(WrapPyTorch, self).__init__(env)
  32. obs_shape = self.observation_space.shape
  33. self.observation_space = Box(
  34. self.observation_space.low[0,0,0],
  35. self.observation_space.high[0,0,0],
  36. [obs_shape[2], obs_shape[1], obs_shape[0]]
  37. )
  38. def _observation(self, observation):
  39. return observation.transpose(2, 0, 1)