test_wrappers.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305
  1. from __future__ import annotations
  2. import math
  3. import gymnasium as gym
  4. import numpy as np
  5. import pytest
  6. from minigrid.core.actions import Actions
  7. from minigrid.envs import EmptyEnv
  8. from minigrid.wrappers import (
  9. ActionBonus,
  10. DictObservationSpaceWrapper,
  11. DirectionObsWrapper,
  12. FlatObsWrapper,
  13. FullyObsWrapper,
  14. ImgObsWrapper,
  15. OneHotPartialObsWrapper,
  16. PositionBonus,
  17. ReseedWrapper,
  18. RGBImgObsWrapper,
  19. RGBImgPartialObsWrapper,
  20. SymbolicObsWrapper,
  21. ViewSizeWrapper,
  22. )
  23. from tests.utils import all_testing_env_specs, assert_equals, minigrid_testing_env_specs
  24. SEEDS = [100, 243, 500]
  25. NUM_STEPS = 100
  26. @pytest.mark.parametrize(
  27. "env_spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
  28. )
  29. def test_reseed_wrapper(env_spec):
  30. """
  31. Test the ReseedWrapper with a list of SEEDS.
  32. """
  33. unwrapped_env = env_spec.make()
  34. env = env_spec.make()
  35. env = ReseedWrapper(env, seeds=SEEDS)
  36. env.action_space.seed(0)
  37. for seed in SEEDS:
  38. env.reset()
  39. unwrapped_env.reset(seed=seed)
  40. for time_step in range(NUM_STEPS):
  41. action = env.action_space.sample()
  42. obs, rew, terminated, truncated, info = env.step(action)
  43. (
  44. unwrapped_obs,
  45. unwrapped_rew,
  46. unwrapped_terminated,
  47. unwrapped_truncated,
  48. unwrapped_info,
  49. ) = unwrapped_env.step(action)
  50. assert_equals(obs, unwrapped_obs, f"[{time_step}] ")
  51. assert unwrapped_env.observation_space.contains(obs)
  52. assert (
  53. rew == unwrapped_rew
  54. ), f"[{time_step}] reward={rew}, unwrapped reward={unwrapped_rew}"
  55. assert (
  56. terminated == unwrapped_terminated
  57. ), f"[{time_step}] terminated={terminated}, unwrapped terminated={unwrapped_terminated}"
  58. assert (
  59. truncated == unwrapped_truncated
  60. ), f"[{time_step}] truncated={truncated}, unwrapped truncated={unwrapped_truncated}"
  61. assert_equals(info, unwrapped_info, f"[{time_step}] ")
  62. # Start the next seed
  63. if terminated or truncated:
  64. break
  65. env.close()
  66. unwrapped_env.close()
  67. @pytest.mark.parametrize("env_id", ["MiniGrid-Empty-16x16-v0"])
  68. def test_position_bonus_wrapper(env_id):
  69. env = gym.make(env_id)
  70. wrapped_env = PositionBonus(gym.make(env_id))
  71. action_forward = Actions.forward
  72. action_left = Actions.left
  73. action_right = Actions.right
  74. for _ in range(10):
  75. wrapped_env.reset()
  76. for _ in range(5):
  77. wrapped_env.step(action_forward)
  78. # Turn lef 3 times (check that actions don't influence bonus)
  79. for _ in range(3):
  80. _, wrapped_rew, _, _, _ = wrapped_env.step(action_left)
  81. env.reset()
  82. for _ in range(5):
  83. env.step(action_forward)
  84. # Turn right 3 times
  85. for _ in range(3):
  86. _, rew, _, _, _ = env.step(action_right)
  87. expected_bonus_reward = rew + 1 / math.sqrt(13)
  88. assert expected_bonus_reward == wrapped_rew
  89. @pytest.mark.parametrize("env_id", ["MiniGrid-Empty-16x16-v0"])
  90. def test_action_bonus_wrapper(env_id):
  91. env = gym.make(env_id)
  92. wrapped_env = ActionBonus(gym.make(env_id))
  93. action = Actions.forward
  94. for _ in range(10):
  95. wrapped_env.reset()
  96. for _ in range(5):
  97. _, wrapped_rew, _, _, _ = wrapped_env.step(action)
  98. env.reset()
  99. for _ in range(5):
  100. _, rew, _, _, _ = env.step(action)
  101. expected_bonus_reward = rew + 1 / math.sqrt(10)
  102. assert expected_bonus_reward == wrapped_rew
  103. @pytest.mark.parametrize(
  104. "env_spec",
  105. minigrid_testing_env_specs,
  106. ids=[spec.id for spec in minigrid_testing_env_specs],
  107. ) # DictObservationSpaceWrapper is not compatible with BabyAI levels. See minigrid/wrappers.py for more details.
  108. def test_dict_observation_space_wrapper(env_spec):
  109. env = env_spec.make()
  110. env = DictObservationSpaceWrapper(env)
  111. env.reset()
  112. mission = env.mission
  113. obs, _, _, _, _ = env.step(0)
  114. assert env.string_to_indices(mission) == [
  115. value for value in obs["mission"] if value != 0
  116. ]
  117. env.close()
  118. @pytest.mark.parametrize(
  119. "wrapper",
  120. [
  121. ReseedWrapper,
  122. ImgObsWrapper,
  123. FlatObsWrapper,
  124. ViewSizeWrapper,
  125. DictObservationSpaceWrapper,
  126. OneHotPartialObsWrapper,
  127. RGBImgPartialObsWrapper,
  128. FullyObsWrapper,
  129. ],
  130. )
  131. @pytest.mark.parametrize(
  132. "env_spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
  133. )
  134. def test_main_wrappers(wrapper, env_spec):
  135. if (
  136. wrapper in (DictObservationSpaceWrapper, FlatObsWrapper)
  137. and env_spec not in minigrid_testing_env_specs
  138. ):
  139. # DictObservationSpaceWrapper and FlatObsWrapper are not compatible with BabyAI levels
  140. # See minigrid/wrappers.py for more details
  141. pytest.skip()
  142. env = env_spec.make()
  143. env = wrapper(env)
  144. for _ in range(10):
  145. env.reset()
  146. env.step(0)
  147. env.close()
  148. @pytest.mark.parametrize(
  149. "wrapper",
  150. [
  151. OneHotPartialObsWrapper,
  152. RGBImgPartialObsWrapper,
  153. FullyObsWrapper,
  154. ],
  155. )
  156. @pytest.mark.parametrize(
  157. "env_spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
  158. )
  159. def test_observation_space_wrappers(wrapper, env_spec):
  160. env = wrapper(env_spec.make(disable_env_checker=True))
  161. obs_space, wrapper_name = env.observation_space, wrapper.__name__
  162. assert isinstance(
  163. obs_space, gym.spaces.Dict
  164. ), f"Observation space for {wrapper_name} is not a Dict: {obs_space}."
  165. # This should not fail either
  166. ImgObsWrapper(env)
  167. env.reset()
  168. env.step(0)
  169. env.close()
  170. class EmptyEnvWithExtraObs(EmptyEnv):
  171. """
  172. Custom environment with an extra observation
  173. """
  174. def __init__(self) -> None:
  175. super().__init__(size=5)
  176. self.observation_space["size"] = gym.spaces.Box(
  177. low=0, high=np.iinfo(np.uint).max, shape=(2,), dtype=np.uint
  178. )
  179. def reset(self, **kwargs):
  180. obs, info = super().reset(**kwargs)
  181. obs["size"] = np.array([self.width, self.height])
  182. return obs, info
  183. def step(self, action):
  184. obs, reward, terminated, truncated, info = super().step(action)
  185. obs["size"] = np.array([self.width, self.height])
  186. return obs, reward, terminated, truncated, info
  187. @pytest.mark.parametrize(
  188. "wrapper",
  189. [
  190. OneHotPartialObsWrapper,
  191. RGBImgObsWrapper,
  192. RGBImgPartialObsWrapper,
  193. FullyObsWrapper,
  194. ],
  195. )
  196. def test_agent_sees_method(wrapper):
  197. env1 = wrapper(EmptyEnvWithExtraObs())
  198. env2 = wrapper(gym.make("MiniGrid-Empty-5x5-v0"))
  199. obs1, _ = env1.reset(seed=0)
  200. obs2, _ = env2.reset(seed=0)
  201. assert "size" in obs1
  202. assert obs1["size"].shape == (2,)
  203. assert (obs1["size"] == [5, 5]).all()
  204. for key in obs2:
  205. assert np.array_equal(obs1[key], obs2[key])
  206. obs1, reward1, terminated1, truncated1, _ = env1.step(0)
  207. obs2, reward2, terminated2, truncated2, _ = env2.step(0)
  208. assert "size" in obs1
  209. assert obs1["size"].shape == (2,)
  210. assert (obs1["size"] == [5, 5]).all()
  211. for key in obs2:
  212. assert np.array_equal(obs1[key], obs2[key])
  213. @pytest.mark.parametrize("view_size", [5, 7, 9])
  214. def test_viewsize_wrapper(view_size):
  215. env = gym.make("MiniGrid-Empty-5x5-v0")
  216. env = ViewSizeWrapper(env, agent_view_size=view_size)
  217. env.reset()
  218. obs, _, _, _, _ = env.step(0)
  219. assert obs["image"].shape == (view_size, view_size, 3)
  220. env.close()
  221. @pytest.mark.parametrize("env_id", ["MiniGrid-LavaCrossingS11N5-v0"])
  222. @pytest.mark.parametrize("type", ["slope", "angle"])
  223. def test_direction_obs_wrapper(env_id, type):
  224. env = gym.make(env_id)
  225. env = DirectionObsWrapper(env, type=type)
  226. obs = env.reset()
  227. slope = np.divide(
  228. env.goal_position[1] - env.agent_pos[1],
  229. env.goal_position[0] - env.agent_pos[0],
  230. )
  231. if type == "slope":
  232. assert obs["goal_direction"] == slope
  233. elif type == "angle":
  234. assert obs["goal_direction"] == np.arctan(slope)
  235. obs, _, _, _, _ = env.step(0)
  236. slope = np.divide(
  237. env.goal_position[1] - env.agent_pos[1],
  238. env.goal_position[0] - env.agent_pos[0],
  239. )
  240. if type == "slope":
  241. assert obs["goal_direction"] == slope
  242. elif type == "angle":
  243. assert obs["goal_direction"] == np.arctan(slope)
  244. env.close()
  245. @pytest.mark.parametrize("env_id", ["MiniGrid-Empty-16x16-v0"])
  246. def test_symbolic_obs_wrapper(env_id):
  247. env = gym.make(env_id)
  248. env = SymbolicObsWrapper(env)
  249. obs, _ = env.reset()
  250. assert obs["image"].shape == (env.width, env.height, 3)
  251. obs, _, _, _, _ = env.step(0)
  252. assert obs["image"].shape == (env.width, env.height, 3)
  253. env.close()