run_tests.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. #!/usr/bin/env python3
  2. import random
  3. import gym
  4. import numpy as np
  5. from gym import spaces
  6. from gym_minigrid.envs.empty import EmptyEnv
  7. from gym_minigrid.minigrid import Grid
  8. from gym_minigrid.register import env_list
  9. from gym_minigrid.wrappers import (
  10. DictObservationSpaceWrapper,
  11. FlatObsWrapper,
  12. FullyObsWrapper,
  13. ImgObsWrapper,
  14. OneHotPartialObsWrapper,
  15. ReseedWrapper,
  16. RGBImgObsWrapper,
  17. RGBImgPartialObsWrapper,
  18. ViewSizeWrapper,
  19. )
  20. # Test importing wrappers
  21. print("%d environments registered" % len(env_list))
  22. for env_idx, env_name in enumerate(env_list):
  23. print(f"testing {env_name} ({env_idx + 1}/{len(env_list)})")
  24. # Load the gym environment
  25. env = gym.make(env_name, render_mode="rgb_array", new_step_api=True)
  26. env.max_steps = min(env.max_steps, 200)
  27. env.reset()
  28. env.render()
  29. # Verify that the same seed always produces the same environment
  30. for i in range(0, 5):
  31. seed = 1337 + i
  32. _ = env.reset(seed=seed)
  33. grid1 = env.grid
  34. _ = env.reset(seed=seed)
  35. grid2 = env.grid
  36. assert grid1 == grid2
  37. env.reset()
  38. # Run for a few episodes
  39. num_episodes = 0
  40. while num_episodes < 5:
  41. # Pick a random action
  42. action = random.randint(0, env.action_space.n - 1)
  43. obs, reward, terminated, truncated, info = env.step(action)
  44. # Validate the agent position
  45. assert env.agent_pos[0] < env.width
  46. assert env.agent_pos[1] < env.height
  47. # Test observation encode/decode roundtrip
  48. img = obs["image"]
  49. grid, vis_mask = Grid.decode(img)
  50. img2 = grid.encode(vis_mask=vis_mask)
  51. assert np.array_equal(img, img2)
  52. # Test the env to string function
  53. str(env)
  54. # Check that the reward is within the specified range
  55. assert reward >= env.reward_range[0], reward
  56. assert reward <= env.reward_range[1], reward
  57. if terminated or truncated:
  58. num_episodes += 1
  59. env.reset()
  60. env.render()
  61. # Test the close method
  62. env.close()
  63. env = gym.make(env_name, new_step_api=True)
  64. env = ReseedWrapper(env)
  65. for _ in range(10):
  66. env.reset()
  67. env.step(0)
  68. env.close()
  69. env = gym.make(env_name, new_step_api=True)
  70. env = ImgObsWrapper(env)
  71. env.reset()
  72. env.step(0)
  73. env.close()
  74. # Test the fully observable wrapper
  75. env = gym.make(env_name, new_step_api=True)
  76. env = FullyObsWrapper(env)
  77. env.reset()
  78. obs, _, _, _, _ = env.step(0)
  79. assert obs["image"].shape == env.observation_space.spaces["image"].shape
  80. env.close()
  81. # RGB image observation wrapper
  82. env = gym.make(env_name, new_step_api=True)
  83. env = RGBImgPartialObsWrapper(env)
  84. env.reset()
  85. obs, _, _, _, _ = env.step(0)
  86. assert obs["image"].mean() > 0
  87. env.close()
  88. env = gym.make(env_name, new_step_api=True)
  89. env = FlatObsWrapper(env)
  90. env.reset()
  91. env.step(0)
  92. env.close()
  93. env = gym.make(env_name, new_step_api=True)
  94. env = ViewSizeWrapper(env, 5)
  95. env.reset()
  96. env.step(0)
  97. env.close()
  98. # Test the DictObservationSpaceWrapper
  99. env = gym.make(env_name, new_step_api=True)
  100. env = DictObservationSpaceWrapper(env)
  101. env.reset()
  102. mission = env.mission
  103. obs, _, _, _, _ = env.step(0)
  104. assert env.string_to_indices(mission) == [
  105. value for value in obs["mission"] if value != 0
  106. ]
  107. env.close()
  108. # Test the wrappers return proper observation spaces.
  109. wrappers = [RGBImgObsWrapper, RGBImgPartialObsWrapper, OneHotPartialObsWrapper]
  110. for wrapper in wrappers:
  111. env = wrapper(gym.make(env_name, render_mode="rgb_array", new_step_api=True))
  112. obs_space, wrapper_name = env.observation_space, wrapper.__name__
  113. assert isinstance(
  114. obs_space, spaces.Dict
  115. ), f"Observation space for {wrapper_name} is not a Dict: {obs_space}."
  116. # This should not fail either
  117. ImgObsWrapper(env)
  118. env.reset()
  119. env.step(0)
  120. env.close()
  121. ##############################################################################
  122. print("testing extra observations")
  123. class EmptyEnvWithExtraObs(EmptyEnv):
  124. """
  125. Custom environment with an extra observation
  126. """
  127. def __init__(self, **kwargs) -> None:
  128. super().__init__(size=5, **kwargs)
  129. self.observation_space["size"] = spaces.Box(
  130. low=0,
  131. high=1000, # gym does not like np.iinfo(np.uint).max,
  132. shape=(2,),
  133. dtype=np.uint,
  134. )
  135. def reset(self, **kwargs):
  136. obs = super().reset(**kwargs)
  137. obs["size"] = np.array([self.width, self.height], dtype=np.uint)
  138. return obs
  139. def step(self, action):
  140. obs, reward, terminated, truncated, info = super().step(action)
  141. obs["size"] = np.array([self.width, self.height], dtype=np.uint)
  142. return obs, reward, terminated, truncated, info
  143. wrappers = [
  144. OneHotPartialObsWrapper,
  145. RGBImgObsWrapper,
  146. RGBImgPartialObsWrapper,
  147. FullyObsWrapper,
  148. ]
  149. for wrapper in wrappers:
  150. env1 = wrapper(EmptyEnvWithExtraObs(render_mode="rgb_array"))
  151. env2 = wrapper(
  152. gym.make("MiniGrid-Empty-5x5-v0", render_mode="rgb_array", new_step_api=True)
  153. )
  154. obs1 = env1.reset(seed=0)
  155. obs2 = env2.reset(seed=0)
  156. assert "size" in obs1
  157. assert obs1["size"].shape == (2,)
  158. assert (obs1["size"] == [5, 5]).all()
  159. for key in obs2:
  160. assert np.array_equal(obs1[key], obs2[key])
  161. obs1, reward1, terminated1, truncated1, _ = env1.step(0)
  162. obs2, reward2, terminated2, truncated2, _ = env2.step(0)
  163. assert "size" in obs1
  164. assert obs1["size"].shape == (2,)
  165. assert (obs1["size"] == [5, 5]).all()
  166. for key in obs2:
  167. assert np.array_equal(obs1[key], obs2[key])
  168. ##############################################################################
  169. print("testing agent_sees method")
  170. env = gym.make("MiniGrid-DoorKey-6x6-v0", new_step_api=True)
  171. goal_pos = (env.grid.width - 2, env.grid.height - 2)
  172. # Test the "in" operator on grid objects
  173. assert ("green", "goal") in env.grid
  174. assert ("blue", "key") not in env.grid
  175. # Test the env.agent_sees() function
  176. env.reset()
  177. for i in range(0, 500):
  178. action = random.randint(0, env.action_space.n - 1)
  179. obs, reward, terminated, truncated, info = env.step(action)
  180. grid, _ = Grid.decode(obs["image"])
  181. goal_visible = ("green", "goal") in grid
  182. agent_sees_goal = env.agent_sees(*goal_pos)
  183. assert agent_sees_goal == goal_visible
  184. if terminated or truncated:
  185. env.reset()