瀏覽代碼

🔧 Make env picklable (#269)

Mathïs Fédérico 2 年之前
父節點
當前提交
1670d01a0d

+ 5 - 1
.gitignore

@@ -17,4 +17,8 @@ __pycache__
 .vscode/
 /docs/environments/*.md
 !docs/environments/index.md
-!docs/environments/babyAI_index.md
+!docs/environments/babyAI_index.md
+
+# Virtual environments
+.env
+venv

+ 5 - 1
minigrid/envs/babyai/core/roomgrid_level.py

@@ -32,7 +32,11 @@ class BabyAIMissionSpace(MissionSpace):
     """
 
     def __init__(self):
-        super().__init__(mission_func=lambda: "go")
+        super().__init__(mission_func=self._gen_mission)
+
+    @staticmethod
+    def _gen_mission():
+        return "go"
 
     def contains(self, x: str):
         return True

+ 5 - 1
minigrid/envs/blockedunlockpickup.py

@@ -69,7 +69,7 @@ class BlockedUnlockPickupEnv(RoomGrid):
 
     def __init__(self, max_steps: Optional[int] = None, **kwargs):
         mission_space = MissionSpace(
-            mission_func=lambda color, type: f"pick up the {color} {type}",
+            mission_func=self._gen_mission,
             ordered_placeholders=[COLOR_NAMES, ["box", "key"]],
         )
 
@@ -86,6 +86,10 @@ class BlockedUnlockPickupEnv(RoomGrid):
             **kwargs,
         )
 
+    @staticmethod
+    def _gen_mission(color: str, obj_type: str):
+        return f"pick up the {color} {obj_type}"
+
     def _gen_grid(self, width, height):
         super()._gen_grid(width, height)
 

+ 10 - 6
minigrid/envs/crossing.py

@@ -109,13 +109,9 @@ class CrossingEnv(MiniGridEnv):
         self.obstacle_type = obstacle_type
 
         if obstacle_type == Lava:
-            mission_space = MissionSpace(
-                mission_func=lambda: "avoid the lava and get to the green goal square"
-            )
+            mission_space = MissionSpace(mission_func=self._gen_mission_lava)
         else:
-            mission_space = MissionSpace(
-                mission_func=lambda: "find the opening and get to the green goal square"
-            )
+            mission_space = MissionSpace(mission_func=self._gen_mission)
 
         if max_steps is None:
             max_steps = 4 * size**2
@@ -128,6 +124,14 @@ class CrossingEnv(MiniGridEnv):
             **kwargs
         )
 
+    @staticmethod
+    def _gen_mission_lava():
+        return "avoid the lava and get to the green goal square"
+
+    @staticmethod
+    def _gen_mission():
+        return "find the opening and get to the green goal square"
+
     def _gen_grid(self, width, height):
         assert width % 2 == 1 and height % 2 == 1  # odd size
 

+ 5 - 3
minigrid/envs/distshift.py

@@ -82,9 +82,7 @@ class DistShiftEnv(MiniGridEnv):
         self.goal_pos = (width - 2, 1)
         self.strip2_row = strip2_row
 
-        mission_space = MissionSpace(
-            mission_func=lambda: "get to the green goal square"
-        )
+        mission_space = MissionSpace(mission_func=self._gen_mission)
 
         if max_steps is None:
             max_steps = 4 * width * height
@@ -99,6 +97,10 @@ class DistShiftEnv(MiniGridEnv):
             **kwargs
         )
 
+    @staticmethod
+    def _gen_mission():
+        return "get to the green goal square"
+
     def _gen_grid(self, width, height):
         # Create an empty grid
         self.grid = Grid(width, height)

+ 5 - 3
minigrid/envs/doorkey.py

@@ -68,13 +68,15 @@ class DoorKeyEnv(MiniGridEnv):
     def __init__(self, size=8, max_steps: Optional[int] = None, **kwargs):
         if max_steps is None:
             max_steps = 10 * size**2
-        mission_space = MissionSpace(
-            mission_func=lambda: "use the key to open the door and then get to the goal"
-        )
+        mission_space = MissionSpace(mission_func=self._gen_mission)
         super().__init__(
             mission_space=mission_space, grid_size=size, max_steps=max_steps, **kwargs
         )
 
+    @staticmethod
+    def _gen_mission():
+        return "use the key to open the door and then get to the goal"
+
     def _gen_grid(self, width, height):
         # Create an empty grid
         self.grid = Grid(width, height)

+ 5 - 3
minigrid/envs/dynamicobstacles.py

@@ -90,9 +90,7 @@ class DynamicObstaclesEnv(MiniGridEnv):
         else:
             self.n_obstacles = int(size / 2)
 
-        mission_space = MissionSpace(
-            mission_func=lambda: "get to the green goal square"
-        )
+        mission_space = MissionSpace(mission_func=self._gen_mission)
 
         if max_steps is None:
             max_steps = 4 * size**2
@@ -109,6 +107,10 @@ class DynamicObstaclesEnv(MiniGridEnv):
         self.action_space = Discrete(self.actions.forward + 1)
         self.reward_range = (-1, 1)
 
+    @staticmethod
+    def _gen_mission():
+        return "get to the green goal square"
+
     def _gen_grid(self, width, height):
         # Create an empty grid
         self.grid = Grid(width, height)

+ 5 - 3
minigrid/envs/empty.py

@@ -80,9 +80,7 @@ class EmptyEnv(MiniGridEnv):
         self.agent_start_pos = agent_start_pos
         self.agent_start_dir = agent_start_dir
 
-        mission_space = MissionSpace(
-            mission_func=lambda: "get to the green goal square"
-        )
+        mission_space = MissionSpace(mission_func=self._gen_mission)
 
         if max_steps is None:
             max_steps = 4 * size**2
@@ -96,6 +94,10 @@ class EmptyEnv(MiniGridEnv):
             **kwargs
         )
 
+    @staticmethod
+    def _gen_mission():
+        return "get to the green goal square"
+
     def _gen_grid(self, width, height):
         # Create an empty grid
         self.grid = Grid(width, height)

+ 5 - 1
minigrid/envs/fetch.py

@@ -88,7 +88,7 @@ class FetchEnv(MiniGridEnv):
         ]
         self.size = size
         mission_space = MissionSpace(
-            mission_func=lambda syntax, color, type: f"{syntax} {color} {type}",
+            mission_func=self._gen_mission,
             ordered_placeholders=[MISSION_SYNTAX, COLOR_NAMES, self.obj_types],
         )
 
@@ -105,6 +105,10 @@ class FetchEnv(MiniGridEnv):
             **kwargs,
         )
 
+    @staticmethod
+    def _gen_mission(syntax: str, color: str, obj_type: str):
+        return f"{syntax} {color} {obj_type}"
+
     def _gen_grid(self, width, height):
         self.grid = Grid(width, height)
 

+ 5 - 1
minigrid/envs/fourrooms.py

@@ -64,7 +64,7 @@ class FourRoomsEnv(MiniGridEnv):
         self._goal_default_pos = goal_pos
 
         self.size = 19
-        mission_space = MissionSpace(mission_func=lambda: "reach the goal")
+        mission_space = MissionSpace(mission_func=self._gen_mission)
 
         super().__init__(
             mission_space=mission_space,
@@ -74,6 +74,10 @@ class FourRoomsEnv(MiniGridEnv):
             **kwargs
         )
 
+    @staticmethod
+    def _gen_mission():
+        return "reach the goal"
+
     def _gen_grid(self, width, height):
         # Create the grid
         self.grid = Grid(width, height)

+ 5 - 1
minigrid/envs/gotodoor.py

@@ -74,7 +74,7 @@ class GoToDoorEnv(MiniGridEnv):
         assert size >= 5
         self.size = size
         mission_space = MissionSpace(
-            mission_func=lambda color: f"go to the {color} door",
+            mission_func=self._gen_mission,
             ordered_placeholders=[COLOR_NAMES],
         )
 
@@ -91,6 +91,10 @@ class GoToDoorEnv(MiniGridEnv):
             **kwargs,
         )
 
+    @staticmethod
+    def _gen_mission(color: str):
+        return f"go to the {color} door"
+
     def _gen_grid(self, width, height):
         # Create the grid
         self.grid = Grid(width, height)

+ 5 - 1
minigrid/envs/gotoobject.py

@@ -21,7 +21,7 @@ class GoToObjectEnv(MiniGridEnv):
         self.obj_types = ["key", "ball", "box"]
 
         mission_space = MissionSpace(
-            mission_func=lambda color, type: f"go to the {color} {type}",
+            mission_func=self._gen_mission,
             ordered_placeholders=[COLOR_NAMES, self.obj_types],
         )
 
@@ -38,6 +38,10 @@ class GoToObjectEnv(MiniGridEnv):
             **kwargs,
         )
 
+    @staticmethod
+    def _gen_mission(color: str, obj_type: str):
+        return f"go to the {color} {obj_type}"
+
     def _gen_grid(self, width, height):
         self.grid = Grid(width, height)
 

+ 6 - 2
minigrid/envs/keycorridor.py

@@ -94,8 +94,8 @@ class KeyCorridorEnv(RoomGrid):
     ):
         self.obj_type = obj_type
         mission_space = MissionSpace(
-            mission_func=lambda color: f"pick up the {color} {obj_type}",
-            ordered_placeholders=[COLOR_NAMES],
+            mission_func=self._gen_mission,
+            ordered_placeholders=[COLOR_NAMES, [obj_type]],
         )
 
         if max_steps is None:
@@ -109,6 +109,10 @@ class KeyCorridorEnv(RoomGrid):
             **kwargs,
         )
 
+    @staticmethod
+    def _gen_mission(color: str, obj_type: str):
+        return f"pick up the {color} {obj_type}"
+
     def _gen_grid(self, width, height):
         super()._gen_grid(width, height)
 

+ 10 - 6
minigrid/envs/lavagap.py

@@ -77,13 +77,9 @@ class LavaGapEnv(MiniGridEnv):
         self.size = size
 
         if obstacle_type == Lava:
-            mission_space = MissionSpace(
-                mission_func=lambda: "avoid the lava and get to the green goal square"
-            )
+            mission_space = MissionSpace(mission_func=self._gen_mission_lava)
         else:
-            mission_space = MissionSpace(
-                mission_func=lambda: "find the opening and get to the green goal square"
-            )
+            mission_space = MissionSpace(mission_func=self._gen_mission)
 
         if max_steps is None:
             max_steps = 4 * size**2
@@ -98,6 +94,14 @@ class LavaGapEnv(MiniGridEnv):
             **kwargs
         )
 
+    @staticmethod
+    def _gen_mission_lava():
+        return "avoid the lava and get to the green goal square"
+
+    @staticmethod
+    def _gen_mission():
+        return "find the opening and get to the green goal square"
+
     def _gen_grid(self, width, height):
         assert width >= 5 and height >= 5
 

+ 8 - 1
minigrid/envs/lockedroom.py

@@ -82,7 +82,7 @@ class LockedRoomEnv(MiniGridEnv):
         if max_steps is None:
             max_steps = 10 * size
         mission_space = MissionSpace(
-            mission_func=lambda lockedroom_color, keyroom_color, door_color: f"get the {lockedroom_color} key from the {keyroom_color} room, unlock the {door_color} door and go to the goal",
+            mission_func=self._gen_mission,
             ordered_placeholders=[COLOR_NAMES] * 3,
         )
         super().__init__(
@@ -93,6 +93,13 @@ class LockedRoomEnv(MiniGridEnv):
             **kwargs,
         )
 
+    @staticmethod
+    def _gen_mission(lockedroom_color: str, keyroom_color: str, door_color: str):
+        return (
+            f"get the {lockedroom_color} key from the {keyroom_color} room,"
+            f" unlock the {door_color} door and go to the goal"
+        )
+
     def _gen_grid(self, width, height):
         # Create the grid
         self.grid = Grid(width, height)

+ 5 - 3
minigrid/envs/memory.py

@@ -76,9 +76,7 @@ class MemoryEnv(MiniGridEnv):
         if max_steps is None:
             max_steps = 5 * size**2
 
-        mission_space = MissionSpace(
-            mission_func=lambda: "go to the matching object at the end of the hallway"
-        )
+        mission_space = MissionSpace(mission_func=self._gen_mission)
         super().__init__(
             mission_space=mission_space,
             width=size,
@@ -89,6 +87,10 @@ class MemoryEnv(MiniGridEnv):
             **kwargs
         )
 
+    @staticmethod
+    def _gen_mission():
+        return "go to the matching object at the end of the hallway"
+
     def _gen_grid(self, width, height):
         self.grid = Grid(width, height)
 

+ 5 - 3
minigrid/envs/multiroom.py

@@ -94,9 +94,7 @@ class MultiRoomEnv(MiniGridEnv):
 
         self.rooms = []
 
-        mission_space = MissionSpace(
-            mission_func=lambda: "traverse the rooms to get to the goal"
-        )
+        mission_space = MissionSpace(mission_func=self._gen_mission)
 
         self.size = 25
 
@@ -111,6 +109,10 @@ class MultiRoomEnv(MiniGridEnv):
             **kwargs
         )
 
+    @staticmethod
+    def _gen_mission():
+        return "traverse the rooms to get to the goal"
+
     def _gen_grid(self, width, height):
         roomList = []
 

+ 6 - 1
minigrid/envs/obstructedmaze.py

@@ -103,7 +103,8 @@ class ObstructedMazeEnv(RoomGrid):
             max_steps = 4 * num_rooms_visited * room_size**2
 
         mission_space = MissionSpace(
-            mission_func=lambda: f"pick up the {COLOR_NAMES[0]} ball",
+            mission_func=self._gen_mission,
+            ordered_placeholders=[[COLOR_NAMES[0]]],
         )
         super().__init__(
             mission_space=mission_space,
@@ -115,6 +116,10 @@ class ObstructedMazeEnv(RoomGrid):
         )
         self.obj = Ball()  # initialize the obj attribute, that will be changed later on
 
+    @staticmethod
+    def _gen_mission(color: str):
+        return f"pick up the {color} ball"
+
     def _gen_grid(self, width, height):
         super()._gen_grid(width, height)
 

+ 5 - 1
minigrid/envs/playground.py

@@ -12,7 +12,7 @@ class PlaygroundEnv(MiniGridEnv):
     """
 
     def __init__(self, max_steps=100, **kwargs):
-        mission_space = MissionSpace(mission_func=lambda: "")
+        mission_space = MissionSpace(mission_func=self._gen_mission)
         self.size = 19
         super().__init__(
             mission_space=mission_space,
@@ -22,6 +22,10 @@ class PlaygroundEnv(MiniGridEnv):
             **kwargs
         )
 
+    @staticmethod
+    def _gen_mission():
+        return ""
+
     def _gen_grid(self, width, height):
         # Create the grid
         self.grid = Grid(width, height)

+ 7 - 1
minigrid/envs/putnear.py

@@ -72,7 +72,7 @@ class PutNearEnv(MiniGridEnv):
         self.numObjs = numObjs
         self.obj_types = ["key", "ball", "box"]
         mission_space = MissionSpace(
-            mission_func=lambda move_color, move_type, target_color, target_type: f"put the {move_color} {move_type} near the {target_color} {target_type}",
+            mission_func=self._gen_mission,
             ordered_placeholders=[
                 COLOR_NAMES,
                 self.obj_types,
@@ -94,6 +94,12 @@ class PutNearEnv(MiniGridEnv):
             **kwargs,
         )
 
+    @staticmethod
+    def _gen_mission(
+        move_color: str, move_type: str, target_color: str, target_type: str
+    ):
+        return f"put the {move_color} {move_type} near the {target_color} {target_type}"
+
     def _gen_grid(self, width, height):
         self.grid = Grid(width, height)
 

+ 6 - 4
minigrid/envs/redbluedoors.py

@@ -61,9 +61,7 @@ class RedBlueDoorEnv(MiniGridEnv):
 
     def __init__(self, size=8, max_steps: Optional[int] = None, **kwargs):
         self.size = size
-        mission_space = MissionSpace(
-            mission_func=lambda: "open the red door then the blue door"
-        )
+        mission_space = MissionSpace(mission_func=self._gen_mission)
 
         if max_steps is None:
             max_steps = 20 * size**2
@@ -73,9 +71,13 @@ class RedBlueDoorEnv(MiniGridEnv):
             width=2 * size,
             height=size,
             max_steps=max_steps,
-            **kwargs
+            **kwargs,
         )
 
+    @staticmethod
+    def _gen_mission():
+        return "open the red door then the blue door"
+
     def _gen_grid(self, width, height):
         # Create an empty grid
         self.grid = Grid(width, height)

+ 5 - 1
minigrid/envs/unlock.py

@@ -59,7 +59,7 @@ class UnlockEnv(RoomGrid):
 
     def __init__(self, max_steps: Optional[int] = None, **kwargs):
         room_size = 6
-        mission_space = MissionSpace(mission_func=lambda: "open the door")
+        mission_space = MissionSpace(mission_func=self._gen_mission)
 
         if max_steps is None:
             max_steps = 8 * room_size**2
@@ -73,6 +73,10 @@ class UnlockEnv(RoomGrid):
             **kwargs
         )
 
+    @staticmethod
+    def _gen_mission():
+        return "open the door"
+
     def _gen_grid(self, width, height):
         super()._gen_grid(width, height)
 

+ 5 - 1
minigrid/envs/unlockpickup.py

@@ -64,7 +64,7 @@ class UnlockPickupEnv(RoomGrid):
     def __init__(self, max_steps: Optional[int] = None, **kwargs):
         room_size = 6
         mission_space = MissionSpace(
-            mission_func=lambda color: f"pick up the {color} box",
+            mission_func=self._gen_mission,
             ordered_placeholders=[COLOR_NAMES],
         )
 
@@ -80,6 +80,10 @@ class UnlockPickupEnv(RoomGrid):
             **kwargs,
         )
 
+    @staticmethod
+    def _gen_mission(color: str):
+        return f"pick up the {color} box"
+
     def _gen_grid(self, width, height):
         super()._gen_grid(width, height)
 

+ 12 - 7
minigrid/minigrid_env.py

@@ -84,17 +84,22 @@ def check_if_no_duplicate(duplicate_list: list) -> bool:
 
 
 class MissionSpace(spaces.Space[str]):
-    r"""A space representing a mission for the Gym-Minigrid environments.
+    r"""A space representing a mission for the Minigrid environments.
     The space allows generating random mission strings constructed with an input placeholder list.
     Example Usage::
-        >>> observation_space = MissionSpace(mission_func=lambda color: f"Get the {color} ball.",
-                                                ordered_placeholders=[["green", "blue"]])
-        >>> observation_space.sample()
-            "Get the green ball."
-        >>> observation_space = MissionSpace(mission_func=lambda : "Get the ball.".,
-                                                ordered_placeholders=None)
+        >>> def _gen_mission() -> str:
+        >>>     return "Get the ball."
+        >>> observation_space = MissionSpace(mission_func=_gen_mission)
         >>> observation_space.sample()
             "Get the ball."
+        >>> def _gen_mission(color: str, object_type:str) -> str:
+        >>>     return f"Get the {color} {object_type}."
+        >>> observation_space = MissionSpace(
+        >>>     mission_func=_gen_mission,
+        >>>     ordered_placeholders=[["green", "blue"], ["ball", "key"]],
+        >>> )
+        >>> observation_space.sample()
+            "Get the green ball."
     """
 
     def __init__(

+ 22 - 2
tests/test_envs.py

@@ -1,10 +1,11 @@
+import pickle
 import warnings
 
 import gymnasium as gym
 import numpy as np
 import pytest
 from gymnasium.envs.registration import EnvSpec
-from gymnasium.utils.env_checker import check_env
+from gymnasium.utils.env_checker import check_env, data_equivalence
 
 from minigrid.core.grid import Grid
 from minigrid.core.mission import MissionSpace
@@ -163,7 +164,26 @@ def test_max_steps_argument(env_spec):
 
 
 @pytest.mark.parametrize(
-    "env_spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
+    "env_spec",
+    all_testing_env_specs,
+    ids=[spec.id for spec in all_testing_env_specs],
+)
+def test_pickle_env(env_spec):
+    env: gym.Env = env_spec.make()
+    pickled_env: gym.Env = pickle.loads(pickle.dumps(env))
+
+    data_equivalence(env.reset(), pickled_env.reset())
+
+    action = env.action_space.sample()
+    data_equivalence(env.step(action), pickled_env.step(action))
+    env.close()
+    pickled_env.close()
+
+
+@pytest.mark.parametrize(
+    "env_spec",
+    all_testing_env_specs,
+    ids=[spec.id for spec in all_testing_env_specs],
 )
 def old_run_test(env_spec):
     # Load the gym environment