test_envs.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. import warnings
  2. import gym
  3. import numpy as np
  4. import pytest
  5. from gym.envs.registration import EnvSpec
  6. from gym.utils.env_checker import check_env
  7. from gym_minigrid.minigrid import Grid, MissionSpace
  8. from tests.utils import all_testing_env_specs, assert_equals
  9. CHECK_ENV_IGNORE_WARNINGS = [
  10. f"\x1b[33mWARN: {message}\x1b[0m"
  11. for message in [
  12. "A Box observation space minimum value is -infinity. This is probably too low.",
  13. "A Box observation space maximum value is -infinity. This is probably too high.",
  14. "For Box action spaces, we recommend using a symmetric and normalized space (range=[-1, 1] or [0, 1]). See https://stable-baselines3.readthedocs.io/en/master/guide/rl_tips.html for more information.",
  15. "Initializing wrapper in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future.",
  16. "Initializing environment in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future.",
  17. "Core environment is written in old step API which returns one bool instead of two. It is recommended to norewrite the environment with new step API. ",
  18. ]
  19. ]
  20. @pytest.mark.parametrize(
  21. "spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
  22. )
  23. def test_env(spec):
  24. # Capture warnings
  25. env = spec.make(disable_env_checker=True).unwrapped
  26. warnings.simplefilter("always")
  27. # Test if env adheres to Gym API
  28. with warnings.catch_warnings(record=True) as w:
  29. check_env(env)
  30. for warning in w:
  31. if warning.message.args[0] not in CHECK_ENV_IGNORE_WARNINGS:
  32. raise gym.error.Error(f"Unexpected warning: {warning.message}")
  33. # Note that this precludes running this test in multiple threads.
  34. # However, we probably already can't do multithreading due to some environments.
  35. SEED = 0
  36. NUM_STEPS = 50
  37. @pytest.mark.parametrize(
  38. "env_spec", all_testing_env_specs, ids=[env.id for env in all_testing_env_specs]
  39. )
  40. def test_env_determinism_rollout(env_spec: EnvSpec):
  41. """Run a rollout with two environments and assert equality.
  42. This test run a rollout of NUM_STEPS steps with two environments
  43. initialized with the same seed and assert that:
  44. - observation after first reset are the same
  45. - same actions are sampled by the two envs
  46. - observations are contained in the observation space
  47. - obs, rew, terminated, truncated and info are equals between the two envs
  48. """
  49. # Don't check rollout equality if it's a nondeterministic environment.
  50. if env_spec.nondeterministic is True:
  51. return
  52. env_1 = env_spec.make(disable_env_checker=True, new_step_api=True)
  53. env_2 = env_spec.make(disable_env_checker=True, new_step_api=True)
  54. initial_obs_1 = env_1.reset(seed=SEED)
  55. initial_obs_2 = env_2.reset(seed=SEED)
  56. assert_equals(initial_obs_1, initial_obs_2)
  57. env_1.action_space.seed(SEED)
  58. for time_step in range(NUM_STEPS):
  59. # We don't evaluate the determinism of actions
  60. action = env_1.action_space.sample()
  61. obs_1, rew_1, terminated_1, truncated_1, info_1 = env_1.step(action)
  62. obs_2, rew_2, terminated_2, truncated_2, info_2 = env_2.step(action)
  63. assert_equals(obs_1, obs_2, f"[{time_step}] ")
  64. assert env_1.observation_space.contains(
  65. obs_1
  66. ) # obs_2 verified by previous assertion
  67. assert rew_1 == rew_2, f"[{time_step}] reward 1={rew_1}, reward 2={rew_2}"
  68. assert (
  69. terminated_1 == terminated_2
  70. ), f"[{time_step}] terminated 1={terminated_1}, terminated 2={terminated_2}"
  71. assert (
  72. truncated_1 == truncated_2
  73. ), f"[{time_step}] truncated 1={truncated_1}, truncated 2={truncated_2}"
  74. assert_equals(info_1, info_2, f"[{time_step}] ")
  75. if (
  76. terminated_1 or truncated_1
  77. ): # terminated_2 and truncated_2 verified by previous assertion
  78. env_1.reset(seed=SEED)
  79. env_2.reset(seed=SEED)
  80. env_1.close()
  81. env_2.close()
  82. @pytest.mark.parametrize(
  83. "spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
  84. )
  85. def test_render_modes(spec):
  86. env = spec.make(new_step_api=True)
  87. for mode in env.metadata.get("render_modes", []):
  88. if mode != "human":
  89. new_env = spec.make(new_step_api=True)
  90. new_env.reset()
  91. new_env.step(new_env.action_space.sample())
  92. new_env.render(mode=mode)
  93. @pytest.mark.parametrize("env_id", ["MiniGrid-DoorKey-6x6-v0"])
  94. def test_agent_sees_method(env_id):
  95. env = gym.make(env_id, new_step_api=True)
  96. goal_pos = (env.grid.width - 2, env.grid.height - 2)
  97. # Test the "in" operator on grid objects
  98. assert ("green", "goal") in env.grid
  99. assert ("blue", "key") not in env.grid
  100. # Test the env.agent_sees() function
  101. env.reset()
  102. for i in range(0, 500):
  103. action = env.action_space.sample()
  104. obs, reward, terminated, truncated, info = env.step(action)
  105. grid, _ = Grid.decode(obs["image"])
  106. goal_visible = ("green", "goal") in grid
  107. agent_sees_goal = env.agent_sees(*goal_pos)
  108. assert agent_sees_goal == goal_visible
  109. if terminated or truncated:
  110. env.reset()
  111. env.close()
  112. @pytest.mark.parametrize(
  113. "env_spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
  114. )
  115. def old_run_test(env_spec):
  116. # Load the gym environment
  117. env = env_spec.make(new_step_api=True)
  118. env.max_steps = min(env.max_steps, 200)
  119. env.reset()
  120. env.render()
  121. # Verify that the same seed always produces the same environment
  122. for i in range(0, 5):
  123. seed = 1337 + i
  124. _ = env.reset(seed=seed)
  125. grid1 = env.grid
  126. _ = env.reset(seed=seed)
  127. grid2 = env.grid
  128. assert grid1 == grid2
  129. env.reset()
  130. # Run for a few episodes
  131. num_episodes = 0
  132. while num_episodes < 5:
  133. # Pick a random action
  134. action = env.action_space.sample()
  135. obs, reward, terminated, truncated, info = env.step(action)
  136. # Validate the agent position
  137. assert env.agent_pos[0] < env.width
  138. assert env.agent_pos[1] < env.height
  139. # Test observation encode/decode roundtrip
  140. img = obs["image"]
  141. grid, vis_mask = Grid.decode(img)
  142. img2 = grid.encode(vis_mask=vis_mask)
  143. assert np.array_equal(img, img2)
  144. # Test the env to string function
  145. str(env)
  146. # Check that the reward is within the specified range
  147. assert reward >= env.reward_range[0], reward
  148. assert reward <= env.reward_range[1], reward
  149. if terminated or truncated:
  150. num_episodes += 1
  151. env.reset()
  152. env.render()
  153. # Test the close method
  154. env.close()
  155. @pytest.mark.parametrize("env_id", ["MiniGrid-Empty-8x8-v0"])
  156. def test_interactive_mode(env_id):
  157. env = gym.make(env_id, new_step_api=True)
  158. env.reset()
  159. for i in range(0, 100):
  160. print(f"step {i}")
  161. # Pick a random action
  162. action = env.action_space.sample()
  163. obs, reward, terminated, truncated, info = env.step(action)
  164. # Test the close method
  165. env.close()
  166. def test_mission_space():
  167. # Test placeholders
  168. mission_space = MissionSpace(
  169. mission_func=lambda color, obj_type: f"Get the {color} {obj_type}.",
  170. ordered_placeholders=[["green", "red"], ["ball", "key"]],
  171. )
  172. assert mission_space.contains("Get the green ball.")
  173. assert mission_space.contains("Get the red key.")
  174. assert not mission_space.contains("Get the purple box.")
  175. # Test passing inverted placeholders
  176. assert not mission_space.contains("Get the key red.")
  177. # Test passing extra repeated placeholders
  178. assert not mission_space.contains("Get the key red key.")
  179. # Test contained placeholders like "get the" and "go get the". "get the" string is contained in both placeholders.
  180. mission_space = MissionSpace(
  181. mission_func=lambda get_syntax, obj_type: f"{get_syntax} {obj_type}.",
  182. ordered_placeholders=[
  183. ["go get the", "get the", "go fetch the", "fetch the"],
  184. ["ball", "key"],
  185. ],
  186. )
  187. assert mission_space.contains("get the ball.")
  188. assert mission_space.contains("go get the key.")
  189. assert mission_space.contains("go fetch the ball.")
  190. # Test repeated placeholders
  191. mission_space = MissionSpace(
  192. mission_func=lambda get_syntax, color_1, obj_type_1, color_2, obj_type_2: f"{get_syntax} {color_1} {obj_type_1} and the {color_2} {obj_type_2}.",
  193. ordered_placeholders=[
  194. ["go get the", "get the", "go fetch the", "fetch the"],
  195. ["green", "red"],
  196. ["ball", "key"],
  197. ["green", "red"],
  198. ["ball", "key"],
  199. ],
  200. )
  201. assert mission_space.contains("get the green key and the green key.")
  202. assert mission_space.contains("go fetch the red ball and the green key.")