test_wrappers.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253
  1. from __future__ import annotations
  2. import math
  3. import gymnasium as gym
  4. import numpy as np
  5. import pytest
  6. from minigrid.core.actions import Actions
  7. from minigrid.envs import EmptyEnv
  8. from minigrid.wrappers import (
  9. ActionBonus,
  10. DictObservationSpaceWrapper,
  11. FlatObsWrapper,
  12. FullyObsWrapper,
  13. ImgObsWrapper,
  14. OneHotPartialObsWrapper,
  15. ReseedWrapper,
  16. RGBImgObsWrapper,
  17. RGBImgPartialObsWrapper,
  18. StateBonus,
  19. ViewSizeWrapper,
  20. )
  21. from tests.utils import all_testing_env_specs, assert_equals, minigrid_testing_env_specs
  22. SEEDS = [100, 243, 500]
  23. NUM_STEPS = 100
  24. @pytest.mark.parametrize(
  25. "env_spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
  26. )
  27. def test_reseed_wrapper(env_spec):
  28. """
  29. Test the ReseedWrapper with a list of SEEDS.
  30. """
  31. unwrapped_env = env_spec.make()
  32. env = env_spec.make()
  33. env = ReseedWrapper(env, seeds=SEEDS)
  34. env.action_space.seed(0)
  35. for seed in SEEDS:
  36. env.reset()
  37. unwrapped_env.reset(seed=seed)
  38. for time_step in range(NUM_STEPS):
  39. action = env.action_space.sample()
  40. obs, rew, terminated, truncated, info = env.step(action)
  41. (
  42. unwrapped_obs,
  43. unwrapped_rew,
  44. unwrapped_terminated,
  45. unwrapped_truncated,
  46. unwrapped_info,
  47. ) = unwrapped_env.step(action)
  48. assert_equals(obs, unwrapped_obs, f"[{time_step}] ")
  49. assert unwrapped_env.observation_space.contains(obs)
  50. assert (
  51. rew == unwrapped_rew
  52. ), f"[{time_step}] reward={rew}, unwrapped reward={unwrapped_rew}"
  53. assert (
  54. terminated == unwrapped_terminated
  55. ), f"[{time_step}] terminated={terminated}, unwrapped terminated={unwrapped_terminated}"
  56. assert (
  57. truncated == unwrapped_truncated
  58. ), f"[{time_step}] truncated={truncated}, unwrapped truncated={unwrapped_truncated}"
  59. assert_equals(info, unwrapped_info, f"[{time_step}] ")
  60. # Start the next seed
  61. if terminated or truncated:
  62. break
  63. env.close()
  64. unwrapped_env.close()
  65. @pytest.mark.parametrize("env_id", ["MiniGrid-Empty-16x16-v0"])
  66. def test_state_bonus_wrapper(env_id):
  67. env = gym.make(env_id)
  68. wrapped_env = StateBonus(gym.make(env_id))
  69. action_forward = Actions.forward
  70. action_left = Actions.left
  71. action_right = Actions.right
  72. for _ in range(10):
  73. wrapped_env.reset()
  74. for _ in range(5):
  75. wrapped_env.step(action_forward)
  76. # Turn lef 3 times (check that actions don't influence bonus)
  77. for _ in range(3):
  78. _, wrapped_rew, _, _, _ = wrapped_env.step(action_left)
  79. env.reset()
  80. for _ in range(5):
  81. env.step(action_forward)
  82. # Turn right 3 times
  83. for _ in range(3):
  84. _, rew, _, _, _ = env.step(action_right)
  85. expected_bonus_reward = rew + 1 / math.sqrt(13)
  86. assert expected_bonus_reward == wrapped_rew
  87. @pytest.mark.parametrize("env_id", ["MiniGrid-Empty-16x16-v0"])
  88. def test_action_bonus_wrapper(env_id):
  89. env = gym.make(env_id)
  90. wrapped_env = ActionBonus(gym.make(env_id))
  91. action = Actions.forward
  92. for _ in range(10):
  93. wrapped_env.reset()
  94. for _ in range(5):
  95. _, wrapped_rew, _, _, _ = wrapped_env.step(action)
  96. env.reset()
  97. for _ in range(5):
  98. _, rew, _, _, _ = env.step(action)
  99. expected_bonus_reward = rew + 1 / math.sqrt(10)
  100. assert expected_bonus_reward == wrapped_rew
  101. @pytest.mark.parametrize(
  102. "env_spec",
  103. minigrid_testing_env_specs,
  104. ids=[spec.id for spec in minigrid_testing_env_specs],
  105. ) # DictObservationSpaceWrapper is not compatible with BabyAI levels. See minigrid/wrappers.py for more details.
  106. def test_dict_observation_space_wrapper(env_spec):
  107. env = env_spec.make()
  108. env = DictObservationSpaceWrapper(env)
  109. env.reset()
  110. mission = env.mission
  111. obs, _, _, _, _ = env.step(0)
  112. assert env.string_to_indices(mission) == [
  113. value for value in obs["mission"] if value != 0
  114. ]
  115. env.close()
  116. @pytest.mark.parametrize(
  117. "wrapper",
  118. [
  119. ReseedWrapper,
  120. ImgObsWrapper,
  121. FlatObsWrapper,
  122. ViewSizeWrapper,
  123. DictObservationSpaceWrapper,
  124. OneHotPartialObsWrapper,
  125. RGBImgPartialObsWrapper,
  126. FullyObsWrapper,
  127. ],
  128. )
  129. @pytest.mark.parametrize(
  130. "env_spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
  131. )
  132. def test_main_wrappers(wrapper, env_spec):
  133. if (
  134. wrapper in (DictObservationSpaceWrapper, FlatObsWrapper)
  135. and env_spec not in minigrid_testing_env_specs
  136. ):
  137. # DictObservationSpaceWrapper and FlatObsWrapper are not compatible with BabyAI levels
  138. # See minigrid/wrappers.py for more details
  139. pytest.skip()
  140. env = env_spec.make()
  141. env = wrapper(env)
  142. for _ in range(10):
  143. env.reset()
  144. env.step(0)
  145. env.close()
  146. @pytest.mark.parametrize(
  147. "wrapper",
  148. [
  149. OneHotPartialObsWrapper,
  150. RGBImgPartialObsWrapper,
  151. FullyObsWrapper,
  152. ],
  153. )
  154. @pytest.mark.parametrize(
  155. "env_spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
  156. )
  157. def test_observation_space_wrappers(wrapper, env_spec):
  158. env = wrapper(env_spec.make(disable_env_checker=True))
  159. obs_space, wrapper_name = env.observation_space, wrapper.__name__
  160. assert isinstance(
  161. obs_space, gym.spaces.Dict
  162. ), f"Observation space for {wrapper_name} is not a Dict: {obs_space}."
  163. # This should not fail either
  164. ImgObsWrapper(env)
  165. env.reset()
  166. env.step(0)
  167. env.close()
  168. class EmptyEnvWithExtraObs(EmptyEnv):
  169. """
  170. Custom environment with an extra observation
  171. """
  172. def __init__(self) -> None:
  173. super().__init__(size=5)
  174. self.observation_space["size"] = gym.spaces.Box(
  175. low=0, high=np.iinfo(np.uint).max, shape=(2,), dtype=np.uint
  176. )
  177. def reset(self, **kwargs):
  178. obs, info = super().reset(**kwargs)
  179. obs["size"] = np.array([self.width, self.height])
  180. return obs, info
  181. def step(self, action):
  182. obs, reward, terminated, truncated, info = super().step(action)
  183. obs["size"] = np.array([self.width, self.height])
  184. return obs, reward, terminated, truncated, info
  185. @pytest.mark.parametrize(
  186. "wrapper",
  187. [
  188. OneHotPartialObsWrapper,
  189. RGBImgObsWrapper,
  190. RGBImgPartialObsWrapper,
  191. FullyObsWrapper,
  192. ],
  193. )
  194. def test_agent_sees_method(wrapper):
  195. env1 = wrapper(EmptyEnvWithExtraObs())
  196. env2 = wrapper(gym.make("MiniGrid-Empty-5x5-v0"))
  197. obs1, _ = env1.reset(seed=0)
  198. obs2, _ = env2.reset(seed=0)
  199. assert "size" in obs1
  200. assert obs1["size"].shape == (2,)
  201. assert (obs1["size"] == [5, 5]).all()
  202. for key in obs2:
  203. assert np.array_equal(obs1[key], obs2[key])
  204. obs1, reward1, terminated1, truncated1, _ = env1.step(0)
  205. obs2, reward2, terminated2, truncated2, _ = env2.step(0)
  206. assert "size" in obs1
  207. assert obs1["size"].shape == (2,)
  208. assert (obs1["size"] == [5, 5]).all()
  209. for key in obs2:
  210. assert np.array_equal(obs1[key], obs2[key])