test_wrappers.py 7.1 KB


  1. import math
  2. import gymnasium as gym
  3. import numpy as np
  4. import pytest
  5. from minigrid.core.actions import Actions
  6. from minigrid.envs import EmptyEnv
  7. from minigrid.wrappers import (
  8. ActionBonus,
  9. DictObservationSpaceWrapper,
  10. FlatObsWrapper,
  11. FullyObsWrapper,
  12. ImgObsWrapper,
  13. OneHotPartialObsWrapper,
  14. ReseedWrapper,
  15. RGBImgObsWrapper,
  16. RGBImgPartialObsWrapper,
  17. StateBonus,
  18. ViewSizeWrapper,
  19. )
  20. from tests.utils import all_testing_env_specs, assert_equals, minigrid_testing_env_specs
  21. SEEDS = [100, 243, 500]
  22. NUM_STEPS = 100
  23. @pytest.mark.parametrize(
  24. "env_spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
  25. )
  26. def test_reseed_wrapper(env_spec):
  27. """
  28. Test the ReseedWrapper with a list of SEEDS.
  29. """
  30. unwrapped_env = env_spec.make()
  31. env = env_spec.make()
  32. env = ReseedWrapper(env, seeds=SEEDS)
  33. env.action_space.seed(0)
  34. for seed in SEEDS:
  35. env.reset()
  36. unwrapped_env.reset(seed=seed)
  37. for time_step in range(NUM_STEPS):
  38. action = env.action_space.sample()
  39. obs, rew, terminated, truncated, info = env.step(action)
  40. (
  41. unwrapped_obs,
  42. unwrapped_rew,
  43. unwrapped_terminated,
  44. unwrapped_truncated,
  45. unwrapped_info,
  46. ) = unwrapped_env.step(action)
  47. assert_equals(obs, unwrapped_obs, f"[{time_step}] ")
  48. assert unwrapped_env.observation_space.contains(obs)
  49. assert (
  50. rew == unwrapped_rew
  51. ), f"[{time_step}] reward={rew}, unwrapped reward={unwrapped_rew}"
  52. assert (
  53. terminated == unwrapped_terminated
  54. ), f"[{time_step}] terminated={terminated}, unwrapped terminated={unwrapped_terminated}"
  55. assert (
  56. truncated == unwrapped_truncated
  57. ), f"[{time_step}] truncated={truncated}, unwrapped truncated={unwrapped_truncated}"
  58. assert_equals(info, unwrapped_info, f"[{time_step}] ")
  59. # Start the next seed
  60. if terminated or truncated:
  61. break
  62. env.close()
  63. unwrapped_env.close()
  64. @pytest.mark.parametrize("env_id", ["MiniGrid-Empty-16x16-v0"])
  65. def test_state_bonus_wrapper(env_id):
  66. env = gym.make(env_id)
  67. wrapped_env = StateBonus(gym.make(env_id))
  68. action_forward = Actions.forward
  69. action_left = Actions.left
  70. action_right = Actions.right
  71. for _ in range(10):
  72. wrapped_env.reset()
  73. for _ in range(5):
  74. wrapped_env.step(action_forward)
  75. # Turn lef 3 times (check that actions don't influence bonus)
  76. for _ in range(3):
  77. _, wrapped_rew, _, _, _ = wrapped_env.step(action_left)
  78. env.reset()
  79. for _ in range(5):
  80. env.step(action_forward)
  81. # Turn right 3 times
  82. for _ in range(3):
  83. _, rew, _, _, _ = env.step(action_right)
  84. expected_bonus_reward = rew + 1 / math.sqrt(13)
  85. assert expected_bonus_reward == wrapped_rew
  86. @pytest.mark.parametrize("env_id", ["MiniGrid-Empty-16x16-v0"])
  87. def test_action_bonus_wrapper(env_id):
  88. env = gym.make(env_id)
  89. wrapped_env = ActionBonus(gym.make(env_id))
  90. action = Actions.forward
  91. for _ in range(10):
  92. wrapped_env.reset()
  93. for _ in range(5):
  94. _, wrapped_rew, _, _, _ = wrapped_env.step(action)
  95. env.reset()
  96. for _ in range(5):
  97. _, rew, _, _, _ = env.step(action)
  98. expected_bonus_reward = rew + 1 / math.sqrt(10)
  99. assert expected_bonus_reward == wrapped_rew
  100. @pytest.mark.parametrize(
  101. "env_spec",
  102. minigrid_testing_env_specs,
  103. ids=[spec.id for spec in minigrid_testing_env_specs],
  104. ) # DictObservationSpaceWrapper is not compatible with BabyAI levels. See minigrid/wrappers.py for more details.
  105. def test_dict_observation_space_wrapper(env_spec):
  106. env = env_spec.make()
  107. env = DictObservationSpaceWrapper(env)
  108. env.reset()
  109. mission = env.mission
  110. obs, _, _, _, _ = env.step(0)
  111. assert env.string_to_indices(mission) == [
  112. value for value in obs["mission"] if value != 0
  113. ]
  114. env.close()
  115. @pytest.mark.parametrize(
  116. "wrapper",
  117. [
  118. ReseedWrapper,
  119. ImgObsWrapper,
  120. FlatObsWrapper,
  121. ViewSizeWrapper,
  122. DictObservationSpaceWrapper,
  123. OneHotPartialObsWrapper,
  124. RGBImgPartialObsWrapper,
  125. FullyObsWrapper,
  126. ],
  127. )
  128. @pytest.mark.parametrize(
  129. "env_spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
  130. )
  131. def test_main_wrappers(wrapper, env_spec):
  132. if (
  133. wrapper in (DictObservationSpaceWrapper, FlatObsWrapper)
  134. and env_spec not in minigrid_testing_env_specs
  135. ):
  136. # DictObservationSpaceWrapper and FlatObsWrapper are not compatible with BabyAI levels
  137. # See minigrid/wrappers.py for more details
  138. pytest.skip()
  139. env = env_spec.make()
  140. env = wrapper(env)
  141. for _ in range(10):
  142. env.reset()
  143. env.step(0)
  144. env.close()
  145. @pytest.mark.parametrize(
  146. "wrapper",
  147. [
  148. OneHotPartialObsWrapper,
  149. RGBImgPartialObsWrapper,
  150. FullyObsWrapper,
  151. ],
  152. )
  153. @pytest.mark.parametrize(
  154. "env_spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
  155. )
  156. def test_observation_space_wrappers(wrapper, env_spec):
  157. env = wrapper(env_spec.make(disable_env_checker=True))
  158. obs_space, wrapper_name = env.observation_space, wrapper.__name__
  159. assert isinstance(
  160. obs_space, gym.spaces.Dict
  161. ), f"Observation space for {wrapper_name} is not a Dict: {obs_space}."
  162. # This should not fail either
  163. ImgObsWrapper(env)
  164. env.reset()
  165. env.step(0)
  166. env.close()
  167. class EmptyEnvWithExtraObs(EmptyEnv):
  168. """
  169. Custom environment with an extra observation
  170. """
  171. def __init__(self) -> None:
  172. super().__init__(size=5)
  173. self.observation_space["size"] = gym.spaces.Box(
  174. low=0, high=np.iinfo(np.uint).max, shape=(2,), dtype=np.uint
  175. )
  176. def reset(self, **kwargs):
  177. obs, info = super().reset(**kwargs)
  178. obs["size"] = np.array([self.width, self.height])
  179. return obs, info
  180. def step(self, action):
  181. obs, reward, terminated, truncated, info = super().step(action)
  182. obs["size"] = np.array([self.width, self.height])
  183. return obs, reward, terminated, truncated, info
  184. @pytest.mark.parametrize(
  185. "wrapper",
  186. [
  187. OneHotPartialObsWrapper,
  188. RGBImgObsWrapper,
  189. RGBImgPartialObsWrapper,
  190. FullyObsWrapper,
  191. ],
  192. )
  193. def test_agent_sees_method(wrapper):
  194. env1 = wrapper(EmptyEnvWithExtraObs())
  195. env2 = wrapper(gym.make("MiniGrid-Empty-5x5-v0"))
  196. obs1, _ = env1.reset(seed=0)
  197. obs2, _ = env2.reset(seed=0)
  198. assert "size" in obs1
  199. assert obs1["size"].shape == (2,)
  200. assert (obs1["size"] == [5, 5]).all()
  201. for key in obs2:
  202. assert np.array_equal(obs1[key], obs2[key])
  203. obs1, reward1, terminated1, truncated1, _ = env1.step(0)
  204. obs2, reward2, terminated2, truncated2, _ = env2.step(0)
  205. assert "size" in obs1
  206. assert obs1["size"].shape == (2,)
  207. assert (obs1["size"] == [5, 5]).all()
  208. for key in obs2:
  209. assert np.array_equal(obs1[key], obs2[key])