envs.py 504 B

12345678910111213141516171819202122232425
  1. import os
  2. import numpy
  3. import gym
  4. from gym import spaces
  5. try:
  6. import gym_minigrid
  7. from gym_minigrid.wrappers import *
  8. except:
  9. pass
  10. def make_env(env_id, seed, rank, log_dir):
  11. def _thunk():
  12. env = gym.make(env_id)
  13. env.seed(seed + rank)
  14. # Maxime: until RL code supports dict observations, squash observations into a flat vector
  15. if isinstance(env.observation_space, spaces.Dict):
  16. env = FlatObsWrapper(env)
  17. return env
  18. return _thunk