run_tests.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. #!/usr/bin/env python3
  2. import random
  3. import numpy as np
  4. import gym
  5. import gym_minigrid
  6. from gym_minigrid.register import env_list
  7. from gym_minigrid.minigrid import Grid, OBJECT_TO_IDX
  8. # Test specifically importing a specific environment
  9. from gym_minigrid.envs import DoorKeyEnv
  10. # Test importing wrappers
  11. from gym_minigrid.wrappers import *
  12. ##############################################################################
  13. print('%d environments registered' % len(env_list))
  14. for env_idx, env_name in enumerate(env_list):
  15. print('testing {} ({}/{})'.format(env_name, env_idx+1, len(env_list)))
  16. # Load the gym environment
  17. env = gym.make(env_name)
  18. env.max_steps = min(env.max_steps, 200)
  19. env.reset()
  20. env.render('rgb_array')
  21. # Verify that the same seed always produces the same environment
  22. for i in range(0, 5):
  23. seed = 1337 + i
  24. env.seed(seed)
  25. grid1 = env.grid
  26. env.seed(seed)
  27. grid2 = env.grid
  28. assert grid1 == grid2
  29. env.reset()
  30. # Run for a few episodes
  31. num_episodes = 0
  32. while num_episodes < 5:
  33. # Pick a random action
  34. action = random.randint(0, env.action_space.n - 1)
  35. obs, reward, done, info = env.step(action)
  36. # Validate the agent position
  37. assert env.agent_pos[0] < env.width
  38. assert env.agent_pos[1] < env.height
  39. # Test observation encode/decode roundtrip
  40. img = obs['image']
  41. grid, vis_mask = Grid.decode(img)
  42. img2 = grid.encode(vis_mask=vis_mask)
  43. assert np.array_equal(img, img2)
  44. # Test the env to string function
  45. str(env)
  46. # Check that the reward is within the specified range
  47. assert reward >= env.reward_range[0], reward
  48. assert reward <= env.reward_range[1], reward
  49. if done:
  50. num_episodes += 1
  51. env.reset()
  52. env.render('rgb_array')
  53. # Test the close method
  54. env.close()
  55. env = gym.make(env_name)
  56. env = ReseedWrapper(env)
  57. for _ in range(10):
  58. env.reset()
  59. env.step(0)
  60. env.close()
  61. env = gym.make(env_name)
  62. env = ImgObsWrapper(env)
  63. env.reset()
  64. env.step(0)
  65. env.close()
  66. # Test the fully observable wrapper
  67. env = gym.make(env_name)
  68. env = FullyObsWrapper(env)
  69. env.reset()
  70. obs, _, _, _ = env.step(0)
  71. assert obs['image'].shape == env.observation_space.spaces['image'].shape
  72. env.close()
  73. # RGB image observation wrapper
  74. env = gym.make(env_name)
  75. env = RGBImgPartialObsWrapper(env)
  76. env.reset()
  77. obs, _, _, _ = env.step(0)
  78. assert obs['image'].mean() > 0
  79. env.close()
  80. env = gym.make(env_name)
  81. env = FlatObsWrapper(env)
  82. env.reset()
  83. env.step(0)
  84. env.close()
  85. env = gym.make(env_name)
  86. env = ViewSizeWrapper(env, 5)
  87. env.reset()
  88. env.step(0)
  89. env.close()
  90. # Test the wrappers return proper observation spaces.
  91. wrappers = [
  92. RGBImgObsWrapper,
  93. RGBImgPartialObsWrapper,
  94. OneHotPartialObsWrapper
  95. ]
  96. for wrapper in wrappers:
  97. env = wrapper(gym.make(env_name))
  98. obs_space, wrapper_name = env.observation_space, wrapper.__name__
  99. assert isinstance(
  100. obs_space, spaces.Dict
  101. ), "Observation space for {0} is not a Dict: {1}.".format(
  102. wrapper_name, obs_space
  103. )
  104. # This should not fail either
  105. ImgObsWrapper(env)
  106. env.reset()
  107. env.step(0)
  108. env.close()
  109. ##############################################################################
  110. print('testing extra observations')
  111. class EmptyEnvWithExtraObs(gym_minigrid.envs.EmptyEnv5x5):
  112. """
  113. Custom environment with an extra observation
  114. """
  115. def __init__(self) -> None:
  116. super().__init__()
  117. self.observation_space['size'] = spaces.Box(
  118. low=0,
  119. high=np.iinfo(np.uint).max,
  120. shape=(2,),
  121. dtype=np.uint
  122. )
  123. def reset(self):
  124. obs = super().reset()
  125. obs['size'] = np.array([self.width, self.height])
  126. return obs
  127. def step(self, action):
  128. obs, reward, done, info = super().step(action)
  129. obs['size'] = np.array([self.width, self.height])
  130. return obs, reward, done, info
  131. wrappers = [
  132. OneHotPartialObsWrapper,
  133. RGBImgObsWrapper,
  134. RGBImgPartialObsWrapper,
  135. FullyObsWrapper,
  136. ]
  137. for wrapper in wrappers:
  138. env1 = wrapper(EmptyEnvWithExtraObs())
  139. env2 = wrapper(gym.make('MiniGrid-Empty-5x5-v0'))
  140. env1.seed(0)
  141. env2.seed(0)
  142. obs1 = env1.reset()
  143. obs2 = env2.reset()
  144. assert 'size' in obs1
  145. assert obs1['size'].shape == (2,)
  146. assert (obs1['size'] == [5,5]).all()
  147. for key in obs2:
  148. assert np.array_equal(obs1[key], obs2[key])
  149. obs1, reward1, done1, _ = env1.step(0)
  150. obs2, reward2, done2, _ = env2.step(0)
  151. assert 'size' in obs1
  152. assert obs1['size'].shape == (2,)
  153. assert (obs1['size'] == [5,5]).all()
  154. for key in obs2:
  155. assert np.array_equal(obs1[key], obs2[key])
  156. ##############################################################################
  157. print('testing agent_sees method')
  158. env = gym.make('MiniGrid-DoorKey-6x6-v0')
  159. goal_pos = (env.grid.width - 2, env.grid.height - 2)
  160. # Test the "in" operator on grid objects
  161. assert ('green', 'goal') in env.grid
  162. assert ('blue', 'key') not in env.grid
  163. # Test the env.agent_sees() function
  164. env.reset()
  165. for i in range(0, 500):
  166. action = random.randint(0, env.action_space.n - 1)
  167. obs, reward, done, info = env.step(action)
  168. grid, _ = Grid.decode(obs['image'])
  169. goal_visible = ('green', 'goal') in grid
  170. agent_sees_goal = env.agent_sees(*goal_pos)
  171. assert agent_sees_goal == goal_visible
  172. if done:
  173. env.reset()