test_wrappers.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. import math
  2. import gymnasium as gym
  3. import numpy as np
  4. import pytest
  5. from minigrid.core.actions import Actions
  6. from minigrid.envs import EmptyEnv
  7. from minigrid.wrappers import (
  8. ActionBonus,
  9. DictObservationSpaceWrapper,
  10. FlatObsWrapper,
  11. FullyObsWrapper,
  12. ImgObsWrapper,
  13. OneHotPartialObsWrapper,
  14. ReseedWrapper,
  15. RGBImgObsWrapper,
  16. RGBImgPartialObsWrapper,
  17. StateBonus,
  18. ViewSizeWrapper,
  19. )
  20. from tests.utils import all_testing_env_specs, assert_equals, minigrid_testing_env_specs
  21. SEEDS = [100, 243, 500]
  22. NUM_STEPS = 100
  23. @pytest.mark.parametrize(
  24. "env_spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
  25. )
  26. def test_reseed_wrapper(env_spec):
  27. """
  28. Test the ReseedWrapper with a list of SEEDS.
  29. """
  30. unwrapped_env = env_spec.make()
  31. env = env_spec.make()
  32. env = ReseedWrapper(env, seeds=SEEDS)
  33. env.action_space.seed(0)
  34. for seed in SEEDS:
  35. env.reset()
  36. unwrapped_env.reset(seed=seed)
  37. for time_step in range(NUM_STEPS):
  38. action = env.action_space.sample()
  39. obs, rew, terminated, truncated, info = env.step(action)
  40. (
  41. unwrapped_obs,
  42. unwrapped_rew,
  43. unwrapped_terminated,
  44. unwrapped_truncated,
  45. unwrapped_info,
  46. ) = unwrapped_env.step(action)
  47. assert_equals(obs, unwrapped_obs, f"[{time_step}] ")
  48. assert unwrapped_env.observation_space.contains(obs)
  49. assert (
  50. rew == unwrapped_rew
  51. ), f"[{time_step}] reward={rew}, unwrapped reward={unwrapped_rew}"
  52. assert (
  53. terminated == unwrapped_terminated
  54. ), f"[{time_step}] terminated={terminated}, unwrapped terminated={unwrapped_terminated}"
  55. assert (
  56. truncated == unwrapped_truncated
  57. ), f"[{time_step}] truncated={truncated}, unwrapped truncated={unwrapped_truncated}"
  58. assert_equals(info, unwrapped_info, f"[{time_step}] ")
  59. # Start the next seed
  60. if terminated or truncated:
  61. break
  62. env.close()
  63. unwrapped_env.close()
  64. @pytest.mark.parametrize("env_id", ["MiniGrid-Empty-16x16-v0"])
  65. def test_state_bonus_wrapper(env_id):
  66. env = gym.make(env_id)
  67. wrapped_env = StateBonus(gym.make(env_id))
  68. action_forward = Actions.forward
  69. action_left = Actions.left
  70. action_right = Actions.right
  71. for _ in range(10):
  72. wrapped_env.reset()
  73. for _ in range(5):
  74. wrapped_env.step(action_forward)
  75. # Turn lef 3 times (check that actions don't influence bonus)
  76. for _ in range(3):
  77. _, wrapped_rew, _, _, _ = wrapped_env.step(action_left)
  78. env.reset()
  79. for _ in range(5):
  80. env.step(action_forward)
  81. # Turn right 3 times
  82. for _ in range(3):
  83. _, rew, _, _, _ = env.step(action_right)
  84. expected_bonus_reward = rew + 1 / math.sqrt(13)
  85. assert expected_bonus_reward == wrapped_rew
  86. @pytest.mark.parametrize("env_id", ["MiniGrid-Empty-16x16-v0"])
  87. def test_action_bonus_wrapper(env_id):
  88. env = gym.make(env_id)
  89. wrapped_env = ActionBonus(gym.make(env_id))
  90. action = Actions.forward
  91. for _ in range(10):
  92. wrapped_env.reset()
  93. for _ in range(5):
  94. _, wrapped_rew, _, _, _ = wrapped_env.step(action)
  95. env.reset()
  96. for _ in range(5):
  97. _, rew, _, _, _ = env.step(action)
  98. expected_bonus_reward = rew + 1 / math.sqrt(10)
  99. assert expected_bonus_reward == wrapped_rew
  100. @pytest.mark.parametrize(
  101. "env_spec",
  102. minigrid_testing_env_specs,
  103. ids=[spec.id for spec in minigrid_testing_env_specs],
  104. ) # DictObservationSpaceWrapper is not compatible with BabyAI levels. See minigrid/wrappers.py for more details.
  105. def test_dict_observation_space_wrapper(env_spec):
  106. env = env_spec.make()
  107. env = DictObservationSpaceWrapper(env)
  108. env.reset()
  109. mission = env.mission
  110. obs, _, _, _, _ = env.step(0)
  111. assert env.string_to_indices(mission) == [
  112. value for value in obs["mission"] if value != 0
  113. ]
  114. env.close()
  115. @pytest.mark.parametrize(
  116. "wrapper",
  117. [
  118. ReseedWrapper,
  119. ImgObsWrapper,
  120. FlatObsWrapper,
  121. ViewSizeWrapper,
  122. DictObservationSpaceWrapper,
  123. OneHotPartialObsWrapper,
  124. RGBImgPartialObsWrapper,
  125. FullyObsWrapper,
  126. ],
  127. )
  128. @pytest.mark.parametrize(
  129. "env_spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
  130. )
  131. def test_main_wrappers(wrapper, env_spec):
  132. if (
  133. wrapper in (DictObservationSpaceWrapper, FlatObsWrapper)
  134. and env_spec not in minigrid_testing_env_specs
  135. ):
  136. # DictObservationSpaceWrapper and FlatObsWrapper are not compatible with BabyAI levels
  137. # See minigrid/wrappers.py for more details
  138. pytest.skip()
  139. env = env_spec.make()
  140. env = wrapper(env)
  141. for _ in range(10):
  142. env.reset()
  143. env.step(0)
  144. env.close()
  145. @pytest.mark.parametrize(
  146. "wrapper",
  147. [
  148. OneHotPartialObsWrapper,
  149. RGBImgPartialObsWrapper,
  150. FullyObsWrapper,
  151. ],
  152. )
  153. @pytest.mark.parametrize(
  154. "env_spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
  155. )
  156. def test_observation_space_wrappers(wrapper, env_spec):
  157. env = wrapper(env_spec.make(disable_env_checker=True))
  158. obs_space, wrapper_name = env.observation_space, wrapper.__name__
  159. assert isinstance(
  160. obs_space, gym.spaces.Dict
  161. ), f"Observation space for {wrapper_name} is not a Dict: {obs_space}."
  162. # This should not fail either
  163. ImgObsWrapper(env)
  164. env.reset()
  165. env.step(0)
  166. env.close()
  167. class EmptyEnvWithExtraObs(EmptyEnv):
  168. """
  169. Custom environment with an extra observation
  170. """
  171. def __init__(self) -> None:
  172. super().__init__(size=5)
  173. self.observation_space["size"] = gym.spaces.Box(
  174. low=0, high=np.iinfo(np.uint).max, shape=(2,), dtype=np.uint
  175. )
  176. def reset(self, **kwargs):
  177. obs, info = super().reset(**kwargs)
  178. obs["size"] = np.array([self.width, self.height])
  179. return obs, info
  180. def step(self, action):
  181. obs, reward, terminated, truncated, info = super().step(action)
  182. obs["size"] = np.array([self.width, self.height])
  183. return obs, reward, terminated, truncated, info
  184. @pytest.mark.parametrize(
  185. "wrapper",
  186. [
  187. OneHotPartialObsWrapper,
  188. RGBImgObsWrapper,
  189. RGBImgPartialObsWrapper,
  190. FullyObsWrapper,
  191. ],
  192. )
  193. def test_agent_sees_method(wrapper):
  194. env1 = wrapper(EmptyEnvWithExtraObs())
  195. env2 = wrapper(gym.make("MiniGrid-Empty-5x5-v0"))
  196. obs1, _ = env1.reset(seed=0)
  197. obs2, _ = env2.reset(seed=0)
  198. assert "size" in obs1
  199. assert obs1["size"].shape == (2,)
  200. assert (obs1["size"] == [5, 5]).all()
  201. for key in obs2:
  202. assert np.array_equal(obs1[key], obs2[key])
  203. obs1, reward1, terminated1, truncated1, _ = env1.step(0)
  204. obs2, reward2, terminated2, truncated2, _ = env2.step(0)
  205. assert "size" in obs1
  206. assert obs1["size"].shape == (2,)
  207. assert (obs1["size"] == [5, 5]).all()
  208. for key in obs2:
  209. assert np.array_equal(obs1[key], obs2[key])