test_envs.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259
  1. import warnings
  2. import gym
  3. import numpy as np
  4. import pytest
  5. from gym.envs.registration import EnvSpec
  6. from gym.utils.env_checker import check_env
  7. from minigrid.minigrid import Grid, MissionSpace
  8. from tests.utils import all_testing_env_specs, assert_equals
  9. CHECK_ENV_IGNORE_WARNINGS = [
  10. f"\x1b[33mWARN: {message}\x1b[0m"
  11. for message in [
  12. "A Box observation space minimum value is -infinity. This is probably too low.",
  13. "A Box observation space maximum value is -infinity. This is probably too high.",
  14. "For Box action spaces, we recommend using a symmetric and normalized space (range=[-1, 1] or [0, 1]). See https://stable-baselines3.readthedocs.io/en/master/guide/rl_tips.html for more information.",
  15. ]
  16. ]
  17. @pytest.mark.parametrize(
  18. "spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
  19. )
  20. def test_env(spec):
  21. # Capture warnings
  22. env = spec.make(disable_env_checker=True).unwrapped
  23. warnings.simplefilter("always")
  24. # Test if env adheres to Gym API
  25. with warnings.catch_warnings(record=True) as w:
  26. check_env(env)
  27. for warning in w:
  28. if warning.message.args[0] not in CHECK_ENV_IGNORE_WARNINGS:
  29. raise gym.error.Error(f"Unexpected warning: {warning.message}")
  30. # Note that this precludes running this test in multiple threads.
  31. # However, we probably already can't do multithreading due to some environments.
  32. SEED = 0
  33. NUM_STEPS = 50
  34. @pytest.mark.parametrize(
  35. "env_spec", all_testing_env_specs, ids=[env.id for env in all_testing_env_specs]
  36. )
  37. def test_env_determinism_rollout(env_spec: EnvSpec):
  38. """Run a rollout with two environments and assert equality.
  39. This test run a rollout of NUM_STEPS steps with two environments
  40. initialized with the same seed and assert that:
  41. - observation after first reset are the same
  42. - same actions are sampled by the two envs
  43. - observations are contained in the observation space
  44. - obs, rew, terminated, truncated and info are equals between the two envs
  45. """
  46. # Don't check rollout equality if it's a nondeterministic environment.
  47. if env_spec.nondeterministic is True:
  48. return
  49. env_1 = env_spec.make(disable_env_checker=True)
  50. env_2 = env_spec.make(disable_env_checker=True)
  51. initial_obs_1 = env_1.reset(seed=SEED)
  52. initial_obs_2 = env_2.reset(seed=SEED)
  53. assert_equals(initial_obs_1, initial_obs_2)
  54. env_1.action_space.seed(SEED)
  55. for time_step in range(NUM_STEPS):
  56. # We don't evaluate the determinism of actions
  57. action = env_1.action_space.sample()
  58. obs_1, rew_1, terminated_1, truncated_1, info_1 = env_1.step(action)
  59. obs_2, rew_2, terminated_2, truncated_2, info_2 = env_2.step(action)
  60. assert_equals(obs_1, obs_2, f"[{time_step}] ")
  61. assert env_1.observation_space.contains(
  62. obs_1
  63. ) # obs_2 verified by previous assertion
  64. assert rew_1 == rew_2, f"[{time_step}] reward 1={rew_1}, reward 2={rew_2}"
  65. assert (
  66. terminated_1 == terminated_2
  67. ), f"[{time_step}] terminated 1={terminated_1}, terminated 2={terminated_2}"
  68. assert (
  69. truncated_1 == truncated_2
  70. ), f"[{time_step}] truncated 1={truncated_1}, truncated 2={truncated_2}"
  71. assert_equals(info_1, info_2, f"[{time_step}] ")
  72. if (
  73. terminated_1 or truncated_1
  74. ): # terminated_2 and truncated_2 verified by previous assertion
  75. env_1.reset(seed=SEED)
  76. env_2.reset(seed=SEED)
  77. env_1.close()
  78. env_2.close()
  79. @pytest.mark.parametrize(
  80. "spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
  81. )
  82. def test_render_modes(spec):
  83. env = spec.make()
  84. for mode in env.metadata.get("render_modes", []):
  85. if mode != "human":
  86. new_env = spec.make(render_mode=mode)
  87. new_env.reset()
  88. new_env.step(new_env.action_space.sample())
  89. new_env.render()
  90. @pytest.mark.parametrize("env_id", ["MiniGrid-DoorKey-6x6-v0"])
  91. def test_agent_sees_method(env_id):
  92. env = gym.make(env_id)
  93. goal_pos = (env.grid.width - 2, env.grid.height - 2)
  94. # Test the env.agent_sees() function
  95. env.reset()
  96. # Test the "in" operator on grid objects
  97. assert ("green", "goal") in env.grid
  98. assert ("blue", "key") not in env.grid
  99. for i in range(0, 500):
  100. action = env.action_space.sample()
  101. obs, reward, terminated, truncated, info = env.step(action)
  102. grid, _ = Grid.decode(obs["image"])
  103. goal_visible = ("green", "goal") in grid
  104. agent_sees_goal = env.agent_sees(*goal_pos)
  105. assert agent_sees_goal == goal_visible
  106. if terminated or truncated:
  107. env.reset()
  108. env.close()
  109. @pytest.mark.parametrize(
  110. "env_spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
  111. )
  112. def old_run_test(env_spec):
  113. # Load the gym environment
  114. env = env_spec.make()
  115. env.max_steps = min(env.max_steps, 200)
  116. env.reset()
  117. env.render()
  118. # Verify that the same seed always produces the same environment
  119. for i in range(0, 5):
  120. seed = 1337 + i
  121. _ = env.reset(seed=seed)
  122. grid1 = env.grid
  123. _ = env.reset(seed=seed)
  124. grid2 = env.grid
  125. assert grid1 == grid2
  126. env.reset()
  127. # Run for a few episodes
  128. num_episodes = 0
  129. while num_episodes < 5:
  130. # Pick a random action
  131. action = env.action_space.sample()
  132. obs, reward, terminated, truncated, info = env.step(action)
  133. # Validate the agent position
  134. assert env.agent_pos[0] < env.width
  135. assert env.agent_pos[1] < env.height
  136. # Test observation encode/decode roundtrip
  137. img = obs["image"]
  138. grid, vis_mask = Grid.decode(img)
  139. img2 = grid.encode(vis_mask=vis_mask)
  140. assert np.array_equal(img, img2)
  141. # Test the env to string function
  142. str(env)
  143. # Check that the reward is within the specified range
  144. assert reward >= env.reward_range[0], reward
  145. assert reward <= env.reward_range[1], reward
  146. if terminated or truncated:
  147. num_episodes += 1
  148. env.reset()
  149. env.render()
  150. # Test the close method
  151. env.close()
  152. @pytest.mark.parametrize("env_id", ["MiniGrid-Empty-8x8-v0"])
  153. def test_interactive_mode(env_id):
  154. env = gym.make(env_id)
  155. env.reset()
  156. for i in range(0, 100):
  157. print(f"step {i}")
  158. # Pick a random action
  159. action = env.action_space.sample()
  160. obs, reward, terminated, truncated, info = env.step(action)
  161. # Test the close method
  162. env.close()
  163. def test_mission_space():
  164. # Test placeholders
  165. mission_space = MissionSpace(
  166. mission_func=lambda color, obj_type: f"Get the {color} {obj_type}.",
  167. ordered_placeholders=[["green", "red"], ["ball", "key"]],
  168. )
  169. assert mission_space.contains("Get the green ball.")
  170. assert mission_space.contains("Get the red key.")
  171. assert not mission_space.contains("Get the purple box.")
  172. # Test passing inverted placeholders
  173. assert not mission_space.contains("Get the key red.")
  174. # Test passing extra repeated placeholders
  175. assert not mission_space.contains("Get the key red key.")
  176. # Test contained placeholders like "get the" and "go get the". "get the" string is contained in both placeholders.
  177. mission_space = MissionSpace(
  178. mission_func=lambda get_syntax, obj_type: f"{get_syntax} {obj_type}.",
  179. ordered_placeholders=[
  180. ["go get the", "get the", "go fetch the", "fetch the"],
  181. ["ball", "key"],
  182. ],
  183. )
  184. assert mission_space.contains("get the ball.")
  185. assert mission_space.contains("go get the key.")
  186. assert mission_space.contains("go fetch the ball.")
  187. # Test repeated placeholders
  188. mission_space = MissionSpace(
  189. mission_func=lambda get_syntax, color_1, obj_type_1, color_2, obj_type_2: f"{get_syntax} {color_1} {obj_type_1} and the {color_2} {obj_type_2}.",
  190. ordered_placeholders=[
  191. ["go get the", "get the", "go fetch the", "fetch the"],
  192. ["green", "red"],
  193. ["ball", "key"],
  194. ["green", "red"],
  195. ["ball", "key"],
  196. ],
  197. )
  198. assert mission_space.contains("get the green key and the green key.")
  199. assert mission_space.contains("go fetch the red ball and the green key.")