test_envs.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260
  1. import warnings
  2. import gymnasium as gym
  3. import numpy as np
  4. import pytest
  5. from gymnasium.envs.registration import EnvSpec
  6. from gymnasium.utils.env_checker import check_env
  7. from minigrid.core.grid import Grid
  8. from minigrid.core.mission import MissionSpace
  9. from tests.utils import all_testing_env_specs, assert_equals
  10. CHECK_ENV_IGNORE_WARNINGS = [
  11. f"\x1b[33mWARN: {message}\x1b[0m"
  12. for message in [
  13. "A Box observation space minimum value is -infinity. This is probably too low.",
  14. "A Box observation space maximum value is -infinity. This is probably too high.",
  15. "For Box action spaces, we recommend using a symmetric and normalized space (range=[-1, 1] or [0, 1]). See https://stable-baselines3.readthedocs.io/en/master/guide/rl_tips.html for more information.",
  16. ]
  17. ]
  18. @pytest.mark.parametrize(
  19. "spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
  20. )
  21. def test_env(spec):
  22. # Capture warnings
  23. env = spec.make(disable_env_checker=True).unwrapped
  24. warnings.simplefilter("always")
  25. # Test if env adheres to Gym API
  26. with warnings.catch_warnings(record=True) as w:
  27. check_env(env)
  28. for warning in w:
  29. if warning.message.args[0] not in CHECK_ENV_IGNORE_WARNINGS:
  30. raise gym.error.Error(f"Unexpected warning: {warning.message}")
  31. # Note that this precludes running this test in multiple threads.
  32. # However, we probably already can't do multithreading due to some environments.
  33. SEED = 0
  34. NUM_STEPS = 50
  35. @pytest.mark.parametrize(
  36. "env_spec", all_testing_env_specs, ids=[env.id for env in all_testing_env_specs]
  37. )
  38. def test_env_determinism_rollout(env_spec: EnvSpec):
  39. """Run a rollout with two environments and assert equality.
  40. This test run a rollout of NUM_STEPS steps with two environments
  41. initialized with the same seed and assert that:
  42. - observation after first reset are the same
  43. - same actions are sampled by the two envs
  44. - observations are contained in the observation space
  45. - obs, rew, terminated, truncated and info are equals between the two envs
  46. """
  47. # Don't check rollout equality if it's a nondeterministic environment.
  48. if env_spec.nondeterministic is True:
  49. return
  50. env_1 = env_spec.make(disable_env_checker=True)
  51. env_2 = env_spec.make(disable_env_checker=True)
  52. initial_obs_1 = env_1.reset(seed=SEED)
  53. initial_obs_2 = env_2.reset(seed=SEED)
  54. assert_equals(initial_obs_1, initial_obs_2)
  55. env_1.action_space.seed(SEED)
  56. for time_step in range(NUM_STEPS):
  57. # We don't evaluate the determinism of actions
  58. action = env_1.action_space.sample()
  59. obs_1, rew_1, terminated_1, truncated_1, info_1 = env_1.step(action)
  60. obs_2, rew_2, terminated_2, truncated_2, info_2 = env_2.step(action)
  61. assert_equals(obs_1, obs_2, f"[{time_step}] ")
  62. assert env_1.observation_space.contains(
  63. obs_1
  64. ) # obs_2 verified by previous assertion
  65. assert rew_1 == rew_2, f"[{time_step}] reward 1={rew_1}, reward 2={rew_2}"
  66. assert (
  67. terminated_1 == terminated_2
  68. ), f"[{time_step}] terminated 1={terminated_1}, terminated 2={terminated_2}"
  69. assert (
  70. truncated_1 == truncated_2
  71. ), f"[{time_step}] truncated 1={truncated_1}, truncated 2={truncated_2}"
  72. assert_equals(info_1, info_2, f"[{time_step}] ")
  73. if (
  74. terminated_1 or truncated_1
  75. ): # terminated_2 and truncated_2 verified by previous assertion
  76. env_1.reset(seed=SEED)
  77. env_2.reset(seed=SEED)
  78. env_1.close()
  79. env_2.close()
  80. @pytest.mark.parametrize(
  81. "spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
  82. )
  83. def test_render_modes(spec):
  84. env = spec.make()
  85. for mode in env.metadata.get("render_modes", []):
  86. if mode != "human":
  87. new_env = spec.make(render_mode=mode)
  88. new_env.reset()
  89. new_env.step(new_env.action_space.sample())
  90. new_env.render()
  91. @pytest.mark.parametrize("env_id", ["MiniGrid-DoorKey-6x6-v0"])
  92. def test_agent_sees_method(env_id):
  93. env = gym.make(env_id)
  94. goal_pos = (env.grid.width - 2, env.grid.height - 2)
  95. # Test the env.agent_sees() function
  96. env.reset()
  97. # Test the "in" operator on grid objects
  98. assert ("green", "goal") in env.grid
  99. assert ("blue", "key") not in env.grid
  100. for i in range(0, 500):
  101. action = env.action_space.sample()
  102. obs, reward, terminated, truncated, info = env.step(action)
  103. grid, _ = Grid.decode(obs["image"])
  104. goal_visible = ("green", "goal") in grid
  105. agent_sees_goal = env.agent_sees(*goal_pos)
  106. assert agent_sees_goal == goal_visible
  107. if terminated or truncated:
  108. env.reset()
  109. env.close()
  110. @pytest.mark.parametrize(
  111. "env_spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
  112. )
  113. def old_run_test(env_spec):
  114. # Load the gym environment
  115. env = env_spec.make()
  116. env.max_steps = min(env.max_steps, 200)
  117. env.reset()
  118. env.render()
  119. # Verify that the same seed always produces the same environment
  120. for i in range(0, 5):
  121. seed = 1337 + i
  122. _ = env.reset(seed=seed)
  123. grid1 = env.grid
  124. _ = env.reset(seed=seed)
  125. grid2 = env.grid
  126. assert grid1 == grid2
  127. env.reset()
  128. # Run for a few episodes
  129. num_episodes = 0
  130. while num_episodes < 5:
  131. # Pick a random action
  132. action = env.action_space.sample()
  133. obs, reward, terminated, truncated, info = env.step(action)
  134. # Validate the agent position
  135. assert env.agent_pos[0] < env.width
  136. assert env.agent_pos[1] < env.height
  137. # Test observation encode/decode roundtrip
  138. img = obs["image"]
  139. grid, vis_mask = Grid.decode(img)
  140. img2 = grid.encode(vis_mask=vis_mask)
  141. assert np.array_equal(img, img2)
  142. # Test the env to string function
  143. str(env)
  144. # Check that the reward is within the specified range
  145. assert reward >= env.reward_range[0], reward
  146. assert reward <= env.reward_range[1], reward
  147. if terminated or truncated:
  148. num_episodes += 1
  149. env.reset()
  150. env.render()
  151. # Test the close method
  152. env.close()
  153. @pytest.mark.parametrize("env_id", ["MiniGrid-Empty-8x8-v0"])
  154. def test_interactive_mode(env_id):
  155. env = gym.make(env_id)
  156. env.reset()
  157. for i in range(0, 100):
  158. print(f"step {i}")
  159. # Pick a random action
  160. action = env.action_space.sample()
  161. obs, reward, terminated, truncated, info = env.step(action)
  162. # Test the close method
  163. env.close()
  164. def test_mission_space():
  165. # Test placeholders
  166. mission_space = MissionSpace(
  167. mission_func=lambda color, obj_type: f"Get the {color} {obj_type}.",
  168. ordered_placeholders=[["green", "red"], ["ball", "key"]],
  169. )
  170. assert mission_space.contains("Get the green ball.")
  171. assert mission_space.contains("Get the red key.")
  172. assert not mission_space.contains("Get the purple box.")
  173. # Test passing inverted placeholders
  174. assert not mission_space.contains("Get the key red.")
  175. # Test passing extra repeated placeholders
  176. assert not mission_space.contains("Get the key red key.")
  177. # Test contained placeholders like "get the" and "go get the". "get the" string is contained in both placeholders.
  178. mission_space = MissionSpace(
  179. mission_func=lambda get_syntax, obj_type: f"{get_syntax} {obj_type}.",
  180. ordered_placeholders=[
  181. ["go get the", "get the", "go fetch the", "fetch the"],
  182. ["ball", "key"],
  183. ],
  184. )
  185. assert mission_space.contains("get the ball.")
  186. assert mission_space.contains("go get the key.")
  187. assert mission_space.contains("go fetch the ball.")
  188. # Test repeated placeholders
  189. mission_space = MissionSpace(
  190. mission_func=lambda get_syntax, color_1, obj_type_1, color_2, obj_type_2: f"{get_syntax} {color_1} {obj_type_1} and the {color_2} {obj_type_2}.",
  191. ordered_placeholders=[
  192. ["go get the", "get the", "go fetch the", "fetch the"],
  193. ["green", "red"],
  194. ["ball", "key"],
  195. ["green", "red"],
  196. ["ball", "key"],
  197. ],
  198. )
  199. assert mission_space.contains("get the green key and the green key.")
  200. assert mission_space.contains("go fetch the red ball and the green key.")