test_envs.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. import pickle
  2. import warnings
  3. import gymnasium as gym
  4. import numpy as np
  5. import pytest
  6. from gymnasium.envs.registration import EnvSpec
  7. from gymnasium.utils.env_checker import check_env, data_equivalence
  8. from minigrid.core.grid import Grid
  9. from minigrid.core.mission import MissionSpace
  10. from tests.utils import all_testing_env_specs, assert_equals
  11. CHECK_ENV_IGNORE_WARNINGS = [
  12. f"\x1b[33mWARN: {message}\x1b[0m"
  13. for message in [
  14. "A Box observation space minimum value is -infinity. This is probably too low.",
  15. "A Box observation space maximum value is -infinity. This is probably too high.",
  16. "For Box action spaces, we recommend using a symmetric and normalized space (range=[-1, 1] or [0, 1]). See https://stable-baselines3.readthedocs.io/en/master/guide/rl_tips.html for more information.",
  17. ]
  18. ]
  19. @pytest.mark.parametrize(
  20. "spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
  21. )
  22. def test_env(spec):
  23. # Capture warnings
  24. env = spec.make(disable_env_checker=True).unwrapped
  25. warnings.simplefilter("always")
  26. # Test if env adheres to Gym API
  27. with warnings.catch_warnings(record=True) as w:
  28. check_env(env)
  29. for warning in w:
  30. if warning.message.args[0] not in CHECK_ENV_IGNORE_WARNINGS:
  31. raise gym.error.Error(f"Unexpected warning: {warning.message}")
  32. # Note that this precludes running this test in multiple threads.
  33. # However, we probably already can't do multithreading due to some environments.
  34. SEED = 0
  35. NUM_STEPS = 50
  36. @pytest.mark.parametrize(
  37. "env_spec", all_testing_env_specs, ids=[env.id for env in all_testing_env_specs]
  38. )
  39. def test_env_determinism_rollout(env_spec: EnvSpec):
  40. """Run a rollout with two environments and assert equality.
  41. This test run a rollout of NUM_STEPS steps with two environments
  42. initialized with the same seed and assert that:
  43. - observation after first reset are the same
  44. - same actions are sampled by the two envs
  45. - observations are contained in the observation space
  46. - obs, rew, terminated, truncated and info are equals between the two envs
  47. """
  48. # Don't check rollout equality if it's a nondeterministic environment.
  49. if env_spec.nondeterministic is True:
  50. return
  51. env_1 = env_spec.make(disable_env_checker=True)
  52. env_2 = env_spec.make(disable_env_checker=True)
  53. initial_obs_1 = env_1.reset(seed=SEED)
  54. initial_obs_2 = env_2.reset(seed=SEED)
  55. assert_equals(initial_obs_1, initial_obs_2)
  56. env_1.action_space.seed(SEED)
  57. for time_step in range(NUM_STEPS):
  58. # We don't evaluate the determinism of actions
  59. action = env_1.action_space.sample()
  60. obs_1, rew_1, terminated_1, truncated_1, info_1 = env_1.step(action)
  61. obs_2, rew_2, terminated_2, truncated_2, info_2 = env_2.step(action)
  62. assert_equals(obs_1, obs_2, f"[{time_step}] ")
  63. assert env_1.observation_space.contains(
  64. obs_1
  65. ) # obs_2 verified by previous assertion
  66. assert rew_1 == rew_2, f"[{time_step}] reward 1={rew_1}, reward 2={rew_2}"
  67. assert (
  68. terminated_1 == terminated_2
  69. ), f"[{time_step}] terminated 1={terminated_1}, terminated 2={terminated_2}"
  70. assert (
  71. truncated_1 == truncated_2
  72. ), f"[{time_step}] truncated 1={truncated_1}, truncated 2={truncated_2}"
  73. assert_equals(info_1, info_2, f"[{time_step}] ")
  74. if (
  75. terminated_1 or truncated_1
  76. ): # terminated_2 and truncated_2 verified by previous assertion
  77. env_1.reset(seed=SEED)
  78. env_2.reset(seed=SEED)
  79. env_1.close()
  80. env_2.close()
  81. @pytest.mark.parametrize(
  82. "spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
  83. )
  84. def test_render_modes(spec):
  85. env = spec.make()
  86. for mode in env.metadata.get("render_modes", []):
  87. if mode != "human":
  88. new_env = spec.make(render_mode=mode)
  89. new_env.reset()
  90. new_env.step(new_env.action_space.sample())
  91. new_env.render()
  92. @pytest.mark.parametrize("env_id", ["MiniGrid-DoorKey-6x6-v0"])
  93. def test_agent_sees_method(env_id):
  94. env = gym.make(env_id)
  95. goal_pos = (env.grid.width - 2, env.grid.height - 2)
  96. # Test the env.agent_sees() function
  97. env.reset()
  98. # Test the "in" operator on grid objects
  99. assert ("green", "goal") in env.grid
  100. assert ("blue", "key") not in env.grid
  101. for i in range(0, 500):
  102. action = env.action_space.sample()
  103. obs, reward, terminated, truncated, info = env.step(action)
  104. grid, _ = Grid.decode(obs["image"])
  105. goal_visible = ("green", "goal") in grid
  106. agent_sees_goal = env.agent_sees(*goal_pos)
  107. assert agent_sees_goal == goal_visible
  108. if terminated or truncated:
  109. env.reset()
  110. env.close()
  111. @pytest.mark.parametrize(
  112. "env_spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
  113. )
  114. def test_max_steps_argument(env_spec):
  115. """
  116. Test that when initializing an environment with a fixed number of steps per episode (`max_steps` argument),
  117. the episode will be truncated after taking that number of steps.
  118. """
  119. max_steps = 50
  120. env = env_spec.make(max_steps=max_steps)
  121. env.reset()
  122. step_count = 0
  123. while True:
  124. _, _, terminated, truncated, _ = env.step(4)
  125. step_count += 1
  126. if truncated:
  127. assert step_count == max_steps
  128. step_count = 0
  129. break
  130. env.close()
  131. @pytest.mark.parametrize(
  132. "env_spec",
  133. all_testing_env_specs,
  134. ids=[spec.id for spec in all_testing_env_specs],
  135. )
  136. def test_pickle_env(env_spec):
  137. env: gym.Env = env_spec.make()
  138. pickled_env: gym.Env = pickle.loads(pickle.dumps(env))
  139. data_equivalence(env.reset(), pickled_env.reset())
  140. action = env.action_space.sample()
  141. data_equivalence(env.step(action), pickled_env.step(action))
  142. env.close()
  143. pickled_env.close()
  144. @pytest.mark.parametrize(
  145. "env_spec",
  146. all_testing_env_specs,
  147. ids=[spec.id for spec in all_testing_env_specs],
  148. )
  149. def old_run_test(env_spec):
  150. # Load the gym environment
  151. env = env_spec.make()
  152. env.max_steps = min(env.max_steps, 200)
  153. env.reset()
  154. env.render()
  155. # Verify that the same seed always produces the same environment
  156. for i in range(0, 5):
  157. seed = 1337 + i
  158. _ = env.reset(seed=seed)
  159. grid1 = env.grid
  160. _ = env.reset(seed=seed)
  161. grid2 = env.grid
  162. assert grid1 == grid2
  163. env.reset()
  164. # Run for a few episodes
  165. num_episodes = 0
  166. while num_episodes < 5:
  167. # Pick a random action
  168. action = env.action_space.sample()
  169. obs, reward, terminated, truncated, info = env.step(action)
  170. # Validate the agent position
  171. assert env.agent_pos[0] < env.width
  172. assert env.agent_pos[1] < env.height
  173. # Test observation encode/decode roundtrip
  174. img = obs["image"]
  175. grid, vis_mask = Grid.decode(img)
  176. img2 = grid.encode(vis_mask=vis_mask)
  177. assert np.array_equal(img, img2)
  178. # Test the env to string function
  179. str(env)
  180. # Check that the reward is within the specified range
  181. assert reward >= env.reward_range[0], reward
  182. assert reward <= env.reward_range[1], reward
  183. if terminated or truncated:
  184. num_episodes += 1
  185. env.reset()
  186. env.render()
  187. # Test the close method
  188. env.close()
  189. @pytest.mark.parametrize("env_id", ["MiniGrid-Empty-8x8-v0"])
  190. def test_interactive_mode(env_id):
  191. env = gym.make(env_id)
  192. env.reset()
  193. for i in range(0, 100):
  194. print(f"step {i}")
  195. # Pick a random action
  196. action = env.action_space.sample()
  197. obs, reward, terminated, truncated, info = env.step(action)
  198. # Test the close method
  199. env.close()
  200. def test_mission_space():
  201. # Test placeholders
  202. mission_space = MissionSpace(
  203. mission_func=lambda color, obj_type: f"Get the {color} {obj_type}.",
  204. ordered_placeholders=[["green", "red"], ["ball", "key"]],
  205. )
  206. assert mission_space.contains("Get the green ball.")
  207. assert mission_space.contains("Get the red key.")
  208. assert not mission_space.contains("Get the purple box.")
  209. # Test passing inverted placeholders
  210. assert not mission_space.contains("Get the key red.")
  211. # Test passing extra repeated placeholders
  212. assert not mission_space.contains("Get the key red key.")
  213. # Test contained placeholders like "get the" and "go get the". "get the" string is contained in both placeholders.
  214. mission_space = MissionSpace(
  215. mission_func=lambda get_syntax, obj_type: f"{get_syntax} {obj_type}.",
  216. ordered_placeholders=[
  217. ["go get the", "get the", "go fetch the", "fetch the"],
  218. ["ball", "key"],
  219. ],
  220. )
  221. assert mission_space.contains("get the ball.")
  222. assert mission_space.contains("go get the key.")
  223. assert mission_space.contains("go fetch the ball.")
  224. # Test repeated placeholders
  225. mission_space = MissionSpace(
  226. mission_func=lambda get_syntax, color_1, obj_type_1, color_2, obj_type_2: f"{get_syntax} {color_1} {obj_type_1} and the {color_2} {obj_type_2}.",
  227. ordered_placeholders=[
  228. ["go get the", "get the", "go fetch the", "fetch the"],
  229. ["green", "red"],
  230. ["ball", "key"],
  231. ["green", "red"],
  232. ["ball", "key"],
  233. ],
  234. )
  235. assert mission_space.contains("get the green key and the green key.")
  236. assert mission_space.contains("go fetch the red ball and the green key.")