test_wrappers.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359
  1. from __future__ import annotations
  2. import math
  3. import gymnasium as gym
  4. import numpy as np
  5. import pytest
  6. from minigrid.core.actions import Actions
  7. from minigrid.core.constants import OBJECT_TO_IDX
  8. from minigrid.envs import EmptyEnv
  9. from minigrid.wrappers import (
  10. ActionBonus,
  11. DictObservationSpaceWrapper,
  12. DirectionObsWrapper,
  13. FlatObsWrapper,
  14. FullyObsWrapper,
  15. ImgObsWrapper,
  16. OneHotPartialObsWrapper,
  17. PositionBonus,
  18. ReseedWrapper,
  19. RGBImgObsWrapper,
  20. RGBImgPartialObsWrapper,
  21. StochasticActionWrapper,
  22. SymbolicObsWrapper,
  23. ViewSizeWrapper,
  24. )
  25. from tests.utils import all_testing_env_specs, assert_equals, minigrid_testing_env_specs
  26. SEEDS = [100, 243, 500]
  27. NUM_STEPS = 100
  28. @pytest.mark.parametrize(
  29. "env_spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
  30. )
  31. def test_reseed_wrapper(env_spec):
  32. """
  33. Test the ReseedWrapper with a list of SEEDS.
  34. """
  35. unwrapped_env = env_spec.make()
  36. env = env_spec.make()
  37. env = ReseedWrapper(env, seeds=SEEDS)
  38. env.action_space.seed(0)
  39. for seed in SEEDS:
  40. env.reset()
  41. unwrapped_env.reset(seed=seed)
  42. for time_step in range(NUM_STEPS):
  43. action = env.action_space.sample()
  44. obs, rew, terminated, truncated, info = env.step(action)
  45. (
  46. unwrapped_obs,
  47. unwrapped_rew,
  48. unwrapped_terminated,
  49. unwrapped_truncated,
  50. unwrapped_info,
  51. ) = unwrapped_env.step(action)
  52. assert_equals(obs, unwrapped_obs, f"[{time_step}] ")
  53. assert unwrapped_env.observation_space.contains(obs)
  54. assert (
  55. rew == unwrapped_rew
  56. ), f"[{time_step}] reward={rew}, unwrapped reward={unwrapped_rew}"
  57. assert (
  58. terminated == unwrapped_terminated
  59. ), f"[{time_step}] terminated={terminated}, unwrapped terminated={unwrapped_terminated}"
  60. assert (
  61. truncated == unwrapped_truncated
  62. ), f"[{time_step}] truncated={truncated}, unwrapped truncated={unwrapped_truncated}"
  63. assert_equals(info, unwrapped_info, f"[{time_step}] ")
  64. # Start the next seed
  65. if terminated or truncated:
  66. break
  67. env.close()
  68. unwrapped_env.close()
  69. @pytest.mark.parametrize("env_id", ["MiniGrid-Empty-16x16-v0"])
  70. def test_position_bonus_wrapper(env_id):
  71. env = gym.make(env_id)
  72. wrapped_env = PositionBonus(gym.make(env_id))
  73. action_forward = Actions.forward
  74. action_left = Actions.left
  75. action_right = Actions.right
  76. for _ in range(10):
  77. wrapped_env.reset()
  78. for _ in range(5):
  79. wrapped_env.step(action_forward)
  80. # Turn lef 3 times (check that actions don't influence bonus)
  81. for _ in range(3):
  82. _, wrapped_rew, _, _, _ = wrapped_env.step(action_left)
  83. env.reset()
  84. for _ in range(5):
  85. env.step(action_forward)
  86. # Turn right 3 times
  87. for _ in range(3):
  88. _, rew, _, _, _ = env.step(action_right)
  89. expected_bonus_reward = rew + 1 / math.sqrt(13)
  90. assert expected_bonus_reward == wrapped_rew
  91. @pytest.mark.parametrize("env_id", ["MiniGrid-Empty-16x16-v0"])
  92. def test_action_bonus_wrapper(env_id):
  93. env = gym.make(env_id)
  94. wrapped_env = ActionBonus(gym.make(env_id))
  95. action = Actions.forward
  96. for _ in range(10):
  97. wrapped_env.reset()
  98. for _ in range(5):
  99. _, wrapped_rew, _, _, _ = wrapped_env.step(action)
  100. env.reset()
  101. for _ in range(5):
  102. _, rew, _, _, _ = env.step(action)
  103. expected_bonus_reward = rew + 1 / math.sqrt(10)
  104. assert expected_bonus_reward == wrapped_rew
  105. @pytest.mark.parametrize(
  106. "env_spec",
  107. minigrid_testing_env_specs,
  108. ids=[spec.id for spec in minigrid_testing_env_specs],
  109. ) # DictObservationSpaceWrapper is not compatible with BabyAI levels. See minigrid/wrappers.py for more details.
  110. def test_dict_observation_space_wrapper(env_spec):
  111. env = env_spec.make()
  112. env = DictObservationSpaceWrapper(env)
  113. env.reset()
  114. mission = env.mission
  115. obs, _, _, _, _ = env.step(0)
  116. assert env.string_to_indices(mission) == [
  117. value for value in obs["mission"] if value != 0
  118. ]
  119. env.close()
  120. @pytest.mark.parametrize(
  121. "wrapper",
  122. [
  123. ReseedWrapper,
  124. ImgObsWrapper,
  125. FlatObsWrapper,
  126. ViewSizeWrapper,
  127. DictObservationSpaceWrapper,
  128. OneHotPartialObsWrapper,
  129. RGBImgPartialObsWrapper,
  130. FullyObsWrapper,
  131. ],
  132. )
  133. @pytest.mark.parametrize(
  134. "env_spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
  135. )
  136. def test_main_wrappers(wrapper, env_spec):
  137. if (
  138. wrapper in (DictObservationSpaceWrapper, FlatObsWrapper)
  139. and env_spec not in minigrid_testing_env_specs
  140. ):
  141. # DictObservationSpaceWrapper and FlatObsWrapper are not compatible with BabyAI levels
  142. # See minigrid/wrappers.py for more details
  143. pytest.skip()
  144. env = env_spec.make()
  145. env = wrapper(env)
  146. for _ in range(10):
  147. env.reset()
  148. env.step(0)
  149. env.close()
  150. @pytest.mark.parametrize(
  151. "wrapper",
  152. [
  153. OneHotPartialObsWrapper,
  154. RGBImgPartialObsWrapper,
  155. FullyObsWrapper,
  156. ],
  157. )
  158. @pytest.mark.parametrize(
  159. "env_spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
  160. )
  161. def test_observation_space_wrappers(wrapper, env_spec):
  162. env = wrapper(env_spec.make(disable_env_checker=True))
  163. obs_space, wrapper_name = env.observation_space, wrapper.__name__
  164. assert isinstance(
  165. obs_space, gym.spaces.Dict
  166. ), f"Observation space for {wrapper_name} is not a Dict: {obs_space}."
  167. # This should not fail either
  168. ImgObsWrapper(env)
  169. env.reset()
  170. env.step(0)
  171. env.close()
  172. class EmptyEnvWithExtraObs(EmptyEnv):
  173. """
  174. Custom environment with an extra observation
  175. """
  176. def __init__(self) -> None:
  177. super().__init__(size=5)
  178. self.observation_space["size"] = gym.spaces.Box(
  179. low=0, high=np.iinfo(np.uint).max, shape=(2,), dtype=np.uint
  180. )
  181. def reset(self, **kwargs):
  182. obs, info = super().reset(**kwargs)
  183. obs["size"] = np.array([self.width, self.height])
  184. return obs, info
  185. def step(self, action):
  186. obs, reward, terminated, truncated, info = super().step(action)
  187. obs["size"] = np.array([self.width, self.height])
  188. return obs, reward, terminated, truncated, info
  189. @pytest.mark.parametrize(
  190. "wrapper",
  191. [
  192. OneHotPartialObsWrapper,
  193. RGBImgObsWrapper,
  194. RGBImgPartialObsWrapper,
  195. FullyObsWrapper,
  196. ],
  197. )
  198. def test_agent_sees_method(wrapper):
  199. env1 = wrapper(EmptyEnvWithExtraObs())
  200. env2 = wrapper(gym.make("MiniGrid-Empty-5x5-v0"))
  201. obs1, _ = env1.reset(seed=0)
  202. obs2, _ = env2.reset(seed=0)
  203. assert "size" in obs1
  204. assert obs1["size"].shape == (2,)
  205. assert (obs1["size"] == [5, 5]).all()
  206. for key in obs2:
  207. assert np.array_equal(obs1[key], obs2[key])
  208. obs1, reward1, terminated1, truncated1, _ = env1.step(0)
  209. obs2, reward2, terminated2, truncated2, _ = env2.step(0)
  210. assert "size" in obs1
  211. assert obs1["size"].shape == (2,)
  212. assert (obs1["size"] == [5, 5]).all()
  213. for key in obs2:
  214. assert np.array_equal(obs1[key], obs2[key])
  215. @pytest.mark.parametrize("view_size", [5, 7, 9])
  216. def test_viewsize_wrapper(view_size):
  217. env = gym.make("MiniGrid-Empty-5x5-v0")
  218. env = ViewSizeWrapper(env, agent_view_size=view_size)
  219. env.reset()
  220. obs, _, _, _, _ = env.step(0)
  221. assert obs["image"].shape == (view_size, view_size, 3)
  222. env.close()
  223. @pytest.mark.parametrize("env_id", ["MiniGrid-LavaCrossingS11N5-v0"])
  224. @pytest.mark.parametrize("type", ["slope", "angle"])
  225. def test_direction_obs_wrapper(env_id, type):
  226. env = gym.make(env_id)
  227. env = DirectionObsWrapper(env, type=type)
  228. obs, _ = env.reset()
  229. slope = np.divide(
  230. env.goal_position[1] - env.agent_pos[1],
  231. env.goal_position[0] - env.agent_pos[0],
  232. )
  233. if type == "slope":
  234. assert obs["goal_direction"] == slope
  235. elif type == "angle":
  236. assert obs["goal_direction"] == np.arctan(slope)
  237. obs, _, _, _, _ = env.step(0)
  238. slope = np.divide(
  239. env.goal_position[1] - env.agent_pos[1],
  240. env.goal_position[0] - env.agent_pos[0],
  241. )
  242. if type == "slope":
  243. assert obs["goal_direction"] == slope
  244. elif type == "angle":
  245. assert obs["goal_direction"] == np.arctan(slope)
  246. env.close()
  247. @pytest.mark.parametrize("env_id", ["MiniGrid-DistShift1-v0"])
  248. def test_symbolic_obs_wrapper(env_id):
  249. env = gym.make(env_id)
  250. env = SymbolicObsWrapper(env)
  251. obs, _ = env.reset(seed=123)
  252. agent_pos = env.agent_pos
  253. goal_pos = env.goal_pos
  254. assert obs["image"].shape == (env.width, env.height, 3)
  255. assert np.alltrue(
  256. obs["image"][agent_pos[0], agent_pos[1], :]
  257. == np.array([agent_pos[0], agent_pos[1], OBJECT_TO_IDX["agent"]])
  258. )
  259. assert np.alltrue(
  260. obs["image"][goal_pos[0], goal_pos[1], :]
  261. == np.array([goal_pos[0], goal_pos[1], OBJECT_TO_IDX["goal"]])
  262. )
  263. obs, _, _, _, _ = env.step(2)
  264. agent_pos = env.agent_pos
  265. goal_pos = env.goal_pos
  266. assert obs["image"].shape == (env.width, env.height, 3)
  267. assert np.alltrue(
  268. obs["image"][agent_pos[0], agent_pos[1], :]
  269. == np.array([agent_pos[0], agent_pos[1], OBJECT_TO_IDX["agent"]])
  270. )
  271. assert np.alltrue(
  272. obs["image"][goal_pos[0], goal_pos[1], :]
  273. == np.array([goal_pos[0], goal_pos[1], OBJECT_TO_IDX["goal"]])
  274. )
  275. env.close()
  276. @pytest.mark.parametrize("env_id", ["MiniGrid-Empty-16x16-v0"])
  277. def test_stochastic_action_wrapper(env_id):
  278. env = gym.make(env_id)
  279. env = StochasticActionWrapper(env, prob=0.2)
  280. _, _ = env.reset()
  281. for _ in range(20):
  282. _, _, _, _, _ = env.step(0)
  283. env.close()
  284. env = gym.make(env_id)
  285. env = StochasticActionWrapper(env, prob=0.2, random_action=1)
  286. _, _ = env.reset()
  287. for _ in range(20):
  288. _, _, _, _, _ = env.step(0)
  289. env.close()
  290. def test_dict_observation_space_doesnt_clash_with_one_hot():
  291. env = gym.make("MiniGrid-Empty-5x5-v0")
  292. env = OneHotPartialObsWrapper(env)
  293. env = DictObservationSpaceWrapper(env)
  294. env.reset()
  295. obs, _, _, _, _ = env.step(0)
  296. assert obs["image"].shape == (7, 7, 20)
  297. assert env.observation_space["image"].shape == (7, 7, 20)
  298. env.close()