test_wrappers.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330
  1. from __future__ import annotations
  2. import math
  3. import gymnasium as gym
  4. import numpy as np
  5. import pytest
  6. from minigrid.core.actions import Actions
  7. from minigrid.core.constants import OBJECT_TO_IDX
  8. from minigrid.envs import EmptyEnv
  9. from minigrid.wrappers import (
  10. ActionBonus,
  11. DictObservationSpaceWrapper,
  12. DirectionObsWrapper,
  13. FlatObsWrapper,
  14. FullyObsWrapper,
  15. ImgObsWrapper,
  16. OneHotPartialObsWrapper,
  17. PositionBonus,
  18. ReseedWrapper,
  19. RGBImgObsWrapper,
  20. RGBImgPartialObsWrapper,
  21. SymbolicObsWrapper,
  22. ViewSizeWrapper,
  23. )
  24. from tests.utils import all_testing_env_specs, assert_equals, minigrid_testing_env_specs
  25. SEEDS = [100, 243, 500]
  26. NUM_STEPS = 100
  27. @pytest.mark.parametrize(
  28. "env_spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
  29. )
  30. def test_reseed_wrapper(env_spec):
  31. """
  32. Test the ReseedWrapper with a list of SEEDS.
  33. """
  34. unwrapped_env = env_spec.make()
  35. env = env_spec.make()
  36. env = ReseedWrapper(env, seeds=SEEDS)
  37. env.action_space.seed(0)
  38. for seed in SEEDS:
  39. env.reset()
  40. unwrapped_env.reset(seed=seed)
  41. for time_step in range(NUM_STEPS):
  42. action = env.action_space.sample()
  43. obs, rew, terminated, truncated, info = env.step(action)
  44. (
  45. unwrapped_obs,
  46. unwrapped_rew,
  47. unwrapped_terminated,
  48. unwrapped_truncated,
  49. unwrapped_info,
  50. ) = unwrapped_env.step(action)
  51. assert_equals(obs, unwrapped_obs, f"[{time_step}] ")
  52. assert unwrapped_env.observation_space.contains(obs)
  53. assert (
  54. rew == unwrapped_rew
  55. ), f"[{time_step}] reward={rew}, unwrapped reward={unwrapped_rew}"
  56. assert (
  57. terminated == unwrapped_terminated
  58. ), f"[{time_step}] terminated={terminated}, unwrapped terminated={unwrapped_terminated}"
  59. assert (
  60. truncated == unwrapped_truncated
  61. ), f"[{time_step}] truncated={truncated}, unwrapped truncated={unwrapped_truncated}"
  62. assert_equals(info, unwrapped_info, f"[{time_step}] ")
  63. # Start the next seed
  64. if terminated or truncated:
  65. break
  66. env.close()
  67. unwrapped_env.close()
  68. @pytest.mark.parametrize("env_id", ["MiniGrid-Empty-16x16-v0"])
  69. def test_position_bonus_wrapper(env_id):
  70. env = gym.make(env_id)
  71. wrapped_env = PositionBonus(gym.make(env_id))
  72. action_forward = Actions.forward
  73. action_left = Actions.left
  74. action_right = Actions.right
  75. for _ in range(10):
  76. wrapped_env.reset()
  77. for _ in range(5):
  78. wrapped_env.step(action_forward)
  79. # Turn lef 3 times (check that actions don't influence bonus)
  80. for _ in range(3):
  81. _, wrapped_rew, _, _, _ = wrapped_env.step(action_left)
  82. env.reset()
  83. for _ in range(5):
  84. env.step(action_forward)
  85. # Turn right 3 times
  86. for _ in range(3):
  87. _, rew, _, _, _ = env.step(action_right)
  88. expected_bonus_reward = rew + 1 / math.sqrt(13)
  89. assert expected_bonus_reward == wrapped_rew
  90. @pytest.mark.parametrize("env_id", ["MiniGrid-Empty-16x16-v0"])
  91. def test_action_bonus_wrapper(env_id):
  92. env = gym.make(env_id)
  93. wrapped_env = ActionBonus(gym.make(env_id))
  94. action = Actions.forward
  95. for _ in range(10):
  96. wrapped_env.reset()
  97. for _ in range(5):
  98. _, wrapped_rew, _, _, _ = wrapped_env.step(action)
  99. env.reset()
  100. for _ in range(5):
  101. _, rew, _, _, _ = env.step(action)
  102. expected_bonus_reward = rew + 1 / math.sqrt(10)
  103. assert expected_bonus_reward == wrapped_rew
  104. @pytest.mark.parametrize(
  105. "env_spec",
  106. minigrid_testing_env_specs,
  107. ids=[spec.id for spec in minigrid_testing_env_specs],
  108. ) # DictObservationSpaceWrapper is not compatible with BabyAI levels. See minigrid/wrappers.py for more details.
  109. def test_dict_observation_space_wrapper(env_spec):
  110. env = env_spec.make()
  111. env = DictObservationSpaceWrapper(env)
  112. env.reset()
  113. mission = env.mission
  114. obs, _, _, _, _ = env.step(0)
  115. assert env.string_to_indices(mission) == [
  116. value for value in obs["mission"] if value != 0
  117. ]
  118. env.close()
  119. @pytest.mark.parametrize(
  120. "wrapper",
  121. [
  122. ReseedWrapper,
  123. ImgObsWrapper,
  124. FlatObsWrapper,
  125. ViewSizeWrapper,
  126. DictObservationSpaceWrapper,
  127. OneHotPartialObsWrapper,
  128. RGBImgPartialObsWrapper,
  129. FullyObsWrapper,
  130. ],
  131. )
  132. @pytest.mark.parametrize(
  133. "env_spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
  134. )
  135. def test_main_wrappers(wrapper, env_spec):
  136. if (
  137. wrapper in (DictObservationSpaceWrapper, FlatObsWrapper)
  138. and env_spec not in minigrid_testing_env_specs
  139. ):
  140. # DictObservationSpaceWrapper and FlatObsWrapper are not compatible with BabyAI levels
  141. # See minigrid/wrappers.py for more details
  142. pytest.skip()
  143. env = env_spec.make()
  144. env = wrapper(env)
  145. for _ in range(10):
  146. env.reset()
  147. env.step(0)
  148. env.close()
  149. @pytest.mark.parametrize(
  150. "wrapper",
  151. [
  152. OneHotPartialObsWrapper,
  153. RGBImgPartialObsWrapper,
  154. FullyObsWrapper,
  155. ],
  156. )
  157. @pytest.mark.parametrize(
  158. "env_spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
  159. )
  160. def test_observation_space_wrappers(wrapper, env_spec):
  161. env = wrapper(env_spec.make(disable_env_checker=True))
  162. obs_space, wrapper_name = env.observation_space, wrapper.__name__
  163. assert isinstance(
  164. obs_space, gym.spaces.Dict
  165. ), f"Observation space for {wrapper_name} is not a Dict: {obs_space}."
  166. # This should not fail either
  167. ImgObsWrapper(env)
  168. env.reset()
  169. env.step(0)
  170. env.close()
  171. class EmptyEnvWithExtraObs(EmptyEnv):
  172. """
  173. Custom environment with an extra observation
  174. """
  175. def __init__(self) -> None:
  176. super().__init__(size=5)
  177. self.observation_space["size"] = gym.spaces.Box(
  178. low=0, high=np.iinfo(np.uint).max, shape=(2,), dtype=np.uint
  179. )
  180. def reset(self, **kwargs):
  181. obs, info = super().reset(**kwargs)
  182. obs["size"] = np.array([self.width, self.height])
  183. return obs, info
  184. def step(self, action):
  185. obs, reward, terminated, truncated, info = super().step(action)
  186. obs["size"] = np.array([self.width, self.height])
  187. return obs, reward, terminated, truncated, info
  188. @pytest.mark.parametrize(
  189. "wrapper",
  190. [
  191. OneHotPartialObsWrapper,
  192. RGBImgObsWrapper,
  193. RGBImgPartialObsWrapper,
  194. FullyObsWrapper,
  195. ],
  196. )
  197. def test_agent_sees_method(wrapper):
  198. env1 = wrapper(EmptyEnvWithExtraObs())
  199. env2 = wrapper(gym.make("MiniGrid-Empty-5x5-v0"))
  200. obs1, _ = env1.reset(seed=0)
  201. obs2, _ = env2.reset(seed=0)
  202. assert "size" in obs1
  203. assert obs1["size"].shape == (2,)
  204. assert (obs1["size"] == [5, 5]).all()
  205. for key in obs2:
  206. assert np.array_equal(obs1[key], obs2[key])
  207. obs1, reward1, terminated1, truncated1, _ = env1.step(0)
  208. obs2, reward2, terminated2, truncated2, _ = env2.step(0)
  209. assert "size" in obs1
  210. assert obs1["size"].shape == (2,)
  211. assert (obs1["size"] == [5, 5]).all()
  212. for key in obs2:
  213. assert np.array_equal(obs1[key], obs2[key])
  214. @pytest.mark.parametrize("view_size", [5, 7, 9])
  215. def test_viewsize_wrapper(view_size):
  216. env = gym.make("MiniGrid-Empty-5x5-v0")
  217. env = ViewSizeWrapper(env, agent_view_size=view_size)
  218. env.reset()
  219. obs, _, _, _, _ = env.step(0)
  220. assert obs["image"].shape == (view_size, view_size, 3)
  221. env.close()
  222. @pytest.mark.parametrize("env_id", ["MiniGrid-LavaCrossingS11N5-v0"])
  223. @pytest.mark.parametrize("type", ["slope", "angle"])
  224. def test_direction_obs_wrapper(env_id, type):
  225. env = gym.make(env_id)
  226. env = DirectionObsWrapper(env, type=type)
  227. obs, _ = env.reset()
  228. slope = np.divide(
  229. env.goal_position[1] - env.agent_pos[1],
  230. env.goal_position[0] - env.agent_pos[0],
  231. )
  232. if type == "slope":
  233. assert obs["goal_direction"] == slope
  234. elif type == "angle":
  235. assert obs["goal_direction"] == np.arctan(slope)
  236. obs, _, _, _, _ = env.step(0)
  237. slope = np.divide(
  238. env.goal_position[1] - env.agent_pos[1],
  239. env.goal_position[0] - env.agent_pos[0],
  240. )
  241. if type == "slope":
  242. assert obs["goal_direction"] == slope
  243. elif type == "angle":
  244. assert obs["goal_direction"] == np.arctan(slope)
  245. env.close()
  246. @pytest.mark.parametrize("env_id", ["MiniGrid-DistShift1-v0"])
  247. def test_symbolic_obs_wrapper(env_id):
  248. env = gym.make(env_id)
  249. env = SymbolicObsWrapper(env)
  250. obs, _ = env.reset(seed=123)
  251. agent_pos = env.agent_pos
  252. goal_pos = env.goal_pos
  253. assert obs["image"].shape == (env.width, env.height, 3)
  254. assert np.alltrue(
  255. obs["image"][agent_pos[0], agent_pos[1], :]
  256. == np.array([agent_pos[0], agent_pos[1], OBJECT_TO_IDX["agent"]])
  257. )
  258. assert np.alltrue(
  259. obs["image"][goal_pos[0], goal_pos[1], :]
  260. == np.array([goal_pos[0], goal_pos[1], OBJECT_TO_IDX["goal"]])
  261. )
  262. obs, _, _, _, _ = env.step(2)
  263. agent_pos = env.agent_pos
  264. goal_pos = env.goal_pos
  265. assert obs["image"].shape == (env.width, env.height, 3)
  266. assert np.alltrue(
  267. obs["image"][agent_pos[0], agent_pos[1], :]
  268. == np.array([agent_pos[0], agent_pos[1], OBJECT_TO_IDX["agent"]])
  269. )
  270. assert np.alltrue(
  271. obs["image"][goal_pos[0], goal_pos[1], :]
  272. == np.array([goal_pos[0], goal_pos[1], OBJECT_TO_IDX["goal"]])
  273. )
  274. env.close()