test_envs.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. import gym
  2. import pytest
  3. from gym.envs.registration import EnvSpec
  4. from gym.utils.env_checker import check_env
  5. from tests.envs.utils import all_testing_env_specs, assert_equals
  6. # This runs a smoketest on each official registered env. We may want
  7. # to try also running environments which are not officially registered envs.
  8. IGNORE_WARNINGS = [
  9. "Agent's minimum observation space value is -infinity. This is probably too low.",
  10. "Agent's maximum observation space value is infinity. This is probably too high.",
  11. "We recommend you to use a symmetric and normalized Box action space (range=[-1, 1]) https://stable-baselines3.readthedocs.io/en/master/guide/rl_tips.html",
  12. ]
  13. IGNORE_WARNINGS = [f"\x1b[33mWARN: {message}\x1b[0m" for message in IGNORE_WARNINGS]
  14. @pytest.mark.parametrize(
  15. "spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
  16. )
  17. def test_env(spec):
  18. # Capture warnings
  19. env = spec.make(disable_env_checker=True).unwrapped
  20. # Test if env adheres to Gym API
  21. with pytest.warns(None) as warnings:
  22. check_env(env)
  23. for warning in warnings.list:
  24. if warning.message.args[0] not in IGNORE_WARNINGS:
  25. raise gym.error.Error(f"Unexpected warning: {warning.message}")
  26. # Note that this precludes running this test in multiple threads.
  27. # However, we probably already can't do multithreading due to some environments.
  28. SEED = 0
  29. NUM_STEPS = 50
  30. @pytest.mark.parametrize(
  31. "env_spec", all_testing_env_specs, ids=[env.id for env in all_testing_env_specs]
  32. )
  33. def test_env_determinism_rollout(env_spec: EnvSpec):
  34. """Run a rollout with two environments and assert equality.
  35. This test run a rollout of NUM_STEPS steps with two environments
  36. initialized with the same seed and assert that:
  37. - observation after first reset are the same
  38. - same actions are sampled by the two envs
  39. - observations are contained in the observation space
  40. - obs, rew, done and info are equals between the two envs
  41. """
  42. # Don't check rollout equality if it's a nondeterministic environment.
  43. if env_spec.nondeterministic is True:
  44. return
  45. env_1 = env_spec.make(disable_env_checker=True)
  46. env_2 = env_spec.make(disable_env_checker=True)
  47. initial_obs_1 = env_1.reset(seed=SEED)
  48. initial_obs_2 = env_2.reset(seed=SEED)
  49. assert_equals(initial_obs_1, initial_obs_2)
  50. env_1.action_space.seed(SEED)
  51. for time_step in range(NUM_STEPS):
  52. # We don't evaluate the determinism of actions
  53. action = env_1.action_space.sample()
  54. obs_1, rew_1, done_1, info_1 = env_1.step(action)
  55. obs_2, rew_2, done_2, info_2 = env_2.step(action)
  56. assert_equals(obs_1, obs_2, f"[{time_step}] ")
  57. assert env_1.observation_space.contains(
  58. obs_1
  59. ) # obs_2 verified by previous assertion
  60. assert rew_1 == rew_2, f"[{time_step}] reward 1={rew_1}, reward 2={rew_2}"
  61. assert done_1 == done_2, f"[{time_step}] done 1={done_1}, done 2={done_2}"
  62. assert_equals(info_1, info_2, f"[{time_step}] ")
  63. if done_1: # done_2 verified by previous assertion
  64. env_1.reset(seed=SEED)
  65. env_2.reset(seed=SEED)
  66. env_1.close()
  67. env_2.close()
  68. @pytest.mark.parametrize(
  69. "spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
  70. )
  71. def test_render_modes(spec):
  72. env = spec.make()
  73. for mode in env.metadata.get("render_modes", []):
  74. if mode != "human":
  75. new_env = spec.make(render_mode=mode)
  76. new_env.reset()
  77. new_env.step(new_env.action_space.sample())
  78. new_env.render()