vec_frame_stack.py 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839
  1. from vec_env import VecEnvWrapper
  2. import numpy as np
  3. from gym import spaces
  4. class VecFrameStack(VecEnvWrapper):
  5. """
  6. Vectorized environment base class
  7. """
  8. def __init__(self, venv, nstack):
  9. self.venv = venv
  10. self.nstack = nstack
  11. wos = venv.observation_space # wrapped ob space
  12. low = np.repeat(wos.low, self.nstack, axis=-1)
  13. high = np.repeat(wos.high, self.nstack, axis=-1)
  14. self.stackedobs = np.zeros((venv.num_envs,)+low.shape, low.dtype)
  15. observation_space = spaces.Box(low=low, high=high, dtype=venv.observation_space.dtype)
  16. VecEnvWrapper.__init__(self, venv, observation_space=observation_space)
  17. def step_wait(self):
  18. obs, rews, news, infos = self.venv.step_wait()
  19. self.stackedobs = np.roll(self.stackedobs, shift=-1, axis=-1)
  20. for (i, new) in enumerate(news):
  21. if new:
  22. self.stackedobs[i] = 0
  23. self.stackedobs[..., -obs.shape[-1]:] = obs
  24. return self.stackedobs, rews, news, infos
  25. def reset(self):
  26. """
  27. Reset all environments
  28. """
  29. obs = self.venv.reset()
  30. self.stackedobs[...] = 0
  31. self.stackedobs[..., -obs.shape[-1]:] = obs
  32. return self.stackedobs
  33. def close(self):
  34. self.venv.close()