__init__.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. from abc import ABC, abstractmethod
  2. class VecEnv(ABC):
  3. def __init__(self, num_envs, observation_space, action_space):
  4. self.num_envs = num_envs
  5. self.observation_space = observation_space
  6. self.action_space = action_space
  7. """
  8. An abstract asynchronous, vectorized environment.
  9. """
  10. @abstractmethod
  11. def reset(self):
  12. """
  13. Reset all the environments and return an array of
  14. observations.
  15. If step_async is still doing work, that work will
  16. be cancelled and step_wait() should not be called
  17. until step_async() is invoked again.
  18. """
  19. pass
  20. @abstractmethod
  21. def step_async(self, actions):
  22. """
  23. Tell all the environments to start taking a step
  24. with the given actions.
  25. Call step_wait() to get the results of the step.
  26. You should not call this if a step_async run is
  27. already pending.
  28. """
  29. pass
  30. @abstractmethod
  31. def step_wait(self):
  32. """
  33. Wait for the step taken with step_async().
  34. Returns (obs, rews, dones, infos):
  35. - obs: an array of observations
  36. - rews: an array of rewards
  37. - dones: an array of "episode done" booleans
  38. - infos: an array of info objects
  39. """
  40. pass
  41. @abstractmethod
  42. def close(self):
  43. """
  44. Clean up the environments' resources.
  45. """
  46. pass
  47. def step(self, actions):
  48. self.step_async(actions)
  49. return self.step_wait()
  50. def render(self):
  51. logger.warn('Render not defined for %s'%self)
  52. class VecEnvWrapper(VecEnv):
  53. def __init__(self, venv, observation_space=None, action_space=None):
  54. self.venv = venv
  55. VecEnv.__init__(self,
  56. num_envs=venv.num_envs,
  57. observation_space=observation_space or venv.observation_space,
  58. action_space=action_space or venv.action_space)
  59. def step_async(self, actions):
  60. self.venv.step_async(actions)
  61. @abstractmethod
  62. def reset(self):
  63. pass
  64. @abstractmethod
  65. def step_wait(self):
  66. pass
  67. def close(self):
  68. return self.venv.close()
  69. def render(self):
  70. self.venv.render()
  71. class CloudpickleWrapper(object):
  72. """
  73. Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)
  74. """
  75. def __init__(self, x):
  76. self.x = x
  77. def __getstate__(self):
  78. import cloudpickle
  79. return cloudpickle.dumps(self.x)
  80. def __setstate__(self, ob):
  81. import pickle
  82. self.x = pickle.loads(ob)