test_wrappers.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. import math
  2. import gymnasium as gym
  3. import numpy as np
  4. import pytest
  5. from minigrid.core.actions import Actions
  6. from minigrid.envs import EmptyEnv
  7. from minigrid.wrappers import (
  8. ActionBonus,
  9. DictObservationSpaceWrapper,
  10. FlatObsWrapper,
  11. FullyObsWrapper,
  12. ImgObsWrapper,
  13. OneHotPartialObsWrapper,
  14. ReseedWrapper,
  15. RGBImgObsWrapper,
  16. RGBImgPartialObsWrapper,
  17. StateBonus,
  18. ViewSizeWrapper,
  19. )
  20. from tests.utils import all_testing_env_specs, assert_equals
  21. SEEDS = [100, 243, 500]
  22. NUM_STEPS = 100
  23. @pytest.mark.parametrize(
  24. "env_spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
  25. )
  26. def test_reseed_wrapper(env_spec):
  27. """
  28. Test the ReseedWrapper with a list of SEEDS.
  29. """
  30. unwrapped_env = env_spec.make()
  31. env = env_spec.make()
  32. env = ReseedWrapper(env, seeds=SEEDS)
  33. env.action_space.seed(0)
  34. for seed in SEEDS:
  35. env.reset()
  36. unwrapped_env.reset(seed=seed)
  37. for time_step in range(NUM_STEPS):
  38. action = env.action_space.sample()
  39. obs, rew, terminated, truncated, info = env.step(action)
  40. (
  41. unwrapped_obs,
  42. unwrapped_rew,
  43. unwrapped_terminated,
  44. unwrapped_truncated,
  45. unwrapped_info,
  46. ) = unwrapped_env.step(action)
  47. assert_equals(obs, unwrapped_obs, f"[{time_step}] ")
  48. assert unwrapped_env.observation_space.contains(obs)
  49. assert (
  50. rew == unwrapped_rew
  51. ), f"[{time_step}] reward={rew}, unwrapped reward={unwrapped_rew}"
  52. assert (
  53. terminated == unwrapped_terminated
  54. ), f"[{time_step}] terminated={terminated}, unwrapped terminated={unwrapped_terminated}"
  55. assert (
  56. truncated == unwrapped_truncated
  57. ), f"[{time_step}] truncated={truncated}, unwrapped truncated={unwrapped_truncated}"
  58. assert_equals(info, unwrapped_info, f"[{time_step}] ")
  59. # Start the next seed
  60. if terminated or truncated:
  61. break
  62. env.close()
  63. unwrapped_env.close()
  64. @pytest.mark.parametrize("env_id", ["MiniGrid-Empty-16x16-v0"])
  65. def test_state_bonus_wrapper(env_id):
  66. env = gym.make(env_id)
  67. wrapped_env = StateBonus(gym.make(env_id))
  68. action_forward = Actions.forward
  69. action_left = Actions.left
  70. action_right = Actions.right
  71. for _ in range(10):
  72. wrapped_env.reset()
  73. for _ in range(5):
  74. wrapped_env.step(action_forward)
  75. # Turn lef 3 times (check that actions don't influence bonus)
  76. for _ in range(3):
  77. _, wrapped_rew, _, _, _ = wrapped_env.step(action_left)
  78. env.reset()
  79. for _ in range(5):
  80. env.step(action_forward)
  81. # Turn right 3 times
  82. for _ in range(3):
  83. _, rew, _, _, _ = env.step(action_right)
  84. expected_bonus_reward = rew + 1 / math.sqrt(13)
  85. assert expected_bonus_reward == wrapped_rew
  86. @pytest.mark.parametrize("env_id", ["MiniGrid-Empty-16x16-v0"])
  87. def test_action_bonus_wrapper(env_id):
  88. env = gym.make(env_id)
  89. wrapped_env = ActionBonus(gym.make(env_id))
  90. action = Actions.forward
  91. for _ in range(10):
  92. wrapped_env.reset()
  93. for _ in range(5):
  94. _, wrapped_rew, _, _, _ = wrapped_env.step(action)
  95. env.reset()
  96. for _ in range(5):
  97. _, rew, _, _, _ = env.step(action)
  98. expected_bonus_reward = rew + 1 / math.sqrt(10)
  99. assert expected_bonus_reward == wrapped_rew
  100. @pytest.mark.parametrize(
  101. "env_spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
  102. )
  103. def test_dict_observation_space_wrapper(env_spec):
  104. env = env_spec.make()
  105. env = DictObservationSpaceWrapper(env)
  106. env.reset()
  107. mission = env.mission
  108. obs, _, _, _, _ = env.step(0)
  109. assert env.string_to_indices(mission) == [
  110. value for value in obs["mission"] if value != 0
  111. ]
  112. env.close()
  113. @pytest.mark.parametrize(
  114. "wrapper",
  115. [
  116. ReseedWrapper,
  117. ImgObsWrapper,
  118. FlatObsWrapper,
  119. ViewSizeWrapper,
  120. DictObservationSpaceWrapper,
  121. OneHotPartialObsWrapper,
  122. RGBImgPartialObsWrapper,
  123. FullyObsWrapper,
  124. ],
  125. )
  126. @pytest.mark.parametrize(
  127. "env_spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
  128. )
  129. def test_main_wrappers(wrapper, env_spec):
  130. env = env_spec.make()
  131. env = wrapper(env)
  132. for _ in range(10):
  133. env.reset()
  134. env.step(0)
  135. env.close()
  136. @pytest.mark.parametrize(
  137. "wrapper",
  138. [
  139. OneHotPartialObsWrapper,
  140. RGBImgPartialObsWrapper,
  141. FullyObsWrapper,
  142. ],
  143. )
  144. @pytest.mark.parametrize(
  145. "env_spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
  146. )
  147. def test_observation_space_wrappers(wrapper, env_spec):
  148. env = wrapper(env_spec.make(disable_env_checker=True))
  149. obs_space, wrapper_name = env.observation_space, wrapper.__name__
  150. assert isinstance(
  151. obs_space, gym.spaces.Dict
  152. ), f"Observation space for {wrapper_name} is not a Dict: {obs_space}."
  153. # This should not fail either
  154. ImgObsWrapper(env)
  155. env.reset()
  156. env.step(0)
  157. env.close()
  158. class EmptyEnvWithExtraObs(EmptyEnv):
  159. """
  160. Custom environment with an extra observation
  161. """
  162. def __init__(self) -> None:
  163. super().__init__(size=5)
  164. self.observation_space["size"] = gym.spaces.Box(
  165. low=0, high=np.iinfo(np.uint).max, shape=(2,), dtype=np.uint
  166. )
  167. def reset(self, **kwargs):
  168. obs, info = super().reset(**kwargs)
  169. obs["size"] = np.array([self.width, self.height])
  170. return obs, info
  171. def step(self, action):
  172. obs, reward, terminated, truncated, info = super().step(action)
  173. obs["size"] = np.array([self.width, self.height])
  174. return obs, reward, terminated, truncated, info
  175. @pytest.mark.parametrize(
  176. "wrapper",
  177. [
  178. OneHotPartialObsWrapper,
  179. RGBImgObsWrapper,
  180. RGBImgPartialObsWrapper,
  181. FullyObsWrapper,
  182. ],
  183. )
  184. def test_agent_sees_method(wrapper):
  185. env1 = wrapper(EmptyEnvWithExtraObs())
  186. env2 = wrapper(gym.make("MiniGrid-Empty-5x5-v0"))
  187. obs1, _ = env1.reset(seed=0)
  188. obs2, _ = env2.reset(seed=0)
  189. assert "size" in obs1
  190. assert obs1["size"].shape == (2,)
  191. assert (obs1["size"] == [5, 5]).all()
  192. for key in obs2:
  193. assert np.array_equal(obs1[key], obs2[key])
  194. obs1, reward1, terminated1, truncated1, _ = env1.step(0)
  195. obs2, reward2, terminated2, truncated2, _ = env2.step(0)
  196. assert "size" in obs1
  197. assert obs1["size"].shape == (2,)
  198. assert (obs1["size"] == [5, 5]).all()
  199. for key in obs2:
  200. assert np.array_equal(obs1[key], obs2[key])