dummy_vec_env.py 1002 B

12345678910111213141516171819202122232425262728293031
  1. import numpy as np
  2. from . import VecEnv
  3. class DummyVecEnv(VecEnv):
  4. def __init__(self, env_fns):
  5. self.envs = [fn() for fn in env_fns]
  6. env = self.envs[0]
  7. VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space)
  8. self.ts = np.zeros(len(self.envs), dtype='int')
  9. self.actions = None
  10. def step_async(self, actions):
  11. self.actions = actions
  12. def step_wait(self):
  13. results = [env.step(a) for (a,env) in zip(self.actions, self.envs)]
  14. obs, rews, dones, infos = map(np.array, zip(*results))
  15. self.ts += 1
  16. for (i, done) in enumerate(dones):
  17. if done:
  18. obs[i] = self.envs[i].reset()
  19. self.ts[i] = 0
  20. self.actions = None
  21. return np.array(obs), np.array(rews), np.array(dones), infos
  22. def reset(self):
  23. results = [env.reset() for env in self.envs]
  24. return np.array(results)
  25. def close(self):
  26. return