Simone Parisi 1 éve
szülő
commit
091eea611e
3 módosított fájl, 115 hozzáadás és 0 törlés
  1. 6 0
      docs/api/wrappers.md
  2. 76 0
      minigrid/wrappers.py
  3. 33 0
      tests/test_wrappers.py

+ 6 - 0
docs/api/wrappers.md

@@ -34,6 +34,12 @@ lastpage:
 .. autoclass:: minigrid.wrappers.FullyObsWrapper
 ```
 
+# No Death
+
+```{eval-rst}
+.. autoclass:: minigrid.wrappers.NoDeath
+```
+
 # Observation
 
 ```{eval-rst}

+ 76 - 0
minigrid/wrappers.py

@@ -788,3 +788,79 @@ class StochasticActionWrapper(ActionWrapper):
                 return self.np_random.integers(0, high=6)
             else:
                 return self.random_action
+
+
+class NoDeath(Wrapper):
+    """
+    Wrapper to prevent death in specific cells (e.g., lava cells).
+    Instead of dying, the agent will receive a negative reward.
+
+    Example:
+        >>> import gymnasium as gym
+        >>> from minigrid.wrappers import NoDeath
+        >>>
+        >>> env = gym.make("MiniGrid-LavaCrossingS9N1-v0")
+        >>> _, _ = env.reset(seed=2)
+        >>> _, _, _, _, _ = env.step(1)
+        >>> _, reward, term, *_ = env.step(2)
+        >>> reward, term
+        (0, True)
+        >>>
+        >>> env = NoDeath(env, no_death_types=("lava",), death_cost=-1.0)
+        >>> _, _ = env.reset(seed=2)
+        >>> _, _, _, _, _ = env.step(1)
+        >>> _, reward, term, *_ = env.step(2)
+        >>> reward, term
+        (-1.0, False)
+        >>>
+        >>>
+        >>> env = gym.make("MiniGrid-Dynamic-Obstacles-5x5-v0")
+        >>> _, _ = env.reset(seed=2)
+        >>> _, reward, term, *_ = env.step(2)
+        >>> reward, term
+        (-1, True)
+        >>>
+        >>> env = NoDeath(env, no_death_types=("ball",), death_cost=-1.0)
+        >>> _, _ = env.reset(seed=2)
+        >>> _, reward, term, *_ = env.step(2)
+        >>> reward, term
+        (-2.0, False)
+    """
+
+    def __init__(self, env, no_death_types: tuple[str, ...], death_cost: float = -1.0):
+        """A wrapper to prevent death in specific cells.
+
+        Args:
+            env: The environment to apply the wrapper
+            no_death_types: List of strings to identify death cells
+            death_cost: The negative reward received in death cells
+
+        """
+        assert "goal" not in no_death_types, "goal cannot be a death cell"
+
+        super().__init__(env)
+        self.death_cost = death_cost
+        self.no_death_types = no_death_types
+
+    def step(self, action):
+        # In Dynamic-Obstacles, obstacles move after the agent moves,
+        # so we need to check for collision before self.env.step()
+        front_cell = self.grid.get(*self.front_pos)
+        going_to_death = (
+            action == self.actions.forward
+            and front_cell is not None
+            and front_cell.type in self.no_death_types
+        )
+
+        obs, reward, terminated, truncated, info = self.env.step(action)
+
+        # We also check if the agent stays in death cells (e.g., lava)
+        # without moving
+        current_cell = self.grid.get(*self.agent_pos)
+        in_death = current_cell is not None and current_cell.type in self.no_death_types
+
+        if terminated and (going_to_death or in_death):
+            terminated = False
+            reward += self.death_cost
+
+        return obs, reward, terminated, truncated, info

+ 33 - 0
tests/test_wrappers.py

@@ -16,6 +16,7 @@ from minigrid.wrappers import (
     FlatObsWrapper,
     FullyObsWrapper,
     ImgObsWrapper,
+    NoDeath,
     OneHotPartialObsWrapper,
     PositionBonus,
     ReseedWrapper,
@@ -356,3 +357,35 @@ def test_dict_observation_space_doesnt_clash_with_one_hot():
     assert obs["image"].shape == (7, 7, 20)
     assert env.observation_space["image"].shape == (7, 7, 20)
     env.close()
+
+
+def test_no_death_wrapper():
+    death_cost = -1
+
+    env = gym.make("MiniGrid-LavaCrossingS9N1-v0")
+    _, _ = env.reset(seed=2)
+    _, _, _, _, _ = env.step(1)
+    _, reward, term, *_ = env.step(2)
+
+    env_wrap = NoDeath(env, ("lava",), death_cost)
+    _, _ = env_wrap.reset(seed=2)
+    _, _, _, _, _ = env_wrap.step(1)
+    _, reward_wrap, term_wrap, *_ = env_wrap.step(2)
+
+    assert term and not term_wrap
+    assert reward_wrap == reward + death_cost
+    env.close()
+    env_wrap.close()
+
+    env = gym.make("MiniGrid-Dynamic-Obstacles-5x5-v0")
+    _, _ = env.reset(seed=2)
+    _, reward, term, *_ = env.step(2)
+
+    env = NoDeath(env, ("ball",), death_cost)
+    _, _ = env.reset(seed=2)
+    _, reward_wrap, term_wrap, *_ = env.step(2)
+
+    assert term and not term_wrap
+    assert reward_wrap == reward + death_cost
+    env.close()
+    env_wrap.close()