Bläddra i källkod

Add `max_steps` argument for all environments (#265)

Rodrigo de Lazcano 2 år sedan
förälder
incheckning
7751a5cbeb

+ 20 - 3
minigrid/envs/babyai/core/roomgrid_level.py

@@ -1,6 +1,8 @@
 """
 """
 Copied and adapted from https://github.com/mila-iqia/babyai
 Copied and adapted from https://github.com/mila-iqia/babyai
 """
 """
+from typing import Optional
+
 from minigrid.core.roomgrid import RoomGrid
 from minigrid.core.roomgrid import RoomGrid
 from minigrid.envs.babyai.core.verifier import (
 from minigrid.envs.babyai.core.verifier import (
     ActionInstr,
     ActionInstr,
@@ -44,9 +46,22 @@ class RoomGridLevel(RoomGrid):
     of approximately similar difficulty.
     of approximately similar difficulty.
     """
     """
 
 
-    def __init__(self, room_size=8, **kwargs):
+    def __init__(self, room_size=8, max_steps: Optional[int] = None, **kwargs):
         mission_space = BabyAIMissionSpace()
         mission_space = BabyAIMissionSpace()
-        super().__init__(room_size=room_size, mission_space=mission_space, **kwargs)
+
+        # If `max_steps` arg is passed it will be fixed for every episode,
+        # if not it will vary after reset depending on the maze size.
+        self.fixed_max_steps = False
+        if max_steps is not None:
+            self.fixed_max_steps = True
+        else:
+            max_steps = 0  # only for initialization
+        super().__init__(
+            room_size=room_size,
+            mission_space=mission_space,
+            max_steps=max_steps,
+            **kwargs
+        )
 
 
     def reset(self, **kwargs):
     def reset(self, **kwargs):
         obs = super().reset(**kwargs)
         obs = super().reset(**kwargs)
@@ -58,7 +73,9 @@ class RoomGridLevel(RoomGrid):
         nav_time_room = self.room_size**2
         nav_time_room = self.room_size**2
         nav_time_maze = nav_time_room * self.num_rows * self.num_cols
         nav_time_maze = nav_time_room * self.num_rows * self.num_cols
         num_navs = self.num_navs_needed(self.instrs)
         num_navs = self.num_navs_needed(self.instrs)
-        self.max_steps = num_navs * nav_time_maze
+
+        if not self.fixed_max_steps:
+            self.max_steps = num_navs * nav_time_maze
 
 
         return obs
         return obs
 
 

+ 20 - 4
minigrid/envs/babyai/open.py

@@ -2,6 +2,7 @@
 Copied and adapted from https://github.com/mila-iqia/babyai.
 Copied and adapted from https://github.com/mila-iqia/babyai.
 Levels described in the Baby AI ICLR 2019 submission, with the `Open` instruction.
 Levels described in the Baby AI ICLR 2019 submission, with the `Open` instruction.
 """
 """
+from typing import Optional
 
 
 from minigrid.core.constants import COLOR_NAMES
 from minigrid.core.constants import COLOR_NAMES
 from minigrid.envs.babyai.core.roomgrid_level import RoomGridLevel
 from minigrid.envs.babyai.core.roomgrid_level import RoomGridLevel
@@ -97,13 +98,23 @@ class OpenTwoDoors(RoomGridLevel):
     This task requires memory (recurrent policy) to be solved effectively.
     This task requires memory (recurrent policy) to be solved effectively.
     """
     """
 
 
-    def __init__(self, first_color=None, second_color=None, strict=False, **kwargs):
+    def __init__(
+        self,
+        first_color=None,
+        second_color=None,
+        strict=False,
+        max_steps: Optional[int] = None,
+        **kwargs
+    ):
         self.first_color = first_color
         self.first_color = first_color
         self.second_color = second_color
         self.second_color = second_color
         self.strict = strict
         self.strict = strict
 
 
         room_size = 6
         room_size = 6
-        super().__init__(room_size=room_size, max_steps=20 * room_size**2, **kwargs)
+        if max_steps is None:
+            max_steps = 20 * room_size**2
+
+        super().__init__(room_size=room_size, max_steps=max_steps, **kwargs)
 
 
     def gen_mission(self):
     def gen_mission(self):
         colors = self._rand_subset(COLOR_NAMES, 2)
         colors = self._rand_subset(COLOR_NAMES, 2)
@@ -131,13 +142,18 @@ class OpenDoorsOrder(RoomGridLevel):
     Open one or two doors in the order specified.
     Open one or two doors in the order specified.
     """
     """
 
 
-    def __init__(self, num_doors, debug=False, **kwargs):
+    def __init__(
+        self, num_doors, debug=False, max_steps: Optional[int] = None, **kwargs
+    ):
         assert num_doors >= 2
         assert num_doors >= 2
         self.num_doors = num_doors
         self.num_doors = num_doors
         self.debug = debug
         self.debug = debug
 
 
         room_size = 6
         room_size = 6
-        super().__init__(room_size=room_size, max_steps=20 * room_size**2, **kwargs)
+        if max_steps is None:
+            max_steps = 20 * room_size**2
+
+        super().__init__(room_size=room_size, max_steps=max_steps, **kwargs)
 
 
     def gen_mission(self):
     def gen_mission(self):
         colors = self._rand_subset(COLOR_NAMES, self.num_doors)
         colors = self._rand_subset(COLOR_NAMES, self.num_doors)

+ 26 - 13
minigrid/envs/babyai/other.py

@@ -2,6 +2,7 @@
 Copied and adapted from https://github.com/mila-iqia/babyai.
 Copied and adapted from https://github.com/mila-iqia/babyai.
 Levels described in the Baby AI ICLR 2019 submission, with different instructions than those in other files.
 Levels described in the Baby AI ICLR 2019 submission, with different instructions than those in other files.
 """
 """
+from typing import Optional
 
 
 from minigrid.envs.babyai.core.roomgrid_level import RoomGridLevel
 from minigrid.envs.babyai.core.roomgrid_level import RoomGridLevel
 from minigrid.envs.babyai.core.verifier import (
 from minigrid.envs.babyai.core.verifier import (
@@ -55,8 +56,12 @@ class FindObjS5(RoomGridLevel):
     This level requires potentially exhaustive exploration
     This level requires potentially exhaustive exploration
     """
     """
 
 
-    def __init__(self, room_size=5, **kwargs):
-        super().__init__(room_size=room_size, max_steps=20 * room_size**2, **kwargs)
+    def __init__(self, room_size=5, max_steps: Optional[int] = None, **kwargs):
+
+        if max_steps is None:
+            max_steps = 20 * room_size**2
+
+        super().__init__(room_size=room_size, max_steps=max_steps, **kwargs)
 
 
     def gen_mission(self):
     def gen_mission(self):
         # Add a random object to a random room
         # Add a random object to a random room
@@ -75,14 +80,21 @@ class KeyCorridor(RoomGridLevel):
     random room.
     random room.
     """
     """
 
 
-    def __init__(self, num_rows=3, obj_type="ball", room_size=6, **kwargs):
+    def __init__(
+        self,
+        num_rows=3,
+        obj_type="ball",
+        room_size=6,
+        max_steps: Optional[int] = None,
+        **kwargs
+    ):
         self.obj_type = obj_type
         self.obj_type = obj_type
 
 
+        if max_steps is None:
+            max_steps = 30 * room_size**2
+
         super().__init__(
         super().__init__(
-            room_size=room_size,
-            num_rows=num_rows,
-            max_steps=30 * room_size**2,
-            **kwargs
+            room_size=room_size, num_rows=num_rows, max_steps=max_steps, **kwargs
         )
         )
 
 
     def gen_mission(self):
     def gen_mission(self):
@@ -130,16 +142,17 @@ class MoveTwoAcross(RoomGridLevel):
     instructions.
     instructions.
     """
     """
 
 
-    def __init__(self, room_size, objs_per_room, **kwargs):
+    def __init__(
+        self, room_size, objs_per_room, max_steps: Optional[int] = None, **kwargs
+    ):
         assert objs_per_room <= 9
         assert objs_per_room <= 9
         self.objs_per_room = objs_per_room
         self.objs_per_room = objs_per_room
 
 
+        if max_steps is None:
+            max_steps = 16 * room_size**2
+
         super().__init__(
         super().__init__(
-            num_rows=1,
-            num_cols=2,
-            room_size=room_size,
-            max_steps=16 * room_size**2,
-            **kwargs
+            num_rows=1, num_cols=2, room_size=room_size, max_steps=max_steps, **kwargs
         )
         )
 
 
     def gen_mission(self):
     def gen_mission(self):

+ 6 - 2
minigrid/envs/babyai/pickup.py

@@ -2,6 +2,7 @@
 Copied and adapted from https://github.com/mila-iqia/babyai.
 Copied and adapted from https://github.com/mila-iqia/babyai.
 Levels described in the Baby AI ICLR 2019 submission, with the `Pick up` instruction.
 Levels described in the Baby AI ICLR 2019 submission, with the `Pick up` instruction.
 """
 """
+from typing import Optional
 
 
 from minigrid.envs.babyai.core.levelgen import LevelGen
 from minigrid.envs.babyai.core.levelgen import LevelGen
 from minigrid.envs.babyai.core.roomgrid_level import RejectSampling, RoomGridLevel
 from minigrid.envs.babyai.core.roomgrid_level import RejectSampling, RoomGridLevel
@@ -101,9 +102,12 @@ class PickupAbove(RoomGridLevel):
     This task requires to use the compass to be solved effectively.
     This task requires to use the compass to be solved effectively.
     """
     """
 
 
-    def __init__(self, **kwargs):
+    def __init__(self, max_steps: Optional[int] = None, **kwargs):
         room_size = 6
         room_size = 6
-        super().__init__(room_size=room_size, max_steps=8 * room_size**2, **kwargs)
+        if max_steps is None:
+            max_steps = 8 * room_size**2
+
+        super().__init__(room_size=room_size, max_steps=max_steps, **kwargs)
 
 
     def gen_mission(self):
     def gen_mission(self):
         # Add a random object to the top-middle room
         # Add a random object to the top-middle room

+ 13 - 6
minigrid/envs/babyai/putnext.py

@@ -2,6 +2,7 @@
 Copied and adapted from https://github.com/mila-iqia/babyai.
 Copied and adapted from https://github.com/mila-iqia/babyai.
 Levels described in the Baby AI ICLR 2019 submission, with the `Put Next` instruction.
 Levels described in the Baby AI ICLR 2019 submission, with the `Put Next` instruction.
 """
 """
+from typing import Optional
 
 
 from minigrid.envs.babyai.core.roomgrid_level import RoomGridLevel
 from minigrid.envs.babyai.core.roomgrid_level import RoomGridLevel
 from minigrid.envs.babyai.core.verifier import ObjDesc, PutNextInstr
 from minigrid.envs.babyai.core.verifier import ObjDesc, PutNextInstr
@@ -35,18 +36,24 @@ class PutNext(RoomGridLevel):
     instructions.
     instructions.
     """
     """
 
 
-    def __init__(self, room_size, objs_per_room, start_carrying=False, **kwargs):
+    def __init__(
+        self,
+        room_size,
+        objs_per_room,
+        start_carrying=False,
+        max_steps: Optional[int] = None,
+        **kwargs
+    ):
         assert room_size >= 4
         assert room_size >= 4
         assert objs_per_room <= 9
         assert objs_per_room <= 9
         self.objs_per_room = objs_per_room
         self.objs_per_room = objs_per_room
         self.start_carrying = start_carrying
         self.start_carrying = start_carrying
 
 
+        if max_steps is None:
+            max_steps = 8 * room_size**2
+
         super().__init__(
         super().__init__(
-            num_rows=1,
-            num_cols=2,
-            room_size=room_size,
-            max_steps=8 * room_size**2,
-            **kwargs
+            num_rows=1, num_cols=2, room_size=room_size, max_steps=max_steps, **kwargs
         )
         )
 
 
     def gen_mission(self):
     def gen_mission(self):

+ 16 - 19
minigrid/envs/babyai/unlock.py

@@ -2,6 +2,7 @@
 Copied and adapted from https://github.com/mila-iqia/babyai.
 Copied and adapted from https://github.com/mila-iqia/babyai.
 Levels described in the Baby AI ICLR 2019 submission, with the `Unlock` instruction.
 Levels described in the Baby AI ICLR 2019 submission, with the `Unlock` instruction.
 """
 """
+from typing import Optional
 
 
 from minigrid.core.constants import COLOR_NAMES
 from minigrid.core.constants import COLOR_NAMES
 from minigrid.envs.babyai.core.roomgrid_level import RoomGridLevel
 from minigrid.envs.babyai.core.roomgrid_level import RoomGridLevel
@@ -109,16 +110,14 @@ class UnlockPickup(RoomGridLevel):
     Unlock a door, then pick up a box in another room
     Unlock a door, then pick up a box in another room
     """
     """
 
 
-    def __init__(self, distractors=False, **kwargs):
+    def __init__(self, distractors=False, max_steps: Optional[int] = None, **kwargs):
         self.distractors = distractors
         self.distractors = distractors
-
         room_size = 6
         room_size = 6
+        if max is None:
+            max_steps = 8 * room_size**2
+
         super().__init__(
         super().__init__(
-            num_rows=1,
-            num_cols=2,
-            room_size=room_size,
-            max_steps=8 * room_size**2,
-            **kwargs
+            num_rows=1, num_cols=2, room_size=6, max_steps=max_steps, **kwargs
         )
         )
 
 
     def gen_mission(self):
     def gen_mission(self):
@@ -142,14 +141,13 @@ class BlockedUnlockPickup(RoomGridLevel):
     in another room
     in another room
     """
     """
 
 
-    def __init__(self, **kwargs):
+    def __init__(self, max_steps: Optional[int] = None, **kwargs):
         room_size = 6
         room_size = 6
+        if max_steps is None:
+            max_steps = 16 * room_size**2
+
         super().__init__(
         super().__init__(
-            num_rows=1,
-            num_cols=2,
-            room_size=room_size,
-            max_steps=16 * room_size**2,
-            **kwargs
+            num_rows=1, num_cols=2, room_size=room_size, max_steps=max_steps, **kwargs
         )
         )
 
 
     def gen_mission(self):
     def gen_mission(self):
@@ -173,14 +171,13 @@ class UnlockToUnlock(RoomGridLevel):
     Unlock a door A that requires to unlock a door B before
     Unlock a door A that requires to unlock a door B before
     """
     """
 
 
-    def __init__(self, **kwargs):
+    def __init__(self, max_steps: Optional[int] = None, **kwargs):
         room_size = 6
         room_size = 6
+        if max_steps is None:
+            max_steps = 30 * room_size**2
+
         super().__init__(
         super().__init__(
-            num_rows=1,
-            num_cols=3,
-            room_size=room_size,
-            max_steps=30 * room_size**2,
-            **kwargs
+            num_rows=1, num_cols=3, room_size=room_size, max_steps=max_steps, **kwargs
         )
         )
 
 
     def gen_mission(self):
     def gen_mission(self):

+ 9 - 3
minigrid/envs/blockedunlockpickup.py

@@ -1,3 +1,5 @@
+from typing import Optional
+
 from minigrid.core.constants import COLOR_NAMES
 from minigrid.core.constants import COLOR_NAMES
 from minigrid.core.mission import MissionSpace
 from minigrid.core.mission import MissionSpace
 from minigrid.core.roomgrid import RoomGrid
 from minigrid.core.roomgrid import RoomGrid
@@ -63,18 +65,22 @@ class BlockedUnlockPickupEnv(RoomGrid):
 
 
     """
     """
 
 
-    def __init__(self, **kwargs):
-        room_size = 6
+    def __init__(self, max_steps: Optional[int] = None, **kwargs):
         mission_space = MissionSpace(
         mission_space = MissionSpace(
             mission_func=lambda color, type: f"pick up the {color} {type}",
             mission_func=lambda color, type: f"pick up the {color} {type}",
             ordered_placeholders=[COLOR_NAMES, ["box", "key"]],
             ordered_placeholders=[COLOR_NAMES, ["box", "key"]],
         )
         )
+
+        room_size = 6
+        if max_steps is None:
+            max_steps = 16 * room_size**2
+
         super().__init__(
         super().__init__(
             mission_space=mission_space,
             mission_space=mission_space,
             num_rows=1,
             num_rows=1,
             num_cols=2,
             num_cols=2,
             room_size=room_size,
             room_size=room_size,
-            max_steps=16 * room_size**2,
+            max_steps=max_steps,
             **kwargs,
             **kwargs,
         )
         )
 
 

+ 14 - 4
minigrid/envs/crossing.py

@@ -1,4 +1,5 @@
 import itertools as itt
 import itertools as itt
+from typing import Optional
 
 
 import numpy as np
 import numpy as np
 
 
@@ -90,7 +91,14 @@ class CrossingEnv(MiniGridEnv):
 
 
     """
     """
 
 
-    def __init__(self, size=9, num_crossings=1, obstacle_type=Lava, **kwargs):
+    def __init__(
+        self,
+        size=9,
+        num_crossings=1,
+        obstacle_type=Lava,
+        max_steps: Optional[int] = None,
+        **kwargs
+    ):
         self.num_crossings = num_crossings
         self.num_crossings = num_crossings
         self.obstacle_type = obstacle_type
         self.obstacle_type = obstacle_type
 
 
@@ -103,12 +111,14 @@ class CrossingEnv(MiniGridEnv):
                 mission_func=lambda: "find the opening and get to the green goal square"
                 mission_func=lambda: "find the opening and get to the green goal square"
             )
             )
 
 
+        if max_steps is None:
+            max_steps = 4 * size**2
+
         super().__init__(
         super().__init__(
             mission_space=mission_space,
             mission_space=mission_space,
             grid_size=size,
             grid_size=size,
-            max_steps=4 * size * size,
-            # Set this to True for maximum speed
-            see_through_walls=False,
+            see_through_walls=False,  # Set this to True for maximum speed
+            max_steps=max_steps,
             **kwargs
             **kwargs
         )
         )
 
 

+ 7 - 1
minigrid/envs/distshift.py

@@ -1,3 +1,5 @@
+from typing import Optional
+
 from minigrid.core.grid import Grid
 from minigrid.core.grid import Grid
 from minigrid.core.mission import MissionSpace
 from minigrid.core.mission import MissionSpace
 from minigrid.core.world_object import Goal, Lava
 from minigrid.core.world_object import Goal, Lava
@@ -70,6 +72,7 @@ class DistShiftEnv(MiniGridEnv):
         agent_start_pos=(1, 1),
         agent_start_pos=(1, 1),
         agent_start_dir=0,
         agent_start_dir=0,
         strip2_row=2,
         strip2_row=2,
+        max_steps: Optional[int] = None,
         **kwargs
         **kwargs
     ):
     ):
         self.agent_start_pos = agent_start_pos
         self.agent_start_pos = agent_start_pos
@@ -81,13 +84,16 @@ class DistShiftEnv(MiniGridEnv):
             mission_func=lambda: "get to the green goal square"
             mission_func=lambda: "get to the green goal square"
         )
         )
 
 
+        if max_steps is None:
+            max_steps = 4 * width * height
+
         super().__init__(
         super().__init__(
             mission_space=mission_space,
             mission_space=mission_space,
             width=width,
             width=width,
             height=height,
             height=height,
-            max_steps=4 * width * height,
             # Set this to True for maximum speed
             # Set this to True for maximum speed
             see_through_walls=True,
             see_through_walls=True,
+            max_steps=max_steps,
             **kwargs
             **kwargs
         )
         )
 
 

+ 8 - 4
minigrid/envs/doorkey.py

@@ -1,3 +1,5 @@
+from typing import Optional
+
 from minigrid.core.grid import Grid
 from minigrid.core.grid import Grid
 from minigrid.core.mission import MissionSpace
 from minigrid.core.mission import MissionSpace
 from minigrid.core.world_object import Door, Goal, Key
 from minigrid.core.world_object import Door, Goal, Key
@@ -60,13 +62,15 @@ class DoorKeyEnv(MiniGridEnv):
 
 
     """
     """
 
 
-    def __init__(self, size=8, **kwargs):
-        if "max_steps" not in kwargs:
-            kwargs["max_steps"] = 10 * size * size
+    def __init__(self, size=8, max_steps: Optional[int] = None, **kwargs):
+        if max_steps is None:
+            max_steps = 10 * size**2
         mission_space = MissionSpace(
         mission_space = MissionSpace(
             mission_func=lambda: "use the key to open the door and then get to the goal"
             mission_func=lambda: "use the key to open the door and then get to the goal"
         )
         )
-        super().__init__(mission_space=mission_space, grid_size=size, **kwargs)
+        super().__init__(
+            mission_space=mission_space, grid_size=size, max_steps=max_steps, **kwargs
+        )
 
 
     def _gen_grid(self, width, height):
     def _gen_grid(self, width, height):
         # Create an empty grid
         # Create an empty grid

+ 12 - 2
minigrid/envs/dynamicobstacles.py

@@ -1,4 +1,5 @@
 from operator import add
 from operator import add
+from typing import Optional
 
 
 from gymnasium.spaces import Discrete
 from gymnasium.spaces import Discrete
 
 
@@ -70,7 +71,13 @@ class DynamicObstaclesEnv(MiniGridEnv):
     """
     """
 
 
     def __init__(
     def __init__(
-        self, size=8, agent_start_pos=(1, 1), agent_start_dir=0, n_obstacles=4, **kwargs
+        self,
+        size=8,
+        agent_start_pos=(1, 1),
+        agent_start_dir=0,
+        n_obstacles=4,
+        max_steps: Optional[int] = None,
+        **kwargs
     ):
     ):
         self.agent_start_pos = agent_start_pos
         self.agent_start_pos = agent_start_pos
         self.agent_start_dir = agent_start_dir
         self.agent_start_dir = agent_start_dir
@@ -85,12 +92,15 @@ class DynamicObstaclesEnv(MiniGridEnv):
             mission_func=lambda: "get to the green goal square"
             mission_func=lambda: "get to the green goal square"
         )
         )
 
 
+        if max_steps is None:
+            max_steps = 4 * size**2
+
         super().__init__(
         super().__init__(
             mission_space=mission_space,
             mission_space=mission_space,
             grid_size=size,
             grid_size=size,
-            max_steps=4 * size * size,
             # Set this to True for maximum speed
             # Set this to True for maximum speed
             see_through_walls=True,
             see_through_walls=True,
+            max_steps=max_steps,
             **kwargs
             **kwargs
         )
         )
         # Allow only 3 actions permitted: left, right, forward
         # Allow only 3 actions permitted: left, right, forward

+ 14 - 2
minigrid/envs/empty.py

@@ -1,3 +1,5 @@
+from typing import Optional
+
 from minigrid.core.grid import Grid
 from minigrid.core.grid import Grid
 from minigrid.core.mission import MissionSpace
 from minigrid.core.mission import MissionSpace
 from minigrid.core.world_object import Goal
 from minigrid.core.world_object import Goal
@@ -65,7 +67,14 @@ class EmptyEnv(MiniGridEnv):
 
 
     """
     """
 
 
-    def __init__(self, size=8, agent_start_pos=(1, 1), agent_start_dir=0, **kwargs):
+    def __init__(
+        self,
+        size=8,
+        agent_start_pos=(1, 1),
+        agent_start_dir=0,
+        max_steps: Optional[int] = None,
+        **kwargs
+    ):
         self.agent_start_pos = agent_start_pos
         self.agent_start_pos = agent_start_pos
         self.agent_start_dir = agent_start_dir
         self.agent_start_dir = agent_start_dir
 
 
@@ -73,12 +82,15 @@ class EmptyEnv(MiniGridEnv):
             mission_func=lambda: "get to the green goal square"
             mission_func=lambda: "get to the green goal square"
         )
         )
 
 
+        if max_steps is None:
+            max_steps = 4 * size**2
+
         super().__init__(
         super().__init__(
             mission_space=mission_space,
             mission_space=mission_space,
             grid_size=size,
             grid_size=size,
-            max_steps=4 * size * size,
             # Set this to True for maximum speed
             # Set this to True for maximum speed
             see_through_walls=True,
             see_through_walls=True,
+            max_steps=max_steps,
             **kwargs
             **kwargs
         )
         )
 
 

+ 8 - 2
minigrid/envs/fetch.py

@@ -1,3 +1,5 @@
+from typing import Optional
+
 from minigrid.core.constants import COLOR_NAMES
 from minigrid.core.constants import COLOR_NAMES
 from minigrid.core.grid import Grid
 from minigrid.core.grid import Grid
 from minigrid.core.mission import MissionSpace
 from minigrid.core.mission import MissionSpace
@@ -71,7 +73,7 @@ class FetchEnv(MiniGridEnv):
 
 
     """
     """
 
 
-    def __init__(self, size=8, numObjs=3, **kwargs):
+    def __init__(self, size=8, numObjs=3, max_steps: Optional[int] = None, **kwargs):
         self.numObjs = numObjs
         self.numObjs = numObjs
         self.obj_types = ["key", "ball"]
         self.obj_types = ["key", "ball"]
 
 
@@ -87,13 +89,17 @@ class FetchEnv(MiniGridEnv):
             mission_func=lambda syntax, color, type: f"{syntax} {color} {type}",
             mission_func=lambda syntax, color, type: f"{syntax} {color} {type}",
             ordered_placeholders=[MISSION_SYNTAX, COLOR_NAMES, self.obj_types],
             ordered_placeholders=[MISSION_SYNTAX, COLOR_NAMES, self.obj_types],
         )
         )
+
+        if max_steps is None:
+            max_steps = 5 * size**2
+
         super().__init__(
         super().__init__(
             mission_space=mission_space,
             mission_space=mission_space,
             width=size,
             width=size,
             height=size,
             height=size,
-            max_steps=5 * size**2,
             # Set this to True for maximum speed
             # Set this to True for maximum speed
             see_through_walls=True,
             see_through_walls=True,
+            max_steps=max_steps,
             **kwargs,
             **kwargs,
         )
         )
 
 

+ 2 - 2
minigrid/envs/fourrooms.py

@@ -57,7 +57,7 @@ class FourRoomsEnv(MiniGridEnv):
 
 
     """
     """
 
 
-    def __init__(self, agent_pos=None, goal_pos=None, **kwargs):
+    def __init__(self, agent_pos=None, goal_pos=None, max_steps=100, **kwargs):
         self._agent_default_pos = agent_pos
         self._agent_default_pos = agent_pos
         self._goal_default_pos = goal_pos
         self._goal_default_pos = goal_pos
 
 
@@ -68,7 +68,7 @@ class FourRoomsEnv(MiniGridEnv):
             mission_space=mission_space,
             mission_space=mission_space,
             width=self.size,
             width=self.size,
             height=self.size,
             height=self.size,
-            max_steps=100,
+            max_steps=max_steps,
             **kwargs
             **kwargs
         )
         )
 
 

+ 8 - 2
minigrid/envs/gotodoor.py

@@ -1,3 +1,5 @@
+from typing import Optional
+
 from minigrid.core.constants import COLOR_NAMES
 from minigrid.core.constants import COLOR_NAMES
 from minigrid.core.grid import Grid
 from minigrid.core.grid import Grid
 from minigrid.core.mission import MissionSpace
 from minigrid.core.mission import MissionSpace
@@ -63,20 +65,24 @@ class GoToDoorEnv(MiniGridEnv):
 
 
     """
     """
 
 
-    def __init__(self, size=5, **kwargs):
+    def __init__(self, size=5, max_steps: Optional[int] = None, **kwargs):
         assert size >= 5
         assert size >= 5
         self.size = size
         self.size = size
         mission_space = MissionSpace(
         mission_space = MissionSpace(
             mission_func=lambda color: f"go to the {color} door",
             mission_func=lambda color: f"go to the {color} door",
             ordered_placeholders=[COLOR_NAMES],
             ordered_placeholders=[COLOR_NAMES],
         )
         )
+
+        if max_steps is None:
+            max_steps = 4 * size**2
+
         super().__init__(
         super().__init__(
             mission_space=mission_space,
             mission_space=mission_space,
             width=size,
             width=size,
             height=size,
             height=size,
-            max_steps=5 * size**2,
             # Set this to True for maximum speed
             # Set this to True for maximum speed
             see_through_walls=True,
             see_through_walls=True,
+            max_steps=max_steps,
             **kwargs,
             **kwargs,
         )
         )
 
 

+ 9 - 2
minigrid/envs/gotoobject.py

@@ -1,3 +1,5 @@
+from typing import Optional
+
 from minigrid.core.constants import COLOR_NAMES
 from minigrid.core.constants import COLOR_NAMES
 from minigrid.core.grid import Grid
 from minigrid.core.grid import Grid
 from minigrid.core.mission import MissionSpace
 from minigrid.core.mission import MissionSpace
@@ -11,7 +13,8 @@ class GoToObjectEnv(MiniGridEnv):
     named using an English text string
     named using an English text string
     """
     """
 
 
-    def __init__(self, size=6, numObjs=2, **kwargs):
+    def __init__(self, size=6, numObjs=2, max_steps: Optional[int] = None, **kwargs):
+
         self.numObjs = numObjs
         self.numObjs = numObjs
         self.size = size
         self.size = size
         # Types of objects to be generated
         # Types of objects to be generated
@@ -21,13 +24,17 @@ class GoToObjectEnv(MiniGridEnv):
             mission_func=lambda color, type: f"go to the {color} {type}",
             mission_func=lambda color, type: f"go to the {color} {type}",
             ordered_placeholders=[COLOR_NAMES, self.obj_types],
             ordered_placeholders=[COLOR_NAMES, self.obj_types],
         )
         )
+
+        if max_steps is None:
+            max_steps = 5 * size**2
+
         super().__init__(
         super().__init__(
             mission_space=mission_space,
             mission_space=mission_space,
             width=size,
             width=size,
             height=size,
             height=size,
-            max_steps=5 * size**2,
             # Set this to True for maximum speed
             # Set this to True for maximum speed
             see_through_walls=True,
             see_through_walls=True,
+            max_steps=max_steps,
             **kwargs,
             **kwargs,
         )
         )
 
 

+ 15 - 2
minigrid/envs/keycorridor.py

@@ -1,3 +1,5 @@
+from typing import Optional
+
 from minigrid.core.constants import COLOR_NAMES
 from minigrid.core.constants import COLOR_NAMES
 from minigrid.core.mission import MissionSpace
 from minigrid.core.mission import MissionSpace
 from minigrid.core.roomgrid import RoomGrid
 from minigrid.core.roomgrid import RoomGrid
@@ -77,17 +79,28 @@ class KeyCorridorEnv(RoomGrid):
 
 
     """
     """
 
 
-    def __init__(self, num_rows=3, obj_type="ball", room_size=6, **kwargs):
+    def __init__(
+        self,
+        num_rows=3,
+        obj_type="ball",
+        room_size=6,
+        max_steps: Optional[int] = None,
+        **kwargs,
+    ):
         self.obj_type = obj_type
         self.obj_type = obj_type
         mission_space = MissionSpace(
         mission_space = MissionSpace(
             mission_func=lambda color: f"pick up the {color} {obj_type}",
             mission_func=lambda color: f"pick up the {color} {obj_type}",
             ordered_placeholders=[COLOR_NAMES],
             ordered_placeholders=[COLOR_NAMES],
         )
         )
+
+        if max_steps is None:
+            max_steps = 30 * room_size**2
+
         super().__init__(
         super().__init__(
             mission_space=mission_space,
             mission_space=mission_space,
             room_size=room_size,
             room_size=room_size,
             num_rows=num_rows,
             num_rows=num_rows,
-            max_steps=30 * room_size**2,
+            max_steps=max_steps,
             **kwargs,
             **kwargs,
         )
         )
 
 

+ 9 - 2
minigrid/envs/lavagap.py

@@ -1,3 +1,5 @@
+from typing import Optional
+
 import numpy as np
 import numpy as np
 
 
 from minigrid.core.grid import Grid
 from minigrid.core.grid import Grid
@@ -66,7 +68,9 @@ class LavaGapEnv(MiniGridEnv):
 
 
     """
     """
 
 
-    def __init__(self, size, obstacle_type=Lava, **kwargs):
+    def __init__(
+        self, size, obstacle_type=Lava, max_steps: Optional[int] = None, **kwargs
+    ):
         self.obstacle_type = obstacle_type
         self.obstacle_type = obstacle_type
         self.size = size
         self.size = size
 
 
@@ -79,13 +83,16 @@ class LavaGapEnv(MiniGridEnv):
                 mission_func=lambda: "find the opening and get to the green goal square"
                 mission_func=lambda: "find the opening and get to the green goal square"
             )
             )
 
 
+        if max_steps is None:
+            max_steps = 4 * size**2
+
         super().__init__(
         super().__init__(
             mission_space=mission_space,
             mission_space=mission_space,
             width=size,
             width=size,
             height=size,
             height=size,
-            max_steps=4 * size * size,
             # Set this to True for maximum speed
             # Set this to True for maximum speed
             see_through_walls=False,
             see_through_walls=False,
+            max_steps=max_steps,
             **kwargs
             **kwargs
         )
         )
 
 

+ 7 - 2
minigrid/envs/lockedroom.py

@@ -1,3 +1,5 @@
+from typing import Optional
+
 from minigrid.core.constants import COLOR_NAMES
 from minigrid.core.constants import COLOR_NAMES
 from minigrid.core.grid import Grid
 from minigrid.core.grid import Grid
 from minigrid.core.mission import MissionSpace
 from minigrid.core.mission import MissionSpace
@@ -74,8 +76,11 @@ class LockedRoomEnv(MiniGridEnv):
 
 
     """
     """
 
 
-    def __init__(self, size=19, **kwargs):
+    def __init__(self, size=19, max_steps: Optional[int] = None, **kwargs):
         self.size = size
         self.size = size
+
+        if max_steps is None:
+            max_steps = 10 * size
         mission_space = MissionSpace(
         mission_space = MissionSpace(
             mission_func=lambda lockedroom_color, keyroom_color, door_color: f"get the {lockedroom_color} key from the {keyroom_color} room, unlock the {door_color} door and go to the goal",
             mission_func=lambda lockedroom_color, keyroom_color, door_color: f"get the {lockedroom_color} key from the {keyroom_color} room, unlock the {door_color} door and go to the goal",
             ordered_placeholders=[COLOR_NAMES] * 3,
             ordered_placeholders=[COLOR_NAMES] * 3,
@@ -84,7 +89,7 @@ class LockedRoomEnv(MiniGridEnv):
             mission_space=mission_space,
             mission_space=mission_space,
             width=size,
             width=size,
             height=size,
             height=size,
-            max_steps=10 * size,
+            max_steps=max_steps,
             **kwargs,
             **kwargs,
         )
         )
 
 

+ 10 - 2
minigrid/envs/memory.py

@@ -1,3 +1,5 @@
+from typing import Optional
+
 import numpy as np
 import numpy as np
 
 
 from minigrid.core.actions import Actions
 from minigrid.core.actions import Actions
@@ -65,9 +67,15 @@ class MemoryEnv(MiniGridEnv):
 
 
     """
     """
 
 
-    def __init__(self, size=8, random_length=False, **kwargs):
+    def __init__(
+        self, size=8, random_length=False, max_steps: Optional[int] = None, **kwargs
+    ):
         self.size = size
         self.size = size
         self.random_length = random_length
         self.random_length = random_length
+
+        if max_steps is None:
+            max_steps = 5 * size**2
+
         mission_space = MissionSpace(
         mission_space = MissionSpace(
             mission_func=lambda: "go to the matching object at the end of the hallway"
             mission_func=lambda: "go to the matching object at the end of the hallway"
         )
         )
@@ -75,9 +83,9 @@ class MemoryEnv(MiniGridEnv):
             mission_space=mission_space,
             mission_space=mission_space,
             width=size,
             width=size,
             height=size,
             height=size,
-            max_steps=5 * size**2,
             # Set this to True for maximum speed
             # Set this to True for maximum speed
             see_through_walls=False,
             see_through_walls=False,
+            max_steps=max_steps,
             **kwargs
             **kwargs
         )
         )
 
 

+ 14 - 2
minigrid/envs/multiroom.py

@@ -1,3 +1,5 @@
+from typing import Optional
+
 from minigrid.core.constants import COLOR_NAMES
 from minigrid.core.constants import COLOR_NAMES
 from minigrid.core.grid import Grid
 from minigrid.core.grid import Grid
 from minigrid.core.mission import MissionSpace
 from minigrid.core.mission import MissionSpace
@@ -72,7 +74,14 @@ class MultiRoomEnv(MiniGridEnv):
 
 
     """
     """
 
 
-    def __init__(self, minNumRooms, maxNumRooms, maxRoomSize=10, **kwargs):
+    def __init__(
+        self,
+        minNumRooms,
+        maxNumRooms,
+        maxRoomSize=10,
+        max_steps: Optional[int] = None,
+        **kwargs
+    ):
         assert minNumRooms > 0
         assert minNumRooms > 0
         assert maxNumRooms >= minNumRooms
         assert maxNumRooms >= minNumRooms
         assert maxRoomSize >= 4
         assert maxRoomSize >= 4
@@ -89,11 +98,14 @@ class MultiRoomEnv(MiniGridEnv):
 
 
         self.size = 25
         self.size = 25
 
 
+        if max_steps is None:
+            max_steps = maxNumRooms * 20
+
         super().__init__(
         super().__init__(
             mission_space=mission_space,
             mission_space=mission_space,
             width=self.size,
             width=self.size,
             height=self.size,
             height=self.size,
-            max_steps=self.maxNumRooms * 20,
+            max_steps=max_steps,
             **kwargs
             **kwargs
         )
         )
 
 

+ 13 - 2
minigrid/envs/obstructedmaze.py

@@ -1,3 +1,5 @@
+from typing import Optional
+
 from minigrid.core.constants import COLOR_NAMES, DIR_TO_VEC
 from minigrid.core.constants import COLOR_NAMES, DIR_TO_VEC
 from minigrid.core.mission import MissionSpace
 from minigrid.core.mission import MissionSpace
 from minigrid.core.roomgrid import RoomGrid
 from minigrid.core.roomgrid import RoomGrid
@@ -79,9 +81,18 @@ class ObstructedMazeEnv(RoomGrid):
 
 
     """
     """
 
 
-    def __init__(self, num_rows, num_cols, num_rooms_visited, **kwargs):
+    def __init__(
+        self,
+        num_rows,
+        num_cols,
+        num_rooms_visited,
+        max_steps: Optional[int] = None,
+        **kwargs,
+    ):
         room_size = 6
         room_size = 6
-        max_steps = 4 * num_rooms_visited * room_size**2
+
+        if max_steps is None:
+            max_steps = 4 * num_rooms_visited * room_size**2
 
 
         mission_space = MissionSpace(
         mission_space = MissionSpace(
             mission_func=lambda: f"pick up the {COLOR_NAMES[0]} ball",
             mission_func=lambda: f"pick up the {COLOR_NAMES[0]} ball",

+ 2 - 2
minigrid/envs/playground.py

@@ -11,14 +11,14 @@ class PlaygroundEnv(MiniGridEnv):
     This environment has no specific goals or rewards.
     This environment has no specific goals or rewards.
     """
     """
 
 
-    def __init__(self, **kwargs):
+    def __init__(self, max_steps=100, **kwargs):
         mission_space = MissionSpace(mission_func=lambda: "")
         mission_space = MissionSpace(mission_func=lambda: "")
         self.size = 19
         self.size = 19
         super().__init__(
         super().__init__(
             mission_space=mission_space,
             mission_space=mission_space,
             width=self.size,
             width=self.size,
             height=self.size,
             height=self.size,
-            max_steps=100,
+            max_steps=max_steps,
             **kwargs
             **kwargs
         )
         )
 
 

+ 8 - 2
minigrid/envs/putnear.py

@@ -1,3 +1,5 @@
+from typing import Optional
+
 from minigrid.core.constants import COLOR_NAMES
 from minigrid.core.constants import COLOR_NAMES
 from minigrid.core.grid import Grid
 from minigrid.core.grid import Grid
 from minigrid.core.mission import MissionSpace
 from minigrid.core.mission import MissionSpace
@@ -65,7 +67,7 @@ class PutNearEnv(MiniGridEnv):
 
 
     """
     """
 
 
-    def __init__(self, size=6, numObjs=2, **kwargs):
+    def __init__(self, size=6, numObjs=2, max_steps: Optional[int] = None, **kwargs):
         self.size = size
         self.size = size
         self.numObjs = numObjs
         self.numObjs = numObjs
         self.obj_types = ["key", "ball", "box"]
         self.obj_types = ["key", "ball", "box"]
@@ -78,13 +80,17 @@ class PutNearEnv(MiniGridEnv):
                 self.obj_types,
                 self.obj_types,
             ],
             ],
         )
         )
+
+        if max_steps is None:
+            max_steps = 5 * size
+
         super().__init__(
         super().__init__(
             mission_space=mission_space,
             mission_space=mission_space,
             width=size,
             width=size,
             height=size,
             height=size,
-            max_steps=5 * size,
             # Set this to True for maximum speed
             # Set this to True for maximum speed
             see_through_walls=True,
             see_through_walls=True,
+            max_steps=max_steps,
             **kwargs,
             **kwargs,
         )
         )
 
 

+ 8 - 2
minigrid/envs/redbluedoors.py

@@ -1,3 +1,5 @@
+from typing import Optional
+
 from minigrid.core.grid import Grid
 from minigrid.core.grid import Grid
 from minigrid.core.mission import MissionSpace
 from minigrid.core.mission import MissionSpace
 from minigrid.core.world_object import Door
 from minigrid.core.world_object import Door
@@ -57,16 +59,20 @@ class RedBlueDoorEnv(MiniGridEnv):
 
 
     """
     """
 
 
-    def __init__(self, size=8, **kwargs):
+    def __init__(self, size=8, max_steps: Optional[int] = None, **kwargs):
         self.size = size
         self.size = size
         mission_space = MissionSpace(
         mission_space = MissionSpace(
             mission_func=lambda: "open the red door then the blue door"
             mission_func=lambda: "open the red door then the blue door"
         )
         )
+
+        if max_steps is None:
+            max_steps = 20 * size**2
+
         super().__init__(
         super().__init__(
             mission_space=mission_space,
             mission_space=mission_space,
             width=2 * size,
             width=2 * size,
             height=size,
             height=size,
-            max_steps=20 * size * size,
+            max_steps=max_steps,
             **kwargs
             **kwargs
         )
         )
 
 

+ 8 - 2
minigrid/envs/unlock.py

@@ -1,3 +1,5 @@
+from typing import Optional
+
 from minigrid.core.mission import MissionSpace
 from minigrid.core.mission import MissionSpace
 from minigrid.core.roomgrid import RoomGrid
 from minigrid.core.roomgrid import RoomGrid
 
 
@@ -53,15 +55,19 @@ class UnlockEnv(RoomGrid):
 
 
     """
     """
 
 
-    def __init__(self, **kwargs):
+    def __init__(self, max_steps: Optional[int] = None, **kwargs):
         room_size = 6
         room_size = 6
         mission_space = MissionSpace(mission_func=lambda: "open the door")
         mission_space = MissionSpace(mission_func=lambda: "open the door")
+
+        if max_steps is None:
+            max_steps = 8 * room_size**2
+
         super().__init__(
         super().__init__(
             mission_space=mission_space,
             mission_space=mission_space,
             num_rows=1,
             num_rows=1,
             num_cols=2,
             num_cols=2,
             room_size=room_size,
             room_size=room_size,
-            max_steps=8 * room_size**2,
+            max_steps=max_steps,
             **kwargs
             **kwargs
         )
         )
 
 

+ 8 - 2
minigrid/envs/unlockpickup.py

@@ -1,3 +1,5 @@
+from typing import Optional
+
 from minigrid.core.constants import COLOR_NAMES
 from minigrid.core.constants import COLOR_NAMES
 from minigrid.core.mission import MissionSpace
 from minigrid.core.mission import MissionSpace
 from minigrid.core.roomgrid import RoomGrid
 from minigrid.core.roomgrid import RoomGrid
@@ -57,18 +59,22 @@ class UnlockPickupEnv(RoomGrid):
 
 
     """
     """
 
 
-    def __init__(self, **kwargs):
+    def __init__(self, max_steps: Optional[int] = None, **kwargs):
         room_size = 6
         room_size = 6
         mission_space = MissionSpace(
         mission_space = MissionSpace(
             mission_func=lambda color: f"pick up the {color} box",
             mission_func=lambda color: f"pick up the {color} box",
             ordered_placeholders=[COLOR_NAMES],
             ordered_placeholders=[COLOR_NAMES],
         )
         )
+
+        if max_steps is None:
+            max_steps = 8 * room_size**2
+
         super().__init__(
         super().__init__(
             mission_space=mission_space,
             mission_space=mission_space,
             num_rows=1,
             num_rows=1,
             num_cols=2,
             num_cols=2,
             room_size=room_size,
             room_size=room_size,
-            max_steps=8 * room_size**2,
+            max_steps=max_steps,
             **kwargs,
             **kwargs,
         )
         )
 
 

+ 5 - 0
minigrid/minigrid_env.py

@@ -903,7 +903,12 @@ class MiniGridEnv(gym.Env):
         # Environment configuration
         # Environment configuration
         self.width = width
         self.width = width
         self.height = height
         self.height = height
+
+        assert isinstance(
+            max_steps, int
+        ), f"The argument max_steps must be an integer, got: {type(max_steps)}"
         self.max_steps = max_steps
         self.max_steps = max_steps
+
         self.see_through_walls = see_through_walls
         self.see_through_walls = see_through_walls
 
 
         # Current position and direction of the agent
         # Current position and direction of the agent

+ 23 - 0
tests/test_envs.py

@@ -142,6 +142,29 @@ def test_agent_sees_method(env_id):
 @pytest.mark.parametrize(
 @pytest.mark.parametrize(
     "env_spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
     "env_spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
 )
 )
+def test_max_steps_argument(env_spec):
+    """
+    Test that when initializing an environment with a fixed number of steps per episode (`max_steps` argument),
+    the episode will be truncated after taking that number of steps.
+    """
+    max_steps = 50
+    env = env_spec.make(max_steps=max_steps)
+    env.reset()
+    step_count = 0
+    while True:
+        _, _, terminated, truncated, _ = env.step(4)
+        step_count += 1
+        if truncated:
+            assert step_count == max_steps
+            step_count = 0
+            break
+
+    env.close()
+
+
+@pytest.mark.parametrize(
+    "env_spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
+)
 def old_run_test(env_spec):
 def old_run_test(env_spec):
     # Load the gym environment
     # Load the gym environment
     env = env_spec.make()
     env = env_spec.make()