run_tests.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. #!/usr/bin/env python3
  2. import random
  3. import numpy as np
  4. import gym
  5. from gym_minigrid.register import env_list
  6. from gym_minigrid.minigrid import Grid, OBJECT_TO_IDX
  7. # Test specifically importing a specific environment
  8. from gym_minigrid.envs import DoorKeyEnv
  9. # Test importing wrappers
  10. from gym_minigrid.wrappers import *
  11. ##############################################################################
  12. print('%d environments registered' % len(env_list))
  13. for env_name in env_list:
  14. print('testing "%s"' % env_name)
  15. # Load the gym environment
  16. env = gym.make(env_name)
  17. env.max_steps = min(env.max_steps, 200)
  18. env.reset()
  19. env.render('rgb_array')
  20. # Verify that the same seed always produces the same environment
  21. for i in range(0, 5):
  22. seed = 1337 + i
  23. env.seed(seed)
  24. grid1 = env.grid
  25. env.seed(seed)
  26. grid2 = env.grid
  27. assert grid1 == grid2
  28. env.reset()
  29. # Run for a few episodes
  30. num_episodes = 0
  31. while num_episodes < 5:
  32. # Pick a random action
  33. action = random.randint(0, env.action_space.n - 1)
  34. obs, reward, done, info = env.step(action)
  35. # Validate the agent position
  36. assert env.agent_pos[0] < env.width
  37. assert env.agent_pos[1] < env.height
  38. # Test observation encode/decode roundtrip
  39. img = obs['image']
  40. grid, vis_mask = Grid.decode(img)
  41. img2 = grid.encode(vis_mask=vis_mask)
  42. assert np.array_equal(img, img2)
  43. # Test the env to string function
  44. str(env)
  45. # Check that the reward is within the specified range
  46. assert reward >= env.reward_range[0], reward
  47. assert reward <= env.reward_range[1], reward
  48. if done:
  49. num_episodes += 1
  50. env.reset()
  51. env.render('rgb_array')
  52. # Test the close method
  53. env.close()
  54. env = gym.make(env_name)
  55. env = ReseedWrapper(env)
  56. for _ in range(10):
  57. env.reset()
  58. env.step(0)
  59. env.close()
  60. env = gym.make(env_name)
  61. env = ImgObsWrapper(env)
  62. env.reset()
  63. env.step(0)
  64. env.close()
  65. # Test the fully observable wrapper
  66. env = gym.make(env_name)
  67. env = FullyObsWrapper(env)
  68. env.reset()
  69. obs, _, _, _ = env.step(0)
  70. assert obs['image'].shape == env.observation_space.spaces['image'].shape
  71. env.close()
  72. # RGB image observation wrapper
  73. env = gym.make(env_name)
  74. env = RGBImgPartialObsWrapper(env)
  75. env.reset()
  76. obs, _, _, _ = env.step(0)
  77. assert obs['image'].mean() > 0
  78. env.close()
  79. env = gym.make(env_name)
  80. env = FlatObsWrapper(env)
  81. env.reset()
  82. env.step(0)
  83. env.close()
  84. env = gym.make(env_name)
  85. env = ViewSizeWrapper(env, 5)
  86. env.reset()
  87. env.step(0)
  88. env.close()
  89. # Test the wrappers return proper observation spaces.
  90. wrappers = [
  91. RGBImgObsWrapper,
  92. RGBImgPartialObsWrapper,
  93. OneHotPartialObsWrapper
  94. ]
  95. for wrapper in wrappers:
  96. env = wrapper(gym.make(env_name))
  97. obs_space, wrapper_name = env.observation_space, wrapper.__name__
  98. assert isinstance(
  99. obs_space, spaces.Dict
  100. ), "Observation space for {0} is not a Dict: {1}.".format(
  101. wrapper_name, obs_space
  102. )
  103. # This should not fail either
  104. ImgObsWrapper(env)
  105. ##############################################################################
  106. print('testing agent_sees method')
  107. env = gym.make('MiniGrid-DoorKey-6x6-v0')
  108. goal_pos = (env.grid.width - 2, env.grid.height - 2)
  109. # Test the "in" operator on grid objects
  110. assert ('green', 'goal') in env.grid
  111. assert ('blue', 'key') not in env.grid
  112. # Test the env.agent_sees() function
  113. env.reset()
  114. for i in range(0, 500):
  115. action = random.randint(0, env.action_space.n - 1)
  116. obs, reward, done, info = env.step(action)
  117. grid, _ = Grid.decode(obs['image'])
  118. goal_visible = ('green', 'goal') in grid
  119. agent_sees_goal = env.agent_sees(*goal_pos)
  120. assert agent_sees_goal == goal_visible
  121. if done:
  122. env.reset()