test_envs.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283
  1. import warnings
  2. import gymnasium as gym
  3. import numpy as np
  4. import pytest
  5. from gymnasium.envs.registration import EnvSpec
  6. from gymnasium.utils.env_checker import check_env
  7. from minigrid.core.grid import Grid
  8. from minigrid.core.mission import MissionSpace
  9. from tests.utils import all_testing_env_specs, assert_equals
  10. CHECK_ENV_IGNORE_WARNINGS = [
  11. f"\x1b[33mWARN: {message}\x1b[0m"
  12. for message in [
  13. "A Box observation space minimum value is -infinity. This is probably too low.",
  14. "A Box observation space maximum value is -infinity. This is probably too high.",
  15. "For Box action spaces, we recommend using a symmetric and normalized space (range=[-1, 1] or [0, 1]). See https://stable-baselines3.readthedocs.io/en/master/guide/rl_tips.html for more information.",
  16. ]
  17. ]
  18. @pytest.mark.parametrize(
  19. "spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
  20. )
  21. def test_env(spec):
  22. # Capture warnings
  23. env = spec.make(disable_env_checker=True).unwrapped
  24. warnings.simplefilter("always")
  25. # Test if env adheres to Gym API
  26. with warnings.catch_warnings(record=True) as w:
  27. check_env(env)
  28. for warning in w:
  29. if warning.message.args[0] not in CHECK_ENV_IGNORE_WARNINGS:
  30. raise gym.error.Error(f"Unexpected warning: {warning.message}")
  31. # Note that this precludes running this test in multiple threads.
  32. # However, we probably already can't do multithreading due to some environments.
  33. SEED = 0
  34. NUM_STEPS = 50
  35. @pytest.mark.parametrize(
  36. "env_spec", all_testing_env_specs, ids=[env.id for env in all_testing_env_specs]
  37. )
  38. def test_env_determinism_rollout(env_spec: EnvSpec):
  39. """Run a rollout with two environments and assert equality.
  40. This test run a rollout of NUM_STEPS steps with two environments
  41. initialized with the same seed and assert that:
  42. - observation after first reset are the same
  43. - same actions are sampled by the two envs
  44. - observations are contained in the observation space
  45. - obs, rew, terminated, truncated and info are equals between the two envs
  46. """
  47. # Don't check rollout equality if it's a nondeterministic environment.
  48. if env_spec.nondeterministic is True:
  49. return
  50. env_1 = env_spec.make(disable_env_checker=True)
  51. env_2 = env_spec.make(disable_env_checker=True)
  52. initial_obs_1 = env_1.reset(seed=SEED)
  53. initial_obs_2 = env_2.reset(seed=SEED)
  54. assert_equals(initial_obs_1, initial_obs_2)
  55. env_1.action_space.seed(SEED)
  56. for time_step in range(NUM_STEPS):
  57. # We don't evaluate the determinism of actions
  58. action = env_1.action_space.sample()
  59. obs_1, rew_1, terminated_1, truncated_1, info_1 = env_1.step(action)
  60. obs_2, rew_2, terminated_2, truncated_2, info_2 = env_2.step(action)
  61. assert_equals(obs_1, obs_2, f"[{time_step}] ")
  62. assert env_1.observation_space.contains(
  63. obs_1
  64. ) # obs_2 verified by previous assertion
  65. assert rew_1 == rew_2, f"[{time_step}] reward 1={rew_1}, reward 2={rew_2}"
  66. assert (
  67. terminated_1 == terminated_2
  68. ), f"[{time_step}] terminated 1={terminated_1}, terminated 2={terminated_2}"
  69. assert (
  70. truncated_1 == truncated_2
  71. ), f"[{time_step}] truncated 1={truncated_1}, truncated 2={truncated_2}"
  72. assert_equals(info_1, info_2, f"[{time_step}] ")
  73. if (
  74. terminated_1 or truncated_1
  75. ): # terminated_2 and truncated_2 verified by previous assertion
  76. env_1.reset(seed=SEED)
  77. env_2.reset(seed=SEED)
  78. env_1.close()
  79. env_2.close()
  80. @pytest.mark.parametrize(
  81. "spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
  82. )
  83. def test_render_modes(spec):
  84. env = spec.make()
  85. for mode in env.metadata.get("render_modes", []):
  86. if mode != "human":
  87. new_env = spec.make(render_mode=mode)
  88. new_env.reset()
  89. new_env.step(new_env.action_space.sample())
  90. new_env.render()
  91. @pytest.mark.parametrize("env_id", ["MiniGrid-DoorKey-6x6-v0"])
  92. def test_agent_sees_method(env_id):
  93. env = gym.make(env_id)
  94. goal_pos = (env.grid.width - 2, env.grid.height - 2)
  95. # Test the env.agent_sees() function
  96. env.reset()
  97. # Test the "in" operator on grid objects
  98. assert ("green", "goal") in env.grid
  99. assert ("blue", "key") not in env.grid
  100. for i in range(0, 500):
  101. action = env.action_space.sample()
  102. obs, reward, terminated, truncated, info = env.step(action)
  103. grid, _ = Grid.decode(obs["image"])
  104. goal_visible = ("green", "goal") in grid
  105. agent_sees_goal = env.agent_sees(*goal_pos)
  106. assert agent_sees_goal == goal_visible
  107. if terminated or truncated:
  108. env.reset()
  109. env.close()
  110. @pytest.mark.parametrize(
  111. "env_spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
  112. )
  113. def test_max_steps_argument(env_spec):
  114. """
  115. Test that when initializing an environment with a fixed number of steps per episode (`max_steps` argument),
  116. the episode will be truncated after taking that number of steps.
  117. """
  118. max_steps = 50
  119. env = env_spec.make(max_steps=max_steps)
  120. env.reset()
  121. step_count = 0
  122. while True:
  123. _, _, terminated, truncated, _ = env.step(4)
  124. step_count += 1
  125. if truncated:
  126. assert step_count == max_steps
  127. step_count = 0
  128. break
  129. env.close()
  130. @pytest.mark.parametrize(
  131. "env_spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
  132. )
  133. def old_run_test(env_spec):
  134. # Load the gym environment
  135. env = env_spec.make()
  136. env.max_steps = min(env.max_steps, 200)
  137. env.reset()
  138. env.render()
  139. # Verify that the same seed always produces the same environment
  140. for i in range(0, 5):
  141. seed = 1337 + i
  142. _ = env.reset(seed=seed)
  143. grid1 = env.grid
  144. _ = env.reset(seed=seed)
  145. grid2 = env.grid
  146. assert grid1 == grid2
  147. env.reset()
  148. # Run for a few episodes
  149. num_episodes = 0
  150. while num_episodes < 5:
  151. # Pick a random action
  152. action = env.action_space.sample()
  153. obs, reward, terminated, truncated, info = env.step(action)
  154. # Validate the agent position
  155. assert env.agent_pos[0] < env.width
  156. assert env.agent_pos[1] < env.height
  157. # Test observation encode/decode roundtrip
  158. img = obs["image"]
  159. grid, vis_mask = Grid.decode(img)
  160. img2 = grid.encode(vis_mask=vis_mask)
  161. assert np.array_equal(img, img2)
  162. # Test the env to string function
  163. str(env)
  164. # Check that the reward is within the specified range
  165. assert reward >= env.reward_range[0], reward
  166. assert reward <= env.reward_range[1], reward
  167. if terminated or truncated:
  168. num_episodes += 1
  169. env.reset()
  170. env.render()
  171. # Test the close method
  172. env.close()
  173. @pytest.mark.parametrize("env_id", ["MiniGrid-Empty-8x8-v0"])
  174. def test_interactive_mode(env_id):
  175. env = gym.make(env_id)
  176. env.reset()
  177. for i in range(0, 100):
  178. print(f"step {i}")
  179. # Pick a random action
  180. action = env.action_space.sample()
  181. obs, reward, terminated, truncated, info = env.step(action)
  182. # Test the close method
  183. env.close()
  184. def test_mission_space():
  185. # Test placeholders
  186. mission_space = MissionSpace(
  187. mission_func=lambda color, obj_type: f"Get the {color} {obj_type}.",
  188. ordered_placeholders=[["green", "red"], ["ball", "key"]],
  189. )
  190. assert mission_space.contains("Get the green ball.")
  191. assert mission_space.contains("Get the red key.")
  192. assert not mission_space.contains("Get the purple box.")
  193. # Test passing inverted placeholders
  194. assert not mission_space.contains("Get the key red.")
  195. # Test passing extra repeated placeholders
  196. assert not mission_space.contains("Get the key red key.")
  197. # Test contained placeholders like "get the" and "go get the". "get the" string is contained in both placeholders.
  198. mission_space = MissionSpace(
  199. mission_func=lambda get_syntax, obj_type: f"{get_syntax} {obj_type}.",
  200. ordered_placeholders=[
  201. ["go get the", "get the", "go fetch the", "fetch the"],
  202. ["ball", "key"],
  203. ],
  204. )
  205. assert mission_space.contains("get the ball.")
  206. assert mission_space.contains("go get the key.")
  207. assert mission_space.contains("go fetch the ball.")
  208. # Test repeated placeholders
  209. mission_space = MissionSpace(
  210. mission_func=lambda get_syntax, color_1, obj_type_1, color_2, obj_type_2: f"{get_syntax} {color_1} {obj_type_1} and the {color_2} {obj_type_2}.",
  211. ordered_placeholders=[
  212. ["go get the", "get the", "go fetch the", "fetch the"],
  213. ["green", "red"],
  214. ["ball", "key"],
  215. ["green", "red"],
  216. ["ball", "key"],
  217. ],
  218. )
  219. assert mission_space.contains("get the green key and the green key.")
  220. assert mission_space.contains("go fetch the red ball and the green key.")