123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101 |
- from abc import ABC, abstractmethod
- class VecEnv(ABC):
- def __init__(self, num_envs, observation_space, action_space):
- self.num_envs = num_envs
- self.observation_space = observation_space
- self.action_space = action_space
- """
- An abstract asynchronous, vectorized environment.
- """
- @abstractmethod
- def reset(self):
- """
- Reset all the environments and return an array of
- observations.
- If step_async is still doing work, that work will
- be cancelled and step_wait() should not be called
- until step_async() is invoked again.
- """
- pass
- @abstractmethod
- def step_async(self, actions):
- """
- Tell all the environments to start taking a step
- with the given actions.
- Call step_wait() to get the results of the step.
- You should not call this if a step_async run is
- already pending.
- """
- pass
- @abstractmethod
- def step_wait(self):
- """
- Wait for the step taken with step_async().
- Returns (obs, rews, dones, infos):
- - obs: an array of observations
- - rews: an array of rewards
- - dones: an array of "episode done" booleans
- - infos: an array of info objects
- """
- pass
- @abstractmethod
- def close(self):
- """
- Clean up the environments' resources.
- """
- pass
- def step(self, actions):
- self.step_async(actions)
- return self.step_wait()
- def render(self):
- logger.warn('Render not defined for %s'%self)
- class VecEnvWrapper(VecEnv):
- def __init__(self, venv, observation_space=None, action_space=None):
- self.venv = venv
- VecEnv.__init__(self,
- num_envs=venv.num_envs,
- observation_space=observation_space or venv.observation_space,
- action_space=action_space or venv.action_space)
- def step_async(self, actions):
- self.venv.step_async(actions)
- @abstractmethod
- def reset(self):
- pass
- @abstractmethod
- def step_wait(self):
- pass
- def close(self):
- return self.venv.close()
- def render(self):
- self.venv.render()
- class CloudpickleWrapper(object):
- """
- Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)
- """
- def __init__(self, x):
- self.x = x
- def __getstate__(self):
- import cloudpickle
- return cloudpickle.dumps(self.x)
- def __setstate__(self, ob):
- import pickle
- self.x = pickle.loads(ob)
|