test_wrappers.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404
  1. from __future__ import annotations
  2. import math
  3. import warnings
  4. import gymnasium as gym
  5. import numpy as np
  6. import pytest
  7. from minigrid.core.actions import Actions
  8. from minigrid.core.constants import OBJECT_TO_IDX
  9. from minigrid.envs import EmptyEnv
  10. from minigrid.wrappers import (
  11. ActionBonus,
  12. DictObservationSpaceWrapper,
  13. DirectionObsWrapper,
  14. FlatObsWrapper,
  15. FullyObsWrapper,
  16. ImgObsWrapper,
  17. NoDeath,
  18. OneHotPartialObsWrapper,
  19. PositionBonus,
  20. ReseedWrapper,
  21. RGBImgObsWrapper,
  22. RGBImgPartialObsWrapper,
  23. StochasticActionWrapper,
  24. SymbolicObsWrapper,
  25. ViewSizeWrapper,
  26. )
  27. from tests.utils import all_testing_env_specs, assert_equals, minigrid_testing_env_specs
  28. SEEDS = [100, 243, 500]
  29. NUM_STEPS = 100
  30. @pytest.mark.parametrize(
  31. "env_spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
  32. )
  33. def test_reseed_wrapper(env_spec):
  34. """
  35. Test the ReseedWrapper with a list of SEEDS.
  36. """
  37. unwrapped_env = env_spec.make()
  38. env = ReseedWrapper(env_spec.make(), seeds=SEEDS)
  39. env.action_space.seed(0)
  40. for seed in SEEDS:
  41. env.reset()
  42. unwrapped_env.reset(seed=seed)
  43. for time_step in range(NUM_STEPS):
  44. action = env.action_space.sample()
  45. obs, rew, terminated, truncated, info = env.step(action)
  46. (
  47. unwrapped_obs,
  48. unwrapped_rew,
  49. unwrapped_terminated,
  50. unwrapped_truncated,
  51. unwrapped_info,
  52. ) = unwrapped_env.step(action)
  53. assert_equals(obs, unwrapped_obs, f"[{time_step}] ")
  54. assert unwrapped_env.observation_space.contains(obs)
  55. assert (
  56. rew == unwrapped_rew
  57. ), f"[{time_step}] reward={rew}, unwrapped reward={unwrapped_rew}"
  58. assert (
  59. terminated == unwrapped_terminated
  60. ), f"[{time_step}] terminated={terminated}, unwrapped terminated={unwrapped_terminated}"
  61. assert (
  62. truncated == unwrapped_truncated
  63. ), f"[{time_step}] truncated={truncated}, unwrapped truncated={unwrapped_truncated}"
  64. assert_equals(info, unwrapped_info, f"[{time_step}] ")
  65. # Start the next seed
  66. if terminated or truncated:
  67. break
  68. env.close()
  69. unwrapped_env.close()
  70. @pytest.mark.parametrize("env_id", ["MiniGrid-Empty-16x16-v0"])
  71. def test_position_bonus_wrapper(env_id):
  72. env = gym.make(env_id)
  73. wrapped_env = PositionBonus(gym.make(env_id))
  74. action_forward = Actions.forward
  75. action_left = Actions.left
  76. action_right = Actions.right
  77. for _ in range(10):
  78. wrapped_env.reset()
  79. for _ in range(5):
  80. wrapped_env.step(action_forward)
  81. # Turn lef 3 times (check that actions don't influence bonus)
  82. for _ in range(3):
  83. _, wrapped_rew, _, _, _ = wrapped_env.step(action_left)
  84. env.reset()
  85. for _ in range(5):
  86. env.step(action_forward)
  87. # Turn right 3 times
  88. for _ in range(3):
  89. _, rew, _, _, _ = env.step(action_right)
  90. expected_bonus_reward = rew + 1 / math.sqrt(13)
  91. assert expected_bonus_reward == wrapped_rew
  92. @pytest.mark.parametrize("env_id", ["MiniGrid-Empty-16x16-v0"])
  93. def test_action_bonus_wrapper(env_id):
  94. env = gym.make(env_id)
  95. wrapped_env = ActionBonus(gym.make(env_id))
  96. action = Actions.forward
  97. for _ in range(10):
  98. wrapped_env.reset()
  99. for _ in range(5):
  100. _, wrapped_rew, _, _, _ = wrapped_env.step(action)
  101. env.reset()
  102. for _ in range(5):
  103. _, rew, _, _, _ = env.step(action)
  104. expected_bonus_reward = rew + 1 / math.sqrt(10)
  105. assert expected_bonus_reward == wrapped_rew
  106. @pytest.mark.parametrize(
  107. "env_spec",
  108. minigrid_testing_env_specs,
  109. ids=[spec.id for spec in minigrid_testing_env_specs],
  110. ) # DictObservationSpaceWrapper is not compatible with BabyAI levels. See minigrid/wrappers.py for more details.
  111. def test_dict_observation_space_wrapper(env_spec):
  112. env = env_spec.make()
  113. env = DictObservationSpaceWrapper(env)
  114. env.reset()
  115. mission = env.unwrapped.mission
  116. obs, _, _, _, _ = env.step(0)
  117. assert env.string_to_indices(mission) == [
  118. value for value in obs["mission"] if value != 0
  119. ]
  120. env.close()
  121. @pytest.mark.parametrize(
  122. "wrapper",
  123. [
  124. ReseedWrapper,
  125. ImgObsWrapper,
  126. FlatObsWrapper,
  127. ViewSizeWrapper,
  128. DictObservationSpaceWrapper,
  129. OneHotPartialObsWrapper,
  130. RGBImgPartialObsWrapper,
  131. FullyObsWrapper,
  132. ],
  133. )
  134. @pytest.mark.parametrize(
  135. "env_spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
  136. )
  137. def test_main_wrappers(wrapper, env_spec):
  138. if (
  139. wrapper in (DictObservationSpaceWrapper, FlatObsWrapper)
  140. and env_spec not in minigrid_testing_env_specs
  141. ):
  142. # DictObservationSpaceWrapper and FlatObsWrapper are not compatible with BabyAI levels
  143. # See minigrid/wrappers.py for more details
  144. pytest.skip()
  145. env = env_spec.make()
  146. env = wrapper(env)
  147. with warnings.catch_warnings():
  148. env.reset(seed=123)
  149. env.step(0)
  150. env.close()
  151. @pytest.mark.parametrize(
  152. "wrapper",
  153. [
  154. OneHotPartialObsWrapper,
  155. RGBImgPartialObsWrapper,
  156. FullyObsWrapper,
  157. ],
  158. )
  159. @pytest.mark.parametrize(
  160. "env_spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
  161. )
  162. def test_observation_space_wrappers(wrapper, env_spec):
  163. env = wrapper(env_spec.make(disable_env_checker=True))
  164. obs_space, wrapper_name = env.observation_space, wrapper.__name__
  165. assert isinstance(
  166. obs_space, gym.spaces.Dict
  167. ), f"Observation space for {wrapper_name} is not a Dict: {obs_space}."
  168. # This should not fail either
  169. ImgObsWrapper(env)
  170. env.reset()
  171. env.step(0)
  172. env.close()
  173. class EmptyEnvWithExtraObs(EmptyEnv):
  174. """
  175. Custom environment with an extra observation
  176. """
  177. def __init__(self) -> None:
  178. super().__init__(size=5)
  179. self.observation_space["size"] = gym.spaces.Box(
  180. low=0, high=np.iinfo(np.uint).max, shape=(2,), dtype=np.uint
  181. )
  182. def reset(self, **kwargs):
  183. obs, info = super().reset(**kwargs)
  184. obs["size"] = np.array([self.width, self.height])
  185. return obs, info
  186. def step(self, action):
  187. obs, reward, terminated, truncated, info = super().step(action)
  188. obs["size"] = np.array([self.width, self.height])
  189. return obs, reward, terminated, truncated, info
  190. @pytest.mark.parametrize(
  191. "wrapper",
  192. [
  193. OneHotPartialObsWrapper,
  194. RGBImgObsWrapper,
  195. RGBImgPartialObsWrapper,
  196. FullyObsWrapper,
  197. ],
  198. )
  199. def test_agent_sees_method(wrapper):
  200. env1 = wrapper(EmptyEnvWithExtraObs())
  201. env2 = wrapper(gym.make("MiniGrid-Empty-5x5-v0"))
  202. obs1, _ = env1.reset(seed=0)
  203. obs2, _ = env2.reset(seed=0)
  204. assert "size" in obs1
  205. assert obs1["size"].shape == (2,)
  206. assert (obs1["size"] == [5, 5]).all()
  207. for key in obs2:
  208. assert np.array_equal(obs1[key], obs2[key])
  209. obs1, reward1, terminated1, truncated1, _ = env1.step(0)
  210. obs2, reward2, terminated2, truncated2, _ = env2.step(0)
  211. assert "size" in obs1
  212. assert obs1["size"].shape == (2,)
  213. assert (obs1["size"] == [5, 5]).all()
  214. for key in obs2:
  215. assert np.array_equal(obs1[key], obs2[key])
  216. @pytest.mark.parametrize("view_size", [5, 7, 9])
  217. def test_viewsize_wrapper(view_size):
  218. env = gym.make("MiniGrid-Empty-5x5-v0")
  219. env = ViewSizeWrapper(env, agent_view_size=view_size)
  220. env.reset()
  221. obs, _, _, _, _ = env.step(0)
  222. assert obs["image"].shape == (view_size, view_size, 3)
  223. env.close()
  224. @pytest.mark.parametrize("env_id", ["MiniGrid-LavaCrossingS11N5-v0"])
  225. @pytest.mark.parametrize("type", ["slope", "angle"])
  226. def test_direction_obs_wrapper(env_id, type):
  227. env = gym.make(env_id)
  228. env = DirectionObsWrapper(env, type=type)
  229. obs, _ = env.reset()
  230. slope = np.divide(
  231. env.unwrapped.goal_position[1] - env.unwrapped.agent_pos[1],
  232. env.unwrapped.goal_position[0] - env.unwrapped.agent_pos[0],
  233. )
  234. if type == "slope":
  235. assert obs["goal_direction"] == slope
  236. elif type == "angle":
  237. assert obs["goal_direction"] == np.arctan(slope)
  238. obs, _, _, _, _ = env.step(0)
  239. slope = np.divide(
  240. env.unwrapped.goal_position[1] - env.unwrapped.agent_pos[1],
  241. env.unwrapped.goal_position[0] - env.unwrapped.agent_pos[0],
  242. )
  243. if type == "slope":
  244. assert obs["goal_direction"] == slope
  245. elif type == "angle":
  246. assert obs["goal_direction"] == np.arctan(slope)
  247. env.close()
  248. @pytest.mark.parametrize("env_id", ["MiniGrid-DistShift1-v0"])
  249. def test_symbolic_obs_wrapper(env_id):
  250. env = gym.make(env_id)
  251. env = SymbolicObsWrapper(env)
  252. obs, _ = env.reset(seed=123)
  253. agent_pos = env.unwrapped.agent_pos
  254. goal_pos = env.unwrapped.goal_pos
  255. assert obs["image"].shape == (env.unwrapped.width, env.unwrapped.height, 3)
  256. assert np.all(
  257. obs["image"][agent_pos[0], agent_pos[1], :]
  258. == np.array([agent_pos[0], agent_pos[1], OBJECT_TO_IDX["agent"]])
  259. )
  260. assert np.all(
  261. obs["image"][goal_pos[0], goal_pos[1], :]
  262. == np.array([goal_pos[0], goal_pos[1], OBJECT_TO_IDX["goal"]])
  263. )
  264. obs, _, _, _, _ = env.step(2)
  265. agent_pos = env.unwrapped.agent_pos
  266. goal_pos = env.unwrapped.goal_pos
  267. assert obs["image"].shape == (env.unwrapped.width, env.unwrapped.height, 3)
  268. assert np.all(
  269. obs["image"][agent_pos[0], agent_pos[1], :]
  270. == np.array([agent_pos[0], agent_pos[1], OBJECT_TO_IDX["agent"]])
  271. )
  272. assert np.all(
  273. obs["image"][goal_pos[0], goal_pos[1], :]
  274. == np.array([goal_pos[0], goal_pos[1], OBJECT_TO_IDX["goal"]])
  275. )
  276. env.close()
  277. @pytest.mark.parametrize("env_id", ["MiniGrid-Empty-16x16-v0"])
  278. def test_stochastic_action_wrapper(env_id):
  279. env = gym.make(env_id)
  280. env = StochasticActionWrapper(env, prob=0.2)
  281. _, _ = env.reset()
  282. for _ in range(20):
  283. _, _, _, _, _ = env.step(0)
  284. env.close()
  285. env = gym.make(env_id)
  286. env = StochasticActionWrapper(env, prob=0.2, random_action=1)
  287. _, _ = env.reset()
  288. for _ in range(20):
  289. _, _, _, _, _ = env.step(0)
  290. env.close()
  291. def test_dict_observation_space_doesnt_clash_with_one_hot():
  292. env = gym.make("MiniGrid-Empty-5x5-v0")
  293. env = OneHotPartialObsWrapper(env)
  294. env = DictObservationSpaceWrapper(env)
  295. env.reset()
  296. obs, _, _, _, _ = env.step(0)
  297. assert obs["image"].shape == (7, 7, 20)
  298. assert env.observation_space["image"].shape == (7, 7, 20)
  299. env.close()
  300. def test_no_death_wrapper():
  301. death_cost = -1
  302. env = gym.make("MiniGrid-LavaCrossingS9N1-v0")
  303. _, _ = env.reset(seed=2)
  304. _, _, _, _, _ = env.step(1)
  305. _, reward, term, *_ = env.step(2)
  306. env_wrap = NoDeath(env, ("lava",), death_cost)
  307. _, _ = env_wrap.reset(seed=2)
  308. _, _, _, _, _ = env_wrap.step(1)
  309. _, reward_wrap, term_wrap, *_ = env_wrap.step(2)
  310. assert term and not term_wrap
  311. assert reward_wrap == reward + death_cost
  312. env.close()
  313. env_wrap.close()
  314. env = gym.make("MiniGrid-Dynamic-Obstacles-5x5-v0")
  315. _, _ = env.reset(seed=2)
  316. _, reward, term, *_ = env.step(2)
  317. env = NoDeath(env, ("ball",), death_cost)
  318. _, _ = env.reset(seed=2)
  319. _, reward_wrap, term_wrap, *_ = env.step(2)
  320. assert term and not term_wrap
  321. assert reward_wrap == reward + death_cost
  322. env.close()
  323. env_wrap.close()
  324. def test_non_square_RGBIMgObsWrapper():
  325. """
  326. Add test for non-square dimensions with RGBImgObsWrapper
  327. (https://github.com/Farama-Foundation/Minigrid/issues/444).
  328. """
  329. env = RGBImgObsWrapper(gym.make("MiniGrid-BlockedUnlockPickup-v0"))
  330. obs, info = env.reset()
  331. assert env.observation_space["image"].shape == obs["image"].shape