Преглед на файлове

use new step api in new tests

saleml преди 2 години
родител
ревизия
5b71a38e9e
променени са 3 файла, в които са добавени 36 реда и са изтрити 35 реда
  1. 0 4
      gym_minigrid/wrappers.py
  2. 13 12
      tests/test_envs.py
  3. 23 19
      tests/test_wrappers.py

+ 0 - 4
gym_minigrid/wrappers.py

@@ -27,10 +27,6 @@ class ReseedWrapper(Wrapper):
         self.seed_idx = (self.seed_idx + 1) % len(self.seeds)
         return self.env.reset(seed=seed, **kwargs)
 
-    def step(self, action):
-        obs, reward, done, info = self.env.step(action)
-        return obs, reward, done, info
-
 
 class ActionBonus(gym.Wrapper):
     """

+ 13 - 12
tests/test_envs.py

@@ -54,7 +54,7 @@ def test_env_determinism_rollout(env_spec: EnvSpec):
     - observation after first reset are the same
     - same actions are sampled by the two envs
     - observations are contained in the observation space
-    - obs, rew, done and info are equals between the two envs
+    - obs, rew, terminated, truncated and info are equals between the two envs
     """
     # Don't check rollout equality if it's a nondeterministic environment.
     if env_spec.nondeterministic is True:
@@ -73,8 +73,8 @@ def test_env_determinism_rollout(env_spec: EnvSpec):
         # We don't evaluate the determinism of actions
         action = env_1.action_space.sample()
 
-        obs_1, rew_1, done_1, info_1 = env_1.step(action)
-        obs_2, rew_2, done_2, info_2 = env_2.step(action)
+        obs_1, rew_1, terminated_1, truncated_1, info_1 = env_1.step(action)
+        obs_2, rew_2, terminated_2, truncated_2, info_2 = env_2.step(action)
 
         assert_equals(obs_1, obs_2, f"[{time_step}] ")
         assert env_1.observation_space.contains(
@@ -82,10 +82,11 @@ def test_env_determinism_rollout(env_spec: EnvSpec):
         )  # obs_2 verified by previous assertion
 
         assert rew_1 == rew_2, f"[{time_step}] reward 1={rew_1}, reward 2={rew_2}"
-        assert done_1 == done_2, f"[{time_step}] done 1={done_1}, done 2={done_2}"
+        assert terminated_1 == terminated_2, f"[{time_step}] terminated 1={terminated_1}, terminated 2={terminated_2}"
+        assert truncated_1 == truncated_2, f"[{time_step}] truncated 1={truncated_1}, truncated 2={truncated_2}"
         assert_equals(info_1, info_2, f"[{time_step}] ")
 
-        if done_1:  # done_2 verified by previous assertion
+        if terminated_1 or truncated_1:  # terminated_2 and truncated_2 verified by previous assertion
             env_1.reset(seed=SEED)
             env_2.reset(seed=SEED)
 
@@ -110,7 +111,7 @@ def test_render_modes(spec):
 
 @pytest.mark.parametrize("env_id", ["MiniGrid-DoorKey-6x6-v0"])
 def test_agent_sees_method(env_id):
-    env = gym.make(env_id)
+    env = gym.make(env_id, new_step_api=True)
     goal_pos = (env.grid.width - 2, env.grid.height - 2)
 
     # Test the "in" operator on grid objects
@@ -121,14 +122,14 @@ def test_agent_sees_method(env_id):
     env.reset()
     for i in range(0, 500):
         action = env.action_space.sample()
-        obs, reward, done, info = env.step(action)
+        obs, reward, terminated, truncated, info = env.step(action)
 
         grid, _ = Grid.decode(obs["image"])
         goal_visible = ("green", "goal") in grid
 
         agent_sees_goal = env.agent_sees(*goal_pos)
         assert agent_sees_goal == goal_visible
-        if done:
+        if terminated or truncated:
             env.reset()
 
     env.close()
@@ -161,7 +162,7 @@ def old_run_test(env_spec):
         # Pick a random action
         action = env.action_space.sample()
 
-        obs, reward, done, info = env.step(action)
+        obs, reward, terminated, truncated, info = env.step(action)
 
         # Validate the agent position
         assert env.agent_pos[0] < env.width
@@ -180,7 +181,7 @@ def old_run_test(env_spec):
         assert reward >= env.reward_range[0], reward
         assert reward <= env.reward_range[1], reward
 
-        if done:
+        if terminated or truncated:
             num_episodes += 1
             env.reset()
 
@@ -192,7 +193,7 @@ def old_run_test(env_spec):
 
 @pytest.mark.parametrize("env_id", ["MiniGrid-Empty-8x8-v0"])
 def test_interactive_mode(env_id):
-    env = gym.make(env_id)
+    env = gym.make(env_id, new_step_api=True)
     env.reset()
 
     for i in range(0, 100):
@@ -201,7 +202,7 @@ def test_interactive_mode(env_id):
         # Pick a random action
         action = env.action_space.sample()
 
-        obs, reward, done, info = env.step(action)
+        obs, reward, terminated, truncated, info = env.step(action)
 
     # Test the close method
     env.close()

+ 23 - 19
tests/test_wrappers.py

@@ -43,11 +43,12 @@ def test_reseed_wrapper(env_spec):
         for time_step in range(NUM_STEPS):
             action = env.action_space.sample()
 
-            obs, rew, done, info = env.step(action)
+            obs, rew, terminated, truncated, info = env.step(action)
             (
                 unwrapped_obs,
                 unwrapped_rew,
-                unwrapped_done,
+                unwrapped_terminated,
+                unwrapped_truncated,
                 unwrapped_info,
             ) = unwrapped_env.step(action)
 
@@ -58,12 +59,15 @@ def test_reseed_wrapper(env_spec):
                 rew == unwrapped_rew
             ), f"[{time_step}] reward={rew}, unwrapped reward={unwrapped_rew}"
             assert (
-                done == unwrapped_done
-            ), f"[{time_step}] done={done}, unwrapped done={unwrapped_done}"
+                terminated == unwrapped_terminated
+            ), f"[{time_step}] terminated={terminated}, unwrapped terminated={unwrapped_terminated}"
+            assert (
+                truncated == unwrapped_truncated
+            ), f"[{time_step}] truncated={truncated}, unwrapped truncated={unwrapped_truncated}"
             assert_equals(info, unwrapped_info, f"[{time_step}] ")
 
             # Start the next seed
-            if done:
+            if terminated or truncated:
                 break
 
     env.close()
@@ -72,8 +76,8 @@ def test_reseed_wrapper(env_spec):
 
 @pytest.mark.parametrize("env_id", ["MiniGrid-Empty-16x16-v0"])
 def test_state_bonus_wrapper(env_id):
-    env = gym.make(env_id)
-    wrapped_env = StateBonus(gym.make(env_id))
+    env = gym.make(env_id, new_step_api=True)
+    wrapped_env = StateBonus(gym.make(env_id, new_step_api=True))
 
     action_forward = MiniGridEnv.Actions.forward
     action_left = MiniGridEnv.Actions.left
@@ -86,14 +90,14 @@ def test_state_bonus_wrapper(env_id):
 
     # Turn lef 3 times (check that actions don't influence bonus)
     for _ in range(3):
-        _, wrapped_rew, _, _ = wrapped_env.step(action_left)
+        _, wrapped_rew, _, _, _ = wrapped_env.step(action_left)
 
     env.reset()
     for _ in range(5):
         env.step(action_forward)
     # Turn right 3 times
     for _ in range(3):
-        _, rew, _, _ = env.step(action_right)
+        _, rew, _, _, _ = env.step(action_right)
 
     expected_bonus_reward = rew + 1 / math.sqrt(13)
 
@@ -102,19 +106,19 @@ def test_state_bonus_wrapper(env_id):
 
 @pytest.mark.parametrize("env_id", ["MiniGrid-Empty-16x16-v0"])
 def test_action_bonus_wrapper(env_id):
-    env = gym.make(env_id)
-    wrapped_env = ActionBonus(gym.make(env_id))
+    env = gym.make(env_id, new_step_api=True)
+    wrapped_env = ActionBonus(gym.make(env_id, new_step_api = True))
 
     action = MiniGridEnv.Actions.forward
 
     for _ in range(10):
         wrapped_env.reset()
         for _ in range(5):
-            _, wrapped_rew, _, _ = wrapped_env.step(action)
+            _, wrapped_rew, _, _, _ = wrapped_env.step(action)
 
     env.reset()
     for _ in range(5):
-        _, rew, _, _ = env.step(action)
+        _, rew, _, _, _ = env.step(action)
 
     expected_bonus_reward = rew + 1 / math.sqrt(10)
 
@@ -129,7 +133,7 @@ def test_dict_observation_space_wrapper(env_spec):
     env = DictObservationSpaceWrapper(env)
     env.reset()
     mission = env.mission
-    obs, _, _, _ = env.step(0)
+    obs, _, _, _, _ = env.step(0)
     assert env.string_to_indices(mission) == [
         value for value in obs["mission"] if value != 0
     ]
@@ -202,9 +206,9 @@ class EmptyEnvWithExtraObs(EmptyEnv):
         return obs
 
     def step(self, action):
-        obs, reward, done, info = super().step(action)
+        obs, reward, terminated, truncated, info = super().step(action)
         obs["size"] = np.array([self.width, self.height])
-        return obs, reward, done, info
+        return obs, reward, terminated, truncated, info
 
 
 @pytest.mark.parametrize(
@@ -218,7 +222,7 @@ class EmptyEnvWithExtraObs(EmptyEnv):
 )
 def test_agent_sees_method(wrapper):
     env1 = wrapper(EmptyEnvWithExtraObs())
-    env2 = wrapper(gym.make("MiniGrid-Empty-5x5-v0"))
+    env2 = wrapper(gym.make("MiniGrid-Empty-5x5-v0", new_step_api=True))
 
     obs1 = env1.reset(seed=0)
     obs2 = env2.reset(seed=0)
@@ -228,8 +232,8 @@ def test_agent_sees_method(wrapper):
     for key in obs2:
         assert np.array_equal(obs1[key], obs2[key])
 
-    obs1, reward1, done1, _ = env1.step(0)
-    obs2, reward2, done2, _ = env2.step(0)
+    obs1, reward1, terminated1, truncated1, _ = env1.step(0)
+    obs2, reward2, terminated2, truncated2, _ = env2.step(0)
     assert "size" in obs1
     assert obs1["size"].shape == (2,)
     assert (obs1["size"] == [5, 5]).all()