Browse Source

Add `max_steps` argument for all environments (#265)

Rodrigo de Lazcano 2 years ago
parent
commit
7751a5cbeb

+ 20 - 3
minigrid/envs/babyai/core/roomgrid_level.py

@@ -1,6 +1,8 @@
 """
 Copied and adapted from https://github.com/mila-iqia/babyai
 """
+from typing import Optional
+
 from minigrid.core.roomgrid import RoomGrid
 from minigrid.envs.babyai.core.verifier import (
     ActionInstr,
@@ -44,9 +46,22 @@ class RoomGridLevel(RoomGrid):
     of approximately similar difficulty.
     """
 
-    def __init__(self, room_size=8, **kwargs):
+    def __init__(self, room_size=8, max_steps: Optional[int] = None, **kwargs):
         mission_space = BabyAIMissionSpace()
-        super().__init__(room_size=room_size, mission_space=mission_space, **kwargs)
+
+        # If `max_steps` arg is passed it will be fixed for every episode,
+        # if not it will vary after reset depending on the maze size.
+        self.fixed_max_steps = False
+        if max_steps is not None:
+            self.fixed_max_steps = True
+        else:
+            max_steps = 0  # only for initialization
+        super().__init__(
+            room_size=room_size,
+            mission_space=mission_space,
+            max_steps=max_steps,
+            **kwargs
+        )
 
     def reset(self, **kwargs):
         obs = super().reset(**kwargs)
@@ -58,7 +73,9 @@ class RoomGridLevel(RoomGrid):
         nav_time_room = self.room_size**2
         nav_time_maze = nav_time_room * self.num_rows * self.num_cols
         num_navs = self.num_navs_needed(self.instrs)
-        self.max_steps = num_navs * nav_time_maze
+
+        if not self.fixed_max_steps:
+            self.max_steps = num_navs * nav_time_maze
 
         return obs
 

+ 20 - 4
minigrid/envs/babyai/open.py

@@ -2,6 +2,7 @@
 Copied and adapted from https://github.com/mila-iqia/babyai.
 Levels described in the Baby AI ICLR 2019 submission, with the `Open` instruction.
 """
+from typing import Optional
 
 from minigrid.core.constants import COLOR_NAMES
 from minigrid.envs.babyai.core.roomgrid_level import RoomGridLevel
@@ -97,13 +98,23 @@ class OpenTwoDoors(RoomGridLevel):
     This task requires memory (recurrent policy) to be solved effectively.
     """
 
-    def __init__(self, first_color=None, second_color=None, strict=False, **kwargs):
+    def __init__(
+        self,
+        first_color=None,
+        second_color=None,
+        strict=False,
+        max_steps: Optional[int] = None,
+        **kwargs
+    ):
         self.first_color = first_color
         self.second_color = second_color
         self.strict = strict
 
         room_size = 6
-        super().__init__(room_size=room_size, max_steps=20 * room_size**2, **kwargs)
+        if max_steps is None:
+            max_steps = 20 * room_size**2
+
+        super().__init__(room_size=room_size, max_steps=max_steps, **kwargs)
 
     def gen_mission(self):
         colors = self._rand_subset(COLOR_NAMES, 2)
@@ -131,13 +142,18 @@ class OpenDoorsOrder(RoomGridLevel):
     Open one or two doors in the order specified.
     """
 
-    def __init__(self, num_doors, debug=False, **kwargs):
+    def __init__(
+        self, num_doors, debug=False, max_steps: Optional[int] = None, **kwargs
+    ):
         assert num_doors >= 2
         self.num_doors = num_doors
         self.debug = debug
 
         room_size = 6
-        super().__init__(room_size=room_size, max_steps=20 * room_size**2, **kwargs)
+        if max_steps is None:
+            max_steps = 20 * room_size**2
+
+        super().__init__(room_size=room_size, max_steps=max_steps, **kwargs)
 
     def gen_mission(self):
         colors = self._rand_subset(COLOR_NAMES, self.num_doors)

+ 26 - 13
minigrid/envs/babyai/other.py

@@ -2,6 +2,7 @@
 Copied and adapted from https://github.com/mila-iqia/babyai.
 Levels described in the Baby AI ICLR 2019 submission, with different instructions than those in other files.
 """
+from typing import Optional
 
 from minigrid.envs.babyai.core.roomgrid_level import RoomGridLevel
 from minigrid.envs.babyai.core.verifier import (
@@ -55,8 +56,12 @@ class FindObjS5(RoomGridLevel):
     This level requires potentially exhaustive exploration
     """
 
-    def __init__(self, room_size=5, **kwargs):
-        super().__init__(room_size=room_size, max_steps=20 * room_size**2, **kwargs)
+    def __init__(self, room_size=5, max_steps: Optional[int] = None, **kwargs):
+
+        if max_steps is None:
+            max_steps = 20 * room_size**2
+
+        super().__init__(room_size=room_size, max_steps=max_steps, **kwargs)
 
     def gen_mission(self):
         # Add a random object to a random room
@@ -75,14 +80,21 @@ class KeyCorridor(RoomGridLevel):
     random room.
     """
 
-    def __init__(self, num_rows=3, obj_type="ball", room_size=6, **kwargs):
+    def __init__(
+        self,
+        num_rows=3,
+        obj_type="ball",
+        room_size=6,
+        max_steps: Optional[int] = None,
+        **kwargs
+    ):
         self.obj_type = obj_type
 
+        if max_steps is None:
+            max_steps = 30 * room_size**2
+
         super().__init__(
-            room_size=room_size,
-            num_rows=num_rows,
-            max_steps=30 * room_size**2,
-            **kwargs
+            room_size=room_size, num_rows=num_rows, max_steps=max_steps, **kwargs
         )
 
     def gen_mission(self):
@@ -130,16 +142,17 @@ class MoveTwoAcross(RoomGridLevel):
     instructions.
     """
 
-    def __init__(self, room_size, objs_per_room, **kwargs):
+    def __init__(
+        self, room_size, objs_per_room, max_steps: Optional[int] = None, **kwargs
+    ):
         assert objs_per_room <= 9
         self.objs_per_room = objs_per_room
 
+        if max_steps is None:
+            max_steps = 16 * room_size**2
+
         super().__init__(
-            num_rows=1,
-            num_cols=2,
-            room_size=room_size,
-            max_steps=16 * room_size**2,
-            **kwargs
+            num_rows=1, num_cols=2, room_size=room_size, max_steps=max_steps, **kwargs
         )
 
     def gen_mission(self):

+ 6 - 2
minigrid/envs/babyai/pickup.py

@@ -2,6 +2,7 @@
 Copied and adapted from https://github.com/mila-iqia/babyai.
 Levels described in the Baby AI ICLR 2019 submission, with the `Pick up` instruction.
 """
+from typing import Optional
 
 from minigrid.envs.babyai.core.levelgen import LevelGen
 from minigrid.envs.babyai.core.roomgrid_level import RejectSampling, RoomGridLevel
@@ -101,9 +102,12 @@ class PickupAbove(RoomGridLevel):
     This task requires to use the compass to be solved effectively.
     """
 
-    def __init__(self, **kwargs):
+    def __init__(self, max_steps: Optional[int] = None, **kwargs):
         room_size = 6
-        super().__init__(room_size=room_size, max_steps=8 * room_size**2, **kwargs)
+        if max_steps is None:
+            max_steps = 8 * room_size**2
+
+        super().__init__(room_size=room_size, max_steps=max_steps, **kwargs)
 
     def gen_mission(self):
         # Add a random object to the top-middle room

+ 13 - 6
minigrid/envs/babyai/putnext.py

@@ -2,6 +2,7 @@
 Copied and adapted from https://github.com/mila-iqia/babyai.
 Levels described in the Baby AI ICLR 2019 submission, with the `Put Next` instruction.
 """
+from typing import Optional
 
 from minigrid.envs.babyai.core.roomgrid_level import RoomGridLevel
 from minigrid.envs.babyai.core.verifier import ObjDesc, PutNextInstr
@@ -35,18 +36,24 @@ class PutNext(RoomGridLevel):
     instructions.
     """
 
-    def __init__(self, room_size, objs_per_room, start_carrying=False, **kwargs):
+    def __init__(
+        self,
+        room_size,
+        objs_per_room,
+        start_carrying=False,
+        max_steps: Optional[int] = None,
+        **kwargs
+    ):
         assert room_size >= 4
         assert objs_per_room <= 9
         self.objs_per_room = objs_per_room
         self.start_carrying = start_carrying
 
+        if max_steps is None:
+            max_steps = 8 * room_size**2
+
         super().__init__(
-            num_rows=1,
-            num_cols=2,
-            room_size=room_size,
-            max_steps=8 * room_size**2,
-            **kwargs
+            num_rows=1, num_cols=2, room_size=room_size, max_steps=max_steps, **kwargs
         )
 
     def gen_mission(self):

+ 16 - 19
minigrid/envs/babyai/unlock.py

@@ -2,6 +2,7 @@
 Copied and adapted from https://github.com/mila-iqia/babyai.
 Levels described in the Baby AI ICLR 2019 submission, with the `Unlock` instruction.
 """
+from typing import Optional
 
 from minigrid.core.constants import COLOR_NAMES
 from minigrid.envs.babyai.core.roomgrid_level import RoomGridLevel
@@ -109,16 +110,14 @@ class UnlockPickup(RoomGridLevel):
     Unlock a door, then pick up a box in another room
     """
 
-    def __init__(self, distractors=False, **kwargs):
+    def __init__(self, distractors=False, max_steps: Optional[int] = None, **kwargs):
         self.distractors = distractors
-
         room_size = 6
+        if max is None:
+            max_steps = 8 * room_size**2
+
         super().__init__(
-            num_rows=1,
-            num_cols=2,
-            room_size=room_size,
-            max_steps=8 * room_size**2,
-            **kwargs
+            num_rows=1, num_cols=2, room_size=6, max_steps=max_steps, **kwargs
         )
 
     def gen_mission(self):
@@ -142,14 +141,13 @@ class BlockedUnlockPickup(RoomGridLevel):
     in another room
     """
 
-    def __init__(self, **kwargs):
+    def __init__(self, max_steps: Optional[int] = None, **kwargs):
         room_size = 6
+        if max_steps is None:
+            max_steps = 16 * room_size**2
+
         super().__init__(
-            num_rows=1,
-            num_cols=2,
-            room_size=room_size,
-            max_steps=16 * room_size**2,
-            **kwargs
+            num_rows=1, num_cols=2, room_size=room_size, max_steps=max_steps, **kwargs
         )
 
     def gen_mission(self):
@@ -173,14 +171,13 @@ class UnlockToUnlock(RoomGridLevel):
     Unlock a door A that requires to unlock a door B before
     """
 
-    def __init__(self, **kwargs):
+    def __init__(self, max_steps: Optional[int] = None, **kwargs):
         room_size = 6
+        if max_steps is None:
+            max_steps = 30 * room_size**2
+
         super().__init__(
-            num_rows=1,
-            num_cols=3,
-            room_size=room_size,
-            max_steps=30 * room_size**2,
-            **kwargs
+            num_rows=1, num_cols=3, room_size=room_size, max_steps=max_steps, **kwargs
         )
 
     def gen_mission(self):

+ 9 - 3
minigrid/envs/blockedunlockpickup.py

@@ -1,3 +1,5 @@
+from typing import Optional
+
 from minigrid.core.constants import COLOR_NAMES
 from minigrid.core.mission import MissionSpace
 from minigrid.core.roomgrid import RoomGrid
@@ -63,18 +65,22 @@ class BlockedUnlockPickupEnv(RoomGrid):
 
     """
 
-    def __init__(self, **kwargs):
-        room_size = 6
+    def __init__(self, max_steps: Optional[int] = None, **kwargs):
         mission_space = MissionSpace(
             mission_func=lambda color, type: f"pick up the {color} {type}",
             ordered_placeholders=[COLOR_NAMES, ["box", "key"]],
         )
+
+        room_size = 6
+        if max_steps is None:
+            max_steps = 16 * room_size**2
+
         super().__init__(
             mission_space=mission_space,
             num_rows=1,
             num_cols=2,
             room_size=room_size,
-            max_steps=16 * room_size**2,
+            max_steps=max_steps,
             **kwargs,
         )
 

+ 14 - 4
minigrid/envs/crossing.py

@@ -1,4 +1,5 @@
 import itertools as itt
+from typing import Optional
 
 import numpy as np
 
@@ -90,7 +91,14 @@ class CrossingEnv(MiniGridEnv):
 
     """
 
-    def __init__(self, size=9, num_crossings=1, obstacle_type=Lava, **kwargs):
+    def __init__(
+        self,
+        size=9,
+        num_crossings=1,
+        obstacle_type=Lava,
+        max_steps: Optional[int] = None,
+        **kwargs
+    ):
         self.num_crossings = num_crossings
         self.obstacle_type = obstacle_type
 
@@ -103,12 +111,14 @@ class CrossingEnv(MiniGridEnv):
                 mission_func=lambda: "find the opening and get to the green goal square"
             )
 
+        if max_steps is None:
+            max_steps = 4 * size**2
+
         super().__init__(
             mission_space=mission_space,
             grid_size=size,
-            max_steps=4 * size * size,
-            # Set this to True for maximum speed
-            see_through_walls=False,
+            see_through_walls=False,  # Set this to True for maximum speed
+            max_steps=max_steps,
             **kwargs
         )
 

+ 7 - 1
minigrid/envs/distshift.py

@@ -1,3 +1,5 @@
+from typing import Optional
+
 from minigrid.core.grid import Grid
 from minigrid.core.mission import MissionSpace
 from minigrid.core.world_object import Goal, Lava
@@ -70,6 +72,7 @@ class DistShiftEnv(MiniGridEnv):
         agent_start_pos=(1, 1),
         agent_start_dir=0,
         strip2_row=2,
+        max_steps: Optional[int] = None,
         **kwargs
     ):
         self.agent_start_pos = agent_start_pos
@@ -81,13 +84,16 @@ class DistShiftEnv(MiniGridEnv):
             mission_func=lambda: "get to the green goal square"
         )
 
+        if max_steps is None:
+            max_steps = 4 * width * height
+
         super().__init__(
             mission_space=mission_space,
             width=width,
             height=height,
-            max_steps=4 * width * height,
             # Set this to True for maximum speed
             see_through_walls=True,
+            max_steps=max_steps,
             **kwargs
         )
 

+ 8 - 4
minigrid/envs/doorkey.py

@@ -1,3 +1,5 @@
+from typing import Optional
+
 from minigrid.core.grid import Grid
 from minigrid.core.mission import MissionSpace
 from minigrid.core.world_object import Door, Goal, Key
@@ -60,13 +62,15 @@ class DoorKeyEnv(MiniGridEnv):
 
     """
 
-    def __init__(self, size=8, **kwargs):
-        if "max_steps" not in kwargs:
-            kwargs["max_steps"] = 10 * size * size
+    def __init__(self, size=8, max_steps: Optional[int] = None, **kwargs):
+        if max_steps is None:
+            max_steps = 10 * size**2
         mission_space = MissionSpace(
             mission_func=lambda: "use the key to open the door and then get to the goal"
         )
-        super().__init__(mission_space=mission_space, grid_size=size, **kwargs)
+        super().__init__(
+            mission_space=mission_space, grid_size=size, max_steps=max_steps, **kwargs
+        )
 
     def _gen_grid(self, width, height):
         # Create an empty grid

+ 12 - 2
minigrid/envs/dynamicobstacles.py

@@ -1,4 +1,5 @@
 from operator import add
+from typing import Optional
 
 from gymnasium.spaces import Discrete
 
@@ -70,7 +71,13 @@ class DynamicObstaclesEnv(MiniGridEnv):
     """
 
     def __init__(
-        self, size=8, agent_start_pos=(1, 1), agent_start_dir=0, n_obstacles=4, **kwargs
+        self,
+        size=8,
+        agent_start_pos=(1, 1),
+        agent_start_dir=0,
+        n_obstacles=4,
+        max_steps: Optional[int] = None,
+        **kwargs
     ):
         self.agent_start_pos = agent_start_pos
         self.agent_start_dir = agent_start_dir
@@ -85,12 +92,15 @@ class DynamicObstaclesEnv(MiniGridEnv):
             mission_func=lambda: "get to the green goal square"
         )
 
+        if max_steps is None:
+            max_steps = 4 * size**2
+
         super().__init__(
             mission_space=mission_space,
             grid_size=size,
-            max_steps=4 * size * size,
             # Set this to True for maximum speed
             see_through_walls=True,
+            max_steps=max_steps,
             **kwargs
         )
         # Allow only 3 actions permitted: left, right, forward

+ 14 - 2
minigrid/envs/empty.py

@@ -1,3 +1,5 @@
+from typing import Optional
+
 from minigrid.core.grid import Grid
 from minigrid.core.mission import MissionSpace
 from minigrid.core.world_object import Goal
@@ -65,7 +67,14 @@ class EmptyEnv(MiniGridEnv):
 
     """
 
-    def __init__(self, size=8, agent_start_pos=(1, 1), agent_start_dir=0, **kwargs):
+    def __init__(
+        self,
+        size=8,
+        agent_start_pos=(1, 1),
+        agent_start_dir=0,
+        max_steps: Optional[int] = None,
+        **kwargs
+    ):
         self.agent_start_pos = agent_start_pos
         self.agent_start_dir = agent_start_dir
 
@@ -73,12 +82,15 @@ class EmptyEnv(MiniGridEnv):
             mission_func=lambda: "get to the green goal square"
         )
 
+        if max_steps is None:
+            max_steps = 4 * size**2
+
         super().__init__(
             mission_space=mission_space,
             grid_size=size,
-            max_steps=4 * size * size,
             # Set this to True for maximum speed
             see_through_walls=True,
+            max_steps=max_steps,
             **kwargs
         )
 

+ 8 - 2
minigrid/envs/fetch.py

@@ -1,3 +1,5 @@
+from typing import Optional
+
 from minigrid.core.constants import COLOR_NAMES
 from minigrid.core.grid import Grid
 from minigrid.core.mission import MissionSpace
@@ -71,7 +73,7 @@ class FetchEnv(MiniGridEnv):
 
     """
 
-    def __init__(self, size=8, numObjs=3, **kwargs):
+    def __init__(self, size=8, numObjs=3, max_steps: Optional[int] = None, **kwargs):
         self.numObjs = numObjs
         self.obj_types = ["key", "ball"]
 
@@ -87,13 +89,17 @@ class FetchEnv(MiniGridEnv):
             mission_func=lambda syntax, color, type: f"{syntax} {color} {type}",
             ordered_placeholders=[MISSION_SYNTAX, COLOR_NAMES, self.obj_types],
         )
+
+        if max_steps is None:
+            max_steps = 5 * size**2
+
         super().__init__(
             mission_space=mission_space,
             width=size,
             height=size,
-            max_steps=5 * size**2,
             # Set this to True for maximum speed
             see_through_walls=True,
+            max_steps=max_steps,
             **kwargs,
         )
 

+ 2 - 2
minigrid/envs/fourrooms.py

@@ -57,7 +57,7 @@ class FourRoomsEnv(MiniGridEnv):
 
     """
 
-    def __init__(self, agent_pos=None, goal_pos=None, **kwargs):
+    def __init__(self, agent_pos=None, goal_pos=None, max_steps=100, **kwargs):
         self._agent_default_pos = agent_pos
         self._goal_default_pos = goal_pos
 
@@ -68,7 +68,7 @@ class FourRoomsEnv(MiniGridEnv):
             mission_space=mission_space,
             width=self.size,
             height=self.size,
-            max_steps=100,
+            max_steps=max_steps,
             **kwargs
         )
 

+ 8 - 2
minigrid/envs/gotodoor.py

@@ -1,3 +1,5 @@
+from typing import Optional
+
 from minigrid.core.constants import COLOR_NAMES
 from minigrid.core.grid import Grid
 from minigrid.core.mission import MissionSpace
@@ -63,20 +65,24 @@ class GoToDoorEnv(MiniGridEnv):
 
     """
 
-    def __init__(self, size=5, **kwargs):
+    def __init__(self, size=5, max_steps: Optional[int] = None, **kwargs):
         assert size >= 5
         self.size = size
         mission_space = MissionSpace(
             mission_func=lambda color: f"go to the {color} door",
             ordered_placeholders=[COLOR_NAMES],
         )
+
+        if max_steps is None:
+            max_steps = 4 * size**2
+
         super().__init__(
             mission_space=mission_space,
             width=size,
             height=size,
-            max_steps=5 * size**2,
             # Set this to True for maximum speed
             see_through_walls=True,
+            max_steps=max_steps,
             **kwargs,
         )
 

+ 9 - 2
minigrid/envs/gotoobject.py

@@ -1,3 +1,5 @@
+from typing import Optional
+
 from minigrid.core.constants import COLOR_NAMES
 from minigrid.core.grid import Grid
 from minigrid.core.mission import MissionSpace
@@ -11,7 +13,8 @@ class GoToObjectEnv(MiniGridEnv):
     named using an English text string
     """
 
-    def __init__(self, size=6, numObjs=2, **kwargs):
+    def __init__(self, size=6, numObjs=2, max_steps: Optional[int] = None, **kwargs):
+
         self.numObjs = numObjs
         self.size = size
         # Types of objects to be generated
@@ -21,13 +24,17 @@ class GoToObjectEnv(MiniGridEnv):
             mission_func=lambda color, type: f"go to the {color} {type}",
             ordered_placeholders=[COLOR_NAMES, self.obj_types],
         )
+
+        if max_steps is None:
+            max_steps = 5 * size**2
+
         super().__init__(
             mission_space=mission_space,
             width=size,
             height=size,
-            max_steps=5 * size**2,
             # Set this to True for maximum speed
             see_through_walls=True,
+            max_steps=max_steps,
             **kwargs,
         )
 

+ 15 - 2
minigrid/envs/keycorridor.py

@@ -1,3 +1,5 @@
+from typing import Optional
+
 from minigrid.core.constants import COLOR_NAMES
 from minigrid.core.mission import MissionSpace
 from minigrid.core.roomgrid import RoomGrid
@@ -77,17 +79,28 @@ class KeyCorridorEnv(RoomGrid):
 
     """
 
-    def __init__(self, num_rows=3, obj_type="ball", room_size=6, **kwargs):
+    def __init__(
+        self,
+        num_rows=3,
+        obj_type="ball",
+        room_size=6,
+        max_steps: Optional[int] = None,
+        **kwargs,
+    ):
         self.obj_type = obj_type
         mission_space = MissionSpace(
             mission_func=lambda color: f"pick up the {color} {obj_type}",
             ordered_placeholders=[COLOR_NAMES],
         )
+
+        if max_steps is None:
+            max_steps = 30 * room_size**2
+
         super().__init__(
             mission_space=mission_space,
             room_size=room_size,
             num_rows=num_rows,
-            max_steps=30 * room_size**2,
+            max_steps=max_steps,
             **kwargs,
         )
 

+ 9 - 2
minigrid/envs/lavagap.py

@@ -1,3 +1,5 @@
+from typing import Optional
+
 import numpy as np
 
 from minigrid.core.grid import Grid
@@ -66,7 +68,9 @@ class LavaGapEnv(MiniGridEnv):
 
     """
 
-    def __init__(self, size, obstacle_type=Lava, **kwargs):
+    def __init__(
+        self, size, obstacle_type=Lava, max_steps: Optional[int] = None, **kwargs
+    ):
         self.obstacle_type = obstacle_type
         self.size = size
 
@@ -79,13 +83,16 @@ class LavaGapEnv(MiniGridEnv):
                 mission_func=lambda: "find the opening and get to the green goal square"
             )
 
+        if max_steps is None:
+            max_steps = 4 * size**2
+
         super().__init__(
             mission_space=mission_space,
             width=size,
             height=size,
-            max_steps=4 * size * size,
             # Set this to True for maximum speed
             see_through_walls=False,
+            max_steps=max_steps,
             **kwargs
         )
 

+ 7 - 2
minigrid/envs/lockedroom.py

@@ -1,3 +1,5 @@
+from typing import Optional
+
 from minigrid.core.constants import COLOR_NAMES
 from minigrid.core.grid import Grid
 from minigrid.core.mission import MissionSpace
@@ -74,8 +76,11 @@ class LockedRoomEnv(MiniGridEnv):
 
     """
 
-    def __init__(self, size=19, **kwargs):
+    def __init__(self, size=19, max_steps: Optional[int] = None, **kwargs):
         self.size = size
+
+        if max_steps is None:
+            max_steps = 10 * size
         mission_space = MissionSpace(
             mission_func=lambda lockedroom_color, keyroom_color, door_color: f"get the {lockedroom_color} key from the {keyroom_color} room, unlock the {door_color} door and go to the goal",
             ordered_placeholders=[COLOR_NAMES] * 3,
@@ -84,7 +89,7 @@ class LockedRoomEnv(MiniGridEnv):
             mission_space=mission_space,
             width=size,
             height=size,
-            max_steps=10 * size,
+            max_steps=max_steps,
             **kwargs,
         )
 

+ 10 - 2
minigrid/envs/memory.py

@@ -1,3 +1,5 @@
+from typing import Optional
+
 import numpy as np
 
 from minigrid.core.actions import Actions
@@ -65,9 +67,15 @@ class MemoryEnv(MiniGridEnv):
 
     """
 
-    def __init__(self, size=8, random_length=False, **kwargs):
+    def __init__(
+        self, size=8, random_length=False, max_steps: Optional[int] = None, **kwargs
+    ):
         self.size = size
         self.random_length = random_length
+
+        if max_steps is None:
+            max_steps = 5 * size**2
+
         mission_space = MissionSpace(
             mission_func=lambda: "go to the matching object at the end of the hallway"
         )
@@ -75,9 +83,9 @@ class MemoryEnv(MiniGridEnv):
             mission_space=mission_space,
             width=size,
             height=size,
-            max_steps=5 * size**2,
             # Set this to True for maximum speed
             see_through_walls=False,
+            max_steps=max_steps,
             **kwargs
         )
 

+ 14 - 2
minigrid/envs/multiroom.py

@@ -1,3 +1,5 @@
+from typing import Optional
+
 from minigrid.core.constants import COLOR_NAMES
 from minigrid.core.grid import Grid
 from minigrid.core.mission import MissionSpace
@@ -72,7 +74,14 @@ class MultiRoomEnv(MiniGridEnv):
 
     """
 
-    def __init__(self, minNumRooms, maxNumRooms, maxRoomSize=10, **kwargs):
+    def __init__(
+        self,
+        minNumRooms,
+        maxNumRooms,
+        maxRoomSize=10,
+        max_steps: Optional[int] = None,
+        **kwargs
+    ):
         assert minNumRooms > 0
         assert maxNumRooms >= minNumRooms
         assert maxRoomSize >= 4
@@ -89,11 +98,14 @@ class MultiRoomEnv(MiniGridEnv):
 
         self.size = 25
 
+        if max_steps is None:
+            max_steps = maxNumRooms * 20
+
         super().__init__(
             mission_space=mission_space,
             width=self.size,
             height=self.size,
-            max_steps=self.maxNumRooms * 20,
+            max_steps=max_steps,
             **kwargs
         )
 

+ 13 - 2
minigrid/envs/obstructedmaze.py

@@ -1,3 +1,5 @@
+from typing import Optional
+
 from minigrid.core.constants import COLOR_NAMES, DIR_TO_VEC
 from minigrid.core.mission import MissionSpace
 from minigrid.core.roomgrid import RoomGrid
@@ -79,9 +81,18 @@ class ObstructedMazeEnv(RoomGrid):
 
     """
 
-    def __init__(self, num_rows, num_cols, num_rooms_visited, **kwargs):
+    def __init__(
+        self,
+        num_rows,
+        num_cols,
+        num_rooms_visited,
+        max_steps: Optional[int] = None,
+        **kwargs,
+    ):
         room_size = 6
-        max_steps = 4 * num_rooms_visited * room_size**2
+
+        if max_steps is None:
+            max_steps = 4 * num_rooms_visited * room_size**2
 
         mission_space = MissionSpace(
             mission_func=lambda: f"pick up the {COLOR_NAMES[0]} ball",

+ 2 - 2
minigrid/envs/playground.py

@@ -11,14 +11,14 @@ class PlaygroundEnv(MiniGridEnv):
     This environment has no specific goals or rewards.
     """
 
-    def __init__(self, **kwargs):
+    def __init__(self, max_steps=100, **kwargs):
         mission_space = MissionSpace(mission_func=lambda: "")
         self.size = 19
         super().__init__(
             mission_space=mission_space,
             width=self.size,
             height=self.size,
-            max_steps=100,
+            max_steps=max_steps,
             **kwargs
         )
 

+ 8 - 2
minigrid/envs/putnear.py

@@ -1,3 +1,5 @@
+from typing import Optional
+
 from minigrid.core.constants import COLOR_NAMES
 from minigrid.core.grid import Grid
 from minigrid.core.mission import MissionSpace
@@ -65,7 +67,7 @@ class PutNearEnv(MiniGridEnv):
 
     """
 
-    def __init__(self, size=6, numObjs=2, **kwargs):
+    def __init__(self, size=6, numObjs=2, max_steps: Optional[int] = None, **kwargs):
         self.size = size
         self.numObjs = numObjs
         self.obj_types = ["key", "ball", "box"]
@@ -78,13 +80,17 @@ class PutNearEnv(MiniGridEnv):
                 self.obj_types,
             ],
         )
+
+        if max_steps is None:
+            max_steps = 5 * size
+
         super().__init__(
             mission_space=mission_space,
             width=size,
             height=size,
-            max_steps=5 * size,
             # Set this to True for maximum speed
             see_through_walls=True,
+            max_steps=max_steps,
             **kwargs,
         )
 

+ 8 - 2
minigrid/envs/redbluedoors.py

@@ -1,3 +1,5 @@
+from typing import Optional
+
 from minigrid.core.grid import Grid
 from minigrid.core.mission import MissionSpace
 from minigrid.core.world_object import Door
@@ -57,16 +59,20 @@ class RedBlueDoorEnv(MiniGridEnv):
 
     """
 
-    def __init__(self, size=8, **kwargs):
+    def __init__(self, size=8, max_steps: Optional[int] = None, **kwargs):
         self.size = size
         mission_space = MissionSpace(
             mission_func=lambda: "open the red door then the blue door"
         )
+
+        if max_steps is None:
+            max_steps = 20 * size**2
+
         super().__init__(
             mission_space=mission_space,
             width=2 * size,
             height=size,
-            max_steps=20 * size * size,
+            max_steps=max_steps,
             **kwargs
         )
 

+ 8 - 2
minigrid/envs/unlock.py

@@ -1,3 +1,5 @@
+from typing import Optional
+
 from minigrid.core.mission import MissionSpace
 from minigrid.core.roomgrid import RoomGrid
 
@@ -53,15 +55,19 @@ class UnlockEnv(RoomGrid):
 
     """
 
-    def __init__(self, **kwargs):
+    def __init__(self, max_steps: Optional[int] = None, **kwargs):
         room_size = 6
         mission_space = MissionSpace(mission_func=lambda: "open the door")
+
+        if max_steps is None:
+            max_steps = 8 * room_size**2
+
         super().__init__(
             mission_space=mission_space,
             num_rows=1,
             num_cols=2,
             room_size=room_size,
-            max_steps=8 * room_size**2,
+            max_steps=max_steps,
             **kwargs
         )
 

+ 8 - 2
minigrid/envs/unlockpickup.py

@@ -1,3 +1,5 @@
+from typing import Optional
+
 from minigrid.core.constants import COLOR_NAMES
 from minigrid.core.mission import MissionSpace
 from minigrid.core.roomgrid import RoomGrid
@@ -57,18 +59,22 @@ class UnlockPickupEnv(RoomGrid):
 
     """
 
-    def __init__(self, **kwargs):
+    def __init__(self, max_steps: Optional[int] = None, **kwargs):
         room_size = 6
         mission_space = MissionSpace(
             mission_func=lambda color: f"pick up the {color} box",
             ordered_placeholders=[COLOR_NAMES],
         )
+
+        if max_steps is None:
+            max_steps = 8 * room_size**2
+
         super().__init__(
             mission_space=mission_space,
             num_rows=1,
             num_cols=2,
             room_size=room_size,
-            max_steps=8 * room_size**2,
+            max_steps=max_steps,
             **kwargs,
         )
 

+ 5 - 0
minigrid/minigrid_env.py

@@ -903,7 +903,12 @@ class MiniGridEnv(gym.Env):
         # Environment configuration
         self.width = width
         self.height = height
+
+        assert isinstance(
+            max_steps, int
+        ), f"The argument max_steps must be an integer, got: {type(max_steps)}"
         self.max_steps = max_steps
+
         self.see_through_walls = see_through_walls
 
         # Current position and direction of the agent

+ 23 - 0
tests/test_envs.py

@@ -142,6 +142,29 @@ def test_agent_sees_method(env_id):
 @pytest.mark.parametrize(
     "env_spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
 )
+def test_max_steps_argument(env_spec):
+    """
+    Test that when initializing an environment with a fixed number of steps per episode (`max_steps` argument),
+    the episode will be truncated after taking that number of steps.
+    """
+    max_steps = 50
+    env = env_spec.make(max_steps=max_steps)
+    env.reset()
+    step_count = 0
+    while True:
+        _, _, terminated, truncated, _ = env.step(4)
+        step_count += 1
+        if truncated:
+            assert step_count == max_steps
+            step_count = 0
+            break
+
+    env.close()
+
+
+@pytest.mark.parametrize(
+    "env_spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
+)
 def old_run_test(env_spec):
     # Load the gym environment
     env = env_spec.make()