test_envs.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306
  1. from __future__ import annotations
  2. import pickle
  3. import warnings
  4. import gymnasium as gym
  5. import numpy as np
  6. import pytest
  7. from gymnasium.envs.registration import EnvSpec
  8. from gymnasium.utils.env_checker import check_env, data_equivalence
  9. from minigrid.core.grid import Grid
  10. from minigrid.core.mission import MissionSpace
  11. from tests.utils import all_testing_env_specs, assert_equals
  12. CHECK_ENV_IGNORE_WARNINGS = [
  13. f"\x1b[33mWARN: {message}\x1b[0m"
  14. for message in [
  15. "A Box observation space minimum value is -infinity. This is probably too low.",
  16. "A Box observation space maximum value is -infinity. This is probably too high.",
  17. "For Box action spaces, we recommend using a symmetric and normalized space (range=[-1, 1] or [0, 1]). See https://stable-baselines3.readthedocs.io/en/master/guide/rl_tips.html for more information.",
  18. ]
  19. ]
  20. @pytest.mark.parametrize(
  21. "spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
  22. )
  23. def test_env(spec):
  24. # Capture warnings
  25. env = spec.make(disable_env_checker=True).unwrapped
  26. warnings.simplefilter("always")
  27. # Test if env adheres to Gym API
  28. with warnings.catch_warnings(record=True) as w:
  29. check_env(env)
  30. for warning in w:
  31. if warning.message.args[0] not in CHECK_ENV_IGNORE_WARNINGS:
  32. raise gym.error.Error(f"Unexpected warning: {warning.message}")
  33. # Note that this precludes running this test in multiple threads.
  34. # However, we probably already can't do multithreading due to some environments.
  35. SEED = 0
  36. NUM_STEPS = 50
  37. @pytest.mark.parametrize(
  38. "env_spec", all_testing_env_specs, ids=[env.id for env in all_testing_env_specs]
  39. )
  40. def test_env_determinism_rollout(env_spec: EnvSpec):
  41. """Run a rollout with two environments and assert equality.
  42. This test run a rollout of NUM_STEPS steps with two environments
  43. initialized with the same seed and assert that:
  44. - observation after first reset are the same
  45. - same actions are sampled by the two envs
  46. - observations are contained in the observation space
  47. - obs, rew, terminated, truncated and info are equals between the two envs
  48. """
  49. # Don't check rollout equality if it's a nondeterministic environment.
  50. if env_spec.nondeterministic is True:
  51. return
  52. env_1 = env_spec.make(disable_env_checker=True)
  53. env_2 = env_spec.make(disable_env_checker=True)
  54. initial_obs_1 = env_1.reset(seed=SEED)
  55. initial_obs_2 = env_2.reset(seed=SEED)
  56. assert_equals(initial_obs_1, initial_obs_2)
  57. env_1.action_space.seed(SEED)
  58. for time_step in range(NUM_STEPS):
  59. # We don't evaluate the determinism of actions
  60. action = env_1.action_space.sample()
  61. obs_1, rew_1, terminated_1, truncated_1, info_1 = env_1.step(action)
  62. obs_2, rew_2, terminated_2, truncated_2, info_2 = env_2.step(action)
  63. assert_equals(obs_1, obs_2, f"[{time_step}] ")
  64. assert env_1.observation_space.contains(
  65. obs_1
  66. ) # obs_2 verified by previous assertion
  67. assert rew_1 == rew_2, f"[{time_step}] reward 1={rew_1}, reward 2={rew_2}"
  68. assert (
  69. terminated_1 == terminated_2
  70. ), f"[{time_step}] terminated 1={terminated_1}, terminated 2={terminated_2}"
  71. assert (
  72. truncated_1 == truncated_2
  73. ), f"[{time_step}] truncated 1={truncated_1}, truncated 2={truncated_2}"
  74. assert_equals(info_1, info_2, f"[{time_step}] ")
  75. if (
  76. terminated_1 or truncated_1
  77. ): # terminated_2 and truncated_2 verified by previous assertion
  78. env_1.reset(seed=SEED)
  79. env_2.reset(seed=SEED)
  80. env_1.close()
  81. env_2.close()
  82. @pytest.mark.parametrize(
  83. "spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
  84. )
  85. def test_render_modes(spec):
  86. env = spec.make()
  87. for mode in env.metadata.get("render_modes", []):
  88. if mode != "human":
  89. new_env = spec.make(render_mode=mode)
  90. new_env.reset()
  91. new_env.step(new_env.action_space.sample())
  92. new_env.render()
  93. @pytest.mark.parametrize("env_id", ["MiniGrid-DoorKey-6x6-v0"])
  94. def test_agent_sees_method(env_id):
  95. env = gym.make(env_id)
  96. goal_pos = (env.grid.width - 2, env.grid.height - 2)
  97. # Test the env.agent_sees() function
  98. env.reset()
  99. # Test the "in" operator on grid objects
  100. assert ("green", "goal") in env.grid
  101. assert ("blue", "key") not in env.grid
  102. for i in range(0, 500):
  103. action = env.action_space.sample()
  104. obs, reward, terminated, truncated, info = env.step(action)
  105. grid, _ = Grid.decode(obs["image"])
  106. goal_visible = ("green", "goal") in grid
  107. agent_sees_goal = env.agent_sees(*goal_pos)
  108. assert agent_sees_goal == goal_visible
  109. if terminated or truncated:
  110. env.reset()
  111. env.close()
  112. @pytest.mark.parametrize(
  113. "env_spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
  114. )
  115. def test_max_steps_argument(env_spec):
  116. """
  117. Test that when initializing an environment with a fixed number of steps per episode (`max_steps` argument),
  118. the episode will be truncated after taking that number of steps.
  119. """
  120. max_steps = 50
  121. env = env_spec.make(max_steps=max_steps)
  122. env.reset()
  123. step_count = 0
  124. while True:
  125. _, _, terminated, truncated, _ = env.step(4)
  126. step_count += 1
  127. if truncated:
  128. assert step_count == max_steps
  129. step_count = 0
  130. break
  131. env.close()
  132. @pytest.mark.parametrize(
  133. "env_spec",
  134. all_testing_env_specs,
  135. ids=[spec.id for spec in all_testing_env_specs],
  136. )
  137. def test_pickle_env(env_spec):
  138. """Test that all environments are picklable."""
  139. env: gym.Env = env_spec.make()
  140. pickled_env: gym.Env = pickle.loads(pickle.dumps(env))
  141. data_equivalence(env.reset(), pickled_env.reset())
  142. action = env.action_space.sample()
  143. data_equivalence(env.step(action), pickled_env.step(action))
  144. env.close()
  145. pickled_env.close()
  146. @pytest.mark.parametrize(
  147. "env_spec",
  148. all_testing_env_specs,
  149. ids=[spec.id for spec in all_testing_env_specs],
  150. )
  151. def old_run_test(env_spec):
  152. # Load the gym environment
  153. env = env_spec.make()
  154. env.max_steps = min(env.max_steps, 200)
  155. env.reset()
  156. env.render()
  157. # Verify that the same seed always produces the same environment
  158. for i in range(0, 5):
  159. seed = 1337 + i
  160. _ = env.reset(seed=seed)
  161. grid1 = env.grid
  162. _ = env.reset(seed=seed)
  163. grid2 = env.grid
  164. assert grid1 == grid2
  165. env.reset()
  166. # Run for a few episodes
  167. num_episodes = 0
  168. while num_episodes < 5:
  169. # Pick a random action
  170. action = env.action_space.sample()
  171. obs, reward, terminated, truncated, info = env.step(action)
  172. # Validate the agent position
  173. assert env.agent_pos[0] < env.width
  174. assert env.agent_pos[1] < env.height
  175. # Test observation encode/decode roundtrip
  176. img = obs["image"]
  177. grid, vis_mask = Grid.decode(img)
  178. img2 = grid.encode(vis_mask=vis_mask)
  179. assert np.array_equal(img, img2)
  180. # Test the env to string function
  181. str(env)
  182. # Check that the reward is within the specified range
  183. assert reward >= env.reward_range[0], reward
  184. assert reward <= env.reward_range[1], reward
  185. if terminated or truncated:
  186. num_episodes += 1
  187. env.reset()
  188. env.render()
  189. # Test the close method
  190. env.close()
  191. @pytest.mark.parametrize("env_id", ["MiniGrid-Empty-8x8-v0"])
  192. def test_interactive_mode(env_id):
  193. env = gym.make(env_id)
  194. env.reset()
  195. for i in range(0, 100):
  196. print(f"step {i}")
  197. # Pick a random action
  198. action = env.action_space.sample()
  199. obs, reward, terminated, truncated, info = env.step(action)
  200. # Test the close method
  201. env.close()
  202. def test_mission_space():
  203. # Test placeholders
  204. mission_space = MissionSpace(
  205. mission_func=lambda color, obj_type: f"Get the {color} {obj_type}.",
  206. ordered_placeholders=[["green", "red"], ["ball", "key"]],
  207. )
  208. assert mission_space.contains("Get the green ball.")
  209. assert mission_space.contains("Get the red key.")
  210. assert not mission_space.contains("Get the purple box.")
  211. # Test passing inverted placeholders
  212. assert not mission_space.contains("Get the key red.")
  213. # Test passing extra repeated placeholders
  214. assert not mission_space.contains("Get the key red key.")
  215. # Test contained placeholders like "get the" and "go get the". "get the" string is contained in both placeholders.
  216. mission_space = MissionSpace(
  217. mission_func=lambda get_syntax, obj_type: f"{get_syntax} {obj_type}.",
  218. ordered_placeholders=[
  219. ["go get the", "get the", "go fetch the", "fetch the"],
  220. ["ball", "key"],
  221. ],
  222. )
  223. assert mission_space.contains("get the ball.")
  224. assert mission_space.contains("go get the key.")
  225. assert mission_space.contains("go fetch the ball.")
  226. # Test repeated placeholders
  227. mission_space = MissionSpace(
  228. mission_func=lambda get_syntax, color_1, obj_type_1, color_2, obj_type_2: f"{get_syntax} {color_1} {obj_type_1} and the {color_2} {obj_type_2}.",
  229. ordered_placeholders=[
  230. ["go get the", "get the", "go fetch the", "fetch the"],
  231. ["green", "red"],
  232. ["ball", "key"],
  233. ["green", "red"],
  234. ["ball", "key"],
  235. ],
  236. )
  237. assert mission_space.contains("get the green key and the green key.")
  238. assert mission_space.contains("go fetch the red ball and the green key.")