import os import numpy import gym from gym.spaces.box import Box from baselines import bench from baselines.common.atari_wrappers import make_atari, wrap_deepmind try: import pybullet_envs except ImportError: pass try: import gym_minigrid except: pass class ScaleActions(gym.ActionWrapper): def __init__(self, env=None): super(ScaleActions, self).__init__(env) def _step(self, action): action = (numpy.tanh(action) + 1) / 2 * (self.action_space.high - self.action_space.low) + self.action_space.low return self.env.step(action) def make_env(env_id, seed, rank, log_dir): def _thunk(): env = gym.make(env_id) is_atari = hasattr(gym.envs, 'atari') and isinstance(env.unwrapped, gym.envs.atari.atari_env.AtariEnv) if is_atari: env = make_atari(env_id) env.seed(seed + rank) if log_dir is not None: env = bench.Monitor(env, os.path.join(log_dir, str(rank))) if is_atari: env = wrap_deepmind(env) # If the input has shape (W,H,3), wrap for PyTorch convolutions obs_shape = env.observation_space.shape if len(obs_shape) == 3 and obs_shape[2] == 3: env = WrapPyTorch(env) #env = ScaleActions(env) return env return _thunk class WrapPyTorch(gym.ObservationWrapper): def __init__(self, env=None): super(WrapPyTorch, self).__init__(env) obs_shape = self.observation_space.shape self.observation_space = Box( self.observation_space.low[0,0,0], self.observation_space.high[0,0,0], [obs_shape[2], obs_shape[1], obs_shape[0]] ) def _observation(self, observation): return observation.transpose(2, 0, 1)