run_tests.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. #!/usr/bin/env python3
  2. import random
  3. import gym
  4. import numpy as np
  5. import gym_minigrid
  6. # Test specifically importing a specific environment
  7. from gym_minigrid.minigrid import Grid
  8. from gym_minigrid.register import env_list
  9. from gym_minigrid.wrappers import (
  10. FlatObsWrapper,
  11. FullyObsWrapper,
  12. ImgObsWrapper,
  13. OneHotPartialObsWrapper,
  14. ReseedWrapper,
  15. RGBImgObsWrapper,
  16. RGBImgPartialObsWrapper,
  17. ViewSizeWrapper,
  18. )
  19. ##############################################################################
  20. # Test importing wrappers
  21. print("%d environments registered" % len(env_list))
  22. for env_idx, env_name in enumerate(env_list):
  23. print(f"testing {env_name} ({env_idx + 1}/{len(env_list)})")
  24. # Load the gym environment
  25. env = gym.make(env_name)
  26. env.max_steps = min(env.max_steps, 200)
  27. env.reset()
  28. env.render("rgb_array")
  29. # Verify that the same seed always produces the same environment
  30. for i in range(0, 5):
  31. seed = 1337 + i
  32. env.seed(seed)
  33. grid1 = env.grid
  34. env.seed(seed)
  35. grid2 = env.grid
  36. assert grid1 == grid2
  37. env.reset()
  38. # Run for a few episodes
  39. num_episodes = 0
  40. while num_episodes < 5:
  41. # Pick a random action
  42. action = random.randint(0, env.action_space.n - 1)
  43. obs, reward, done, info = env.step(action)
  44. # Validate the agent position
  45. assert env.agent_pos[0] < env.width
  46. assert env.agent_pos[1] < env.height
  47. # Test observation encode/decode roundtrip
  48. img = obs["image"]
  49. grid, vis_mask = Grid.decode(img)
  50. img2 = grid.encode(vis_mask=vis_mask)
  51. assert np.array_equal(img, img2)
  52. # Test the env to string function
  53. str(env)
  54. # Check that the reward is within the specified range
  55. assert reward >= env.reward_range[0], reward
  56. assert reward <= env.reward_range[1], reward
  57. if done:
  58. num_episodes += 1
  59. env.reset()
  60. env.render("rgb_array")
  61. # Test the close method
  62. env.close()
  63. env = gym.make(env_name)
  64. env = ReseedWrapper(env)
  65. for _ in range(10):
  66. env.reset()
  67. env.step(0)
  68. env.close()
  69. env = gym.make(env_name)
  70. env = ImgObsWrapper(env)
  71. env.reset()
  72. env.step(0)
  73. env.close()
  74. # Test the fully observable wrapper
  75. env = gym.make(env_name)
  76. env = FullyObsWrapper(env)
  77. env.reset()
  78. obs, _, _, _ = env.step(0)
  79. assert obs["image"].shape == env.observation_space.spaces["image"].shape
  80. env.close()
  81. # RGB image observation wrapper
  82. env = gym.make(env_name)
  83. env = RGBImgPartialObsWrapper(env)
  84. env.reset()
  85. obs, _, _, _ = env.step(0)
  86. assert obs["image"].mean() > 0
  87. env.close()
  88. env = gym.make(env_name)
  89. env = FlatObsWrapper(env)
  90. env.reset()
  91. env.step(0)
  92. env.close()
  93. env = gym.make(env_name)
  94. env = ViewSizeWrapper(env, 5)
  95. env.reset()
  96. env.step(0)
  97. env.close()
  98. # Test the wrappers return proper observation spaces.
  99. wrappers = [RGBImgObsWrapper, RGBImgPartialObsWrapper, OneHotPartialObsWrapper]
  100. for wrapper in wrappers:
  101. env = wrapper(gym.make(env_name))
  102. obs_space, wrapper_name = env.observation_space, wrapper.__name__
  103. assert isinstance(
  104. obs_space, gym.spaces.Dict
  105. ), f"Observation space for {wrapper_name} is not a Dict: {obs_space}."
  106. # This should not fail either
  107. ImgObsWrapper(env)
  108. env.reset()
  109. env.step(0)
  110. env.close()
  111. ##############################################################################
  112. print("testing extra observations")
  113. class EmptyEnvWithExtraObs(gym_minigrid.envs.EmptyEnv5x5):
  114. """
  115. Custom environment with an extra observation
  116. """
  117. def __init__(self) -> None:
  118. super().__init__()
  119. self.observation_space["size"] = gym.spaces.Box(
  120. low=0, high=np.iinfo(np.uint).max, shape=(2,), dtype=np.uint
  121. )
  122. def reset(self):
  123. obs = super().reset()
  124. obs["size"] = np.array([self.width, self.height])
  125. return obs
  126. def step(self, action):
  127. obs, reward, done, info = super().step(action)
  128. obs["size"] = np.array([self.width, self.height])
  129. return obs, reward, done, info
  130. wrappers = [
  131. OneHotPartialObsWrapper,
  132. RGBImgObsWrapper,
  133. RGBImgPartialObsWrapper,
  134. FullyObsWrapper,
  135. ]
  136. for wrapper in wrappers:
  137. env1 = wrapper(EmptyEnvWithExtraObs())
  138. env2 = wrapper(gym.make("MiniGrid-Empty-5x5-v0"))
  139. env1.seed(0)
  140. env2.seed(0)
  141. obs1 = env1.reset()
  142. obs2 = env2.reset()
  143. assert "size" in obs1
  144. assert obs1["size"].shape == (2,)
  145. assert (obs1["size"] == [5, 5]).all()
  146. for key in obs2:
  147. assert np.array_equal(obs1[key], obs2[key])
  148. obs1, reward1, done1, _ = env1.step(0)
  149. obs2, reward2, done2, _ = env2.step(0)
  150. assert "size" in obs1
  151. assert obs1["size"].shape == (2,)
  152. assert (obs1["size"] == [5, 5]).all()
  153. for key in obs2:
  154. assert np.array_equal(obs1[key], obs2[key])
  155. ##############################################################################
  156. print("testing agent_sees method")
  157. env = gym.make("MiniGrid-DoorKey-6x6-v0")
  158. goal_pos = (env.grid.width - 2, env.grid.height - 2)
  159. # Test the "in" operator on grid objects
  160. assert ("green", "goal") in env.grid
  161. assert ("blue", "key") not in env.grid
  162. # Test the env.agent_sees() function
  163. env.reset()
  164. for i in range(0, 500):
  165. action = random.randint(0, env.action_space.n - 1)
  166. obs, reward, done, info = env.step(action)
  167. grid, _ = Grid.decode(obs["image"])
  168. goal_visible = ("green", "goal") in grid
  169. agent_sees_goal = env.agent_sees(*goal_pos)
  170. assert agent_sees_goal == goal_visible
  171. if done:
  172. env.reset()