浏览代码

add **kwargs to __init__ functions of all classes inhertting from MiniGridEnv, in order to be able to use gym.make the right way

saleml 2 年之前
父节点
当前提交
ce22f87a27

+ 3 - 2
gym_minigrid/envs/blockedunlockpickup.py

@@ -8,14 +8,15 @@ class BlockedUnlockPickup(RoomGrid):
     in another room
     in another room
     """
     """
 
 
-    def __init__(self, seed=None):
+    def __init__(self, seed=None, **kwargs):
         room_size = 6
         room_size = 6
         super().__init__(
         super().__init__(
             num_rows=1,
             num_rows=1,
             num_cols=2,
             num_cols=2,
             room_size=room_size,
             room_size=room_size,
             max_steps=16*room_size**2,
             max_steps=16*room_size**2,
-            seed=seed
+            seed=seed,
+            **kwargs
         )
         )
 
 
     def _gen_grid(self, width, height):
     def _gen_grid(self, width, height):

+ 11 - 10
gym_minigrid/envs/crossing.py

@@ -9,7 +9,7 @@ class CrossingEnv(MiniGridEnv):
     Environment with wall or lava obstacles, sparse reward.
     Environment with wall or lava obstacles, sparse reward.
     """
     """
 
 
-    def __init__(self, size=9, num_crossings=1, obstacle_type=Lava, seed=None):
+    def __init__(self, size=9, num_crossings=1, obstacle_type=Lava, seed=None, **kwargs):
         self.num_crossings = num_crossings
         self.num_crossings = num_crossings
         self.obstacle_type = obstacle_type
         self.obstacle_type = obstacle_type
         super().__init__(
         super().__init__(
@@ -17,7 +17,8 @@ class CrossingEnv(MiniGridEnv):
             max_steps=4*size*size,
             max_steps=4*size*size,
             # Set this to True for maximum speed
             # Set this to True for maximum speed
             see_through_walls=False,
             see_through_walls=False,
-            seed=None
+            seed=None,
+            **kwargs
         )
         )
 
 
     def _gen_grid(self, width, height):
     def _gen_grid(self, width, height):
@@ -83,20 +84,20 @@ class CrossingEnv(MiniGridEnv):
         )
         )
 
 
 class LavaCrossingEnv(CrossingEnv):
 class LavaCrossingEnv(CrossingEnv):
-    def __init__(self):
-        super().__init__(size=9, num_crossings=1)
+    def __init__(self, **kwargs):
+        super().__init__(size=9, num_crossings=1, **kwargs)
 
 
 class LavaCrossingS9N2Env(CrossingEnv):
 class LavaCrossingS9N2Env(CrossingEnv):
-    def __init__(self):
-        super().__init__(size=9, num_crossings=2)
+    def __init__(self, **kwargs):
+        super().__init__(size=9, num_crossings=2, **kwargs)
 
 
 class LavaCrossingS9N3Env(CrossingEnv):
 class LavaCrossingS9N3Env(CrossingEnv):
-    def __init__(self):
-        super().__init__(size=9, num_crossings=3)
+    def __init__(self, **kwargs):
+        super().__init__(size=9, num_crossings=3, **kwargs)
 
 
 class LavaCrossingS11N5Env(CrossingEnv):
 class LavaCrossingS11N5Env(CrossingEnv):
-    def __init__(self):
-        super().__init__(size=11, num_crossings=5)
+    def __init__(self, **kwargs):
+        super().__init__(size=11, num_crossings=5, **kwargs)
 
 
 register(
 register(
     id='MiniGrid-LavaCrossingS9N1-v0',
     id='MiniGrid-LavaCrossingS9N1-v0',

+ 8 - 6
gym_minigrid/envs/distshift.py

@@ -12,7 +12,8 @@ class DistShiftEnv(MiniGridEnv):
         height=7,
         height=7,
         agent_start_pos=(1,1),
         agent_start_pos=(1,1),
         agent_start_dir=0,
         agent_start_dir=0,
-        strip2_row=2
+        strip2_row=2,
+        **kwargs
     ):
     ):
         self.agent_start_pos = agent_start_pos
         self.agent_start_pos = agent_start_pos
         self.agent_start_dir = agent_start_dir
         self.agent_start_dir = agent_start_dir
@@ -24,7 +25,8 @@ class DistShiftEnv(MiniGridEnv):
             height=height,
             height=height,
             max_steps=4*width*height,
             max_steps=4*width*height,
             # Set this to True for maximum speed
             # Set this to True for maximum speed
-            see_through_walls=True
+            see_through_walls=True,
+            **kwargs
         )
         )
 
 
     def _gen_grid(self, width, height):
     def _gen_grid(self, width, height):
@@ -52,12 +54,12 @@ class DistShiftEnv(MiniGridEnv):
         self.mission = "get to the green goal square"
         self.mission = "get to the green goal square"
 
 
 class DistShift1(DistShiftEnv):
 class DistShift1(DistShiftEnv):
-    def __init__(self):
-        super().__init__(strip2_row=2)
+    def __init__(self, **kwargs):
+        super().__init__(strip2_row=2, **kwargs)
 
 
 class DistShift2(DistShiftEnv):
 class DistShift2(DistShiftEnv):
-    def __init__(self):
-        super().__init__(strip2_row=5)
+    def __init__(self, **kwargs):
+        super().__init__(strip2_row=5, **kwargs)
 
 
 register(
 register(
     id='MiniGrid-DistShift1-v0',
     id='MiniGrid-DistShift1-v0',

+ 9 - 8
gym_minigrid/envs/doorkey.py

@@ -6,10 +6,11 @@ class DoorKeyEnv(MiniGridEnv):
     Environment with a door and key, sparse reward
     Environment with a door and key, sparse reward
     """
     """
 
 
-    def __init__(self, size=8):
+    def __init__(self, size=8, **kwargs):
         super().__init__(
         super().__init__(
             grid_size=size,
             grid_size=size,
-            max_steps=10*size*size
+            max_steps=10*size*size,
+            **kwargs
         )
         )
 
 
     def _gen_grid(self, width, height):
     def _gen_grid(self, width, height):
@@ -44,16 +45,16 @@ class DoorKeyEnv(MiniGridEnv):
         self.mission = "use the key to open the door and then get to the goal"
         self.mission = "use the key to open the door and then get to the goal"
 
 
 class DoorKeyEnv5x5(DoorKeyEnv):
 class DoorKeyEnv5x5(DoorKeyEnv):
-    def __init__(self):
-        super().__init__(size=5)
+    def __init__(self, **kwargs):
+        super().__init__(size=5, **kwargs)
 
 
 class DoorKeyEnv6x6(DoorKeyEnv):
 class DoorKeyEnv6x6(DoorKeyEnv):
-    def __init__(self):
-        super().__init__(size=6)
+    def __init__(self, **kwargs):
+        super().__init__(size=6, **kwargs)
 
 
 class DoorKeyEnv16x16(DoorKeyEnv):
 class DoorKeyEnv16x16(DoorKeyEnv):
-    def __init__(self):
-        super().__init__(size=16)
+    def __init__(self, **kwargs):
+        super().__init__(size=16, **kwargs)
 
 
 register(
 register(
     id='MiniGrid-DoorKey-5x5-v0',
     id='MiniGrid-DoorKey-5x5-v0',

+ 13 - 11
gym_minigrid/envs/dynamicobstacles.py

@@ -12,7 +12,8 @@ class DynamicObstaclesEnv(MiniGridEnv):
             size=8,
             size=8,
             agent_start_pos=(1, 1),
             agent_start_pos=(1, 1),
             agent_start_dir=0,
             agent_start_dir=0,
-            n_obstacles=4
+            n_obstacles=4,
+            **kwargs
     ):
     ):
         self.agent_start_pos = agent_start_pos
         self.agent_start_pos = agent_start_pos
         self.agent_start_dir = agent_start_dir
         self.agent_start_dir = agent_start_dir
@@ -27,6 +28,7 @@ class DynamicObstaclesEnv(MiniGridEnv):
             max_steps=4 * size * size,
             max_steps=4 * size * size,
             # Set this to True for maximum speed
             # Set this to True for maximum speed
             see_through_walls=True,
             see_through_walls=True,
+            **kwargs
         )
         )
         # Allow only 3 actions permitted: left, right, forward
         # Allow only 3 actions permitted: left, right, forward
         self.action_space = spaces.Discrete(self.actions.forward + 1)
         self.action_space = spaces.Discrete(self.actions.forward + 1)
@@ -89,24 +91,24 @@ class DynamicObstaclesEnv(MiniGridEnv):
         return obs, reward, done, info
         return obs, reward, done, info
 
 
 class DynamicObstaclesEnv5x5(DynamicObstaclesEnv):
 class DynamicObstaclesEnv5x5(DynamicObstaclesEnv):
-    def __init__(self):
-        super().__init__(size=5, n_obstacles=2)
+    def __init__(self, **kwargs):
+        super().__init__(size=5, n_obstacles=2, **kwargs)
 
 
 class DynamicObstaclesRandomEnv5x5(DynamicObstaclesEnv):
 class DynamicObstaclesRandomEnv5x5(DynamicObstaclesEnv):
-    def __init__(self):
-        super().__init__(size=5, agent_start_pos=None, n_obstacles=2)
+    def __init__(self, **kwargs):
+        super().__init__(size=5, agent_start_pos=None, n_obstacles=2, **kwargs)
 
 
 class DynamicObstaclesEnv6x6(DynamicObstaclesEnv):
 class DynamicObstaclesEnv6x6(DynamicObstaclesEnv):
-    def __init__(self):
-        super().__init__(size=6, n_obstacles=3)
+    def __init__(self, **kwargs):
+        super().__init__(size=6, n_obstacles=3, **kwargs)
 
 
 class DynamicObstaclesRandomEnv6x6(DynamicObstaclesEnv):
 class DynamicObstaclesRandomEnv6x6(DynamicObstaclesEnv):
-    def __init__(self):
-        super().__init__(size=6, agent_start_pos=None, n_obstacles=3)
+    def __init__(self, **kwargs):
+        super().__init__(size=6, agent_start_pos=None, n_obstacles=3, **kwargs)
 
 
 class DynamicObstaclesEnv16x16(DynamicObstaclesEnv):
 class DynamicObstaclesEnv16x16(DynamicObstaclesEnv):
-    def __init__(self):
-        super().__init__(size=16, n_obstacles=8)
+    def __init__(self, **kwargs):
+        super().__init__(size=16, n_obstacles=8, **kwargs)
 
 
 register(
 register(
     id='MiniGrid-Dynamic-Obstacles-5x5-v0',
     id='MiniGrid-Dynamic-Obstacles-5x5-v0',

+ 7 - 5
gym_minigrid/envs/empty.py

@@ -11,6 +11,7 @@ class EmptyEnv(MiniGridEnv):
         size=8,
         size=8,
         agent_start_pos=(1,1),
         agent_start_pos=(1,1),
         agent_start_dir=0,
         agent_start_dir=0,
+        **kwargs
     ):
     ):
         self.agent_start_pos = agent_start_pos
         self.agent_start_pos = agent_start_pos
         self.agent_start_dir = agent_start_dir
         self.agent_start_dir = agent_start_dir
@@ -19,7 +20,8 @@ class EmptyEnv(MiniGridEnv):
             grid_size=size,
             grid_size=size,
             max_steps=4*size*size,
             max_steps=4*size*size,
             # Set this to True for maximum speed
             # Set this to True for maximum speed
-            see_through_walls=True
+            see_through_walls=True,
+            **kwargs
         )
         )
 
 
     def _gen_grid(self, width, height):
     def _gen_grid(self, width, height):
@@ -46,16 +48,16 @@ class EmptyEnv5x5(EmptyEnv):
         super().__init__(size=5, **kwargs)
         super().__init__(size=5, **kwargs)
 
 
 class EmptyRandomEnv5x5(EmptyEnv):
 class EmptyRandomEnv5x5(EmptyEnv):
-    def __init__(self):
-        super().__init__(size=5, agent_start_pos=None)
+    def __init__(self, **kwargs):
+        super().__init__(size=5, agent_start_pos=None, **kwargs)
 
 
 class EmptyEnv6x6(EmptyEnv):
 class EmptyEnv6x6(EmptyEnv):
     def __init__(self, **kwargs):
     def __init__(self, **kwargs):
         super().__init__(size=6, **kwargs)
         super().__init__(size=6, **kwargs)
 
 
 class EmptyRandomEnv6x6(EmptyEnv):
 class EmptyRandomEnv6x6(EmptyEnv):
-    def __init__(self):
-        super().__init__(size=6, agent_start_pos=None)
+    def __init__(self, **kwargs):
+        super().__init__(size=6, agent_start_pos=None, **kwargs)
 
 
 class EmptyEnv16x16(EmptyEnv):
 class EmptyEnv16x16(EmptyEnv):
     def __init__(self, **kwargs):
     def __init__(self, **kwargs):

+ 8 - 6
gym_minigrid/envs/fetch.py

@@ -10,7 +10,8 @@ class FetchEnv(MiniGridEnv):
     def __init__(
     def __init__(
         self,
         self,
         size=8,
         size=8,
-        numObjs=3
+        numObjs=3,
+        **kwargs
     ):
     ):
         self.numObjs = numObjs
         self.numObjs = numObjs
 
 
@@ -18,7 +19,8 @@ class FetchEnv(MiniGridEnv):
             grid_size=size,
             grid_size=size,
             max_steps=5*size**2,
             max_steps=5*size**2,
             # Set this to True for maximum speed
             # Set this to True for maximum speed
-            see_through_walls=True
+            see_through_walls=True,
+            **kwargs
         )
         )
 
 
     def _gen_grid(self, width, height):
     def _gen_grid(self, width, height):
@@ -86,12 +88,12 @@ class FetchEnv(MiniGridEnv):
         return obs, reward, done, info
         return obs, reward, done, info
 
 
 class FetchEnv5x5N2(FetchEnv):
 class FetchEnv5x5N2(FetchEnv):
-    def __init__(self):
-        super().__init__(size=5, numObjs=2)
+    def __init__(self, **kwargs):
+        super().__init__(size=5, numObjs=2, **kwargs)
 
 
 class FetchEnv6x6N2(FetchEnv):
 class FetchEnv6x6N2(FetchEnv):
-    def __init__(self):
-        super().__init__(size=6, numObjs=2)
+    def __init__(self, **kwargs):
+        super().__init__(size=6, numObjs=2, **kwargs)
 
 
 register(
 register(
     id='MiniGrid-Fetch-5x5-N2-v0',
     id='MiniGrid-Fetch-5x5-N2-v0',

+ 2 - 2
gym_minigrid/envs/fourrooms.py

@@ -11,10 +11,10 @@ class FourRoomsEnv(MiniGridEnv):
     Can specify agent and goal position, if not it set at random.
     Can specify agent and goal position, if not it set at random.
     """
     """
 
 
-    def __init__(self, agent_pos=None, goal_pos=None):
+    def __init__(self, agent_pos=None, goal_pos=None, **kwargs):
         self._agent_default_pos = agent_pos
         self._agent_default_pos = agent_pos
         self._goal_default_pos = goal_pos
         self._goal_default_pos = goal_pos
-        super().__init__(grid_size=19, max_steps=100)
+        super().__init__(grid_size=19, max_steps=100, **kwargs)
 
 
     def _gen_grid(self, width, height):
     def _gen_grid(self, width, height):
         # Create the grid
         # Create the grid

+ 8 - 6
gym_minigrid/envs/gotodoor.py

@@ -9,7 +9,8 @@ class GoToDoorEnv(MiniGridEnv):
 
 
     def __init__(
     def __init__(
         self,
         self,
-        size=5
+        size=5,
+        **kwargs
     ):
     ):
         assert size >= 5
         assert size >= 5
 
 
@@ -17,7 +18,8 @@ class GoToDoorEnv(MiniGridEnv):
             grid_size=size,
             grid_size=size,
             max_steps=5*size**2,
             max_steps=5*size**2,
             # Set this to True for maximum speed
             # Set this to True for maximum speed
-            see_through_walls=True
+            see_through_walls=True,
+            **kwargs
         )
         )
 
 
     def _gen_grid(self, width, height):
     def _gen_grid(self, width, height):
@@ -81,12 +83,12 @@ class GoToDoorEnv(MiniGridEnv):
         return obs, reward, done, info
         return obs, reward, done, info
 
 
 class GoToDoor8x8Env(GoToDoorEnv):
 class GoToDoor8x8Env(GoToDoorEnv):
-    def __init__(self):
-        super().__init__(size=8)
+    def __init__(self, **kwargs):
+        super().__init__(size=8, **kwargs)
 
 
 class GoToDoor6x6Env(GoToDoorEnv):
 class GoToDoor6x6Env(GoToDoorEnv):
-    def __init__(self):
-        super().__init__(size=6)
+    def __init__(self, **kwargs):
+        super().__init__(size=6, **kwargs)
 
 
 register(
 register(
     id='MiniGrid-GoToDoor-5x5-v0',
     id='MiniGrid-GoToDoor-5x5-v0',

+ 6 - 4
gym_minigrid/envs/gotoobject.py

@@ -10,7 +10,8 @@ class GoToObjectEnv(MiniGridEnv):
     def __init__(
     def __init__(
         self,
         self,
         size=6,
         size=6,
-        numObjs=2
+        numObjs=2,
+        **kwargs
     ):
     ):
         self.numObjs = numObjs
         self.numObjs = numObjs
 
 
@@ -18,7 +19,8 @@ class GoToObjectEnv(MiniGridEnv):
             grid_size=size,
             grid_size=size,
             max_steps=5*size**2,
             max_steps=5*size**2,
             # Set this to True for maximum speed
             # Set this to True for maximum speed
-            see_through_walls=True
+            see_through_walls=True,
+            **kwargs
         )
         )
 
 
     def _gen_grid(self, width, height):
     def _gen_grid(self, width, height):
@@ -84,8 +86,8 @@ class GoToObjectEnv(MiniGridEnv):
         return obs, reward, done, info
         return obs, reward, done, info
 
 
 class GotoEnv8x8N2(GoToObjectEnv):
 class GotoEnv8x8N2(GoToObjectEnv):
-    def __init__(self):
-        super().__init__(size=8, numObjs=2)
+    def __init__(self, **kwargs):
+        super().__init__(size=8, numObjs=2, **kwargs)
 
 
 register(
 register(
     id='MiniGrid-GoToObject-6x6-N2-v0',
     id='MiniGrid-GoToObject-6x6-N2-v0',

+ 21 - 13
gym_minigrid/envs/keycorridor.py

@@ -12,7 +12,8 @@ class KeyCorridor(RoomGrid):
         num_rows=3,
         num_rows=3,
         obj_type="ball",
         obj_type="ball",
         room_size=6,
         room_size=6,
-        seed=None
+        seed=None,
+        **kwargs
     ):
     ):
         self.obj_type = obj_type
         self.obj_type = obj_type
 
 
@@ -21,6 +22,7 @@ class KeyCorridor(RoomGrid):
             num_rows=num_rows,
             num_rows=num_rows,
             max_steps=30*room_size**2,
             max_steps=30*room_size**2,
             seed=seed,
             seed=seed,
+            **kwargs
         )
         )
 
 
     def _gen_grid(self, width, height):
     def _gen_grid(self, width, height):
@@ -59,51 +61,57 @@ class KeyCorridor(RoomGrid):
         return obs, reward, done, info
         return obs, reward, done, info
 
 
 class KeyCorridorS3R1(KeyCorridor):
 class KeyCorridorS3R1(KeyCorridor):
-    def __init__(self, seed=None):
+    def __init__(self, seed=None, **kwargs):
         super().__init__(
         super().__init__(
             room_size=3,
             room_size=3,
             num_rows=1,
             num_rows=1,
-            seed=seed
+            seed=seed,
+            **kwargs
         )
         )
 
 
 class KeyCorridorS3R2(KeyCorridor):
 class KeyCorridorS3R2(KeyCorridor):
-    def __init__(self, seed=None):
+    def __init__(self, seed=None, **kwargs):
         super().__init__(
         super().__init__(
             room_size=3,
             room_size=3,
             num_rows=2,
             num_rows=2,
-            seed=seed
+            seed=seed,
+            **kwargs
         )
         )
 
 
 class KeyCorridorS3R3(KeyCorridor):
 class KeyCorridorS3R3(KeyCorridor):
-    def __init__(self, seed=None):
+    def __init__(self, seed=None, **kwargs):
         super().__init__(
         super().__init__(
             room_size=3,
             room_size=3,
             num_rows=3,
             num_rows=3,
-            seed=seed
+            seed=seed, 
+            **kwargs
         )
         )
 
 
 class KeyCorridorS4R3(KeyCorridor):
 class KeyCorridorS4R3(KeyCorridor):
-    def __init__(self, seed=None):
+    def __init__(self, seed=None, **kwargs):
         super().__init__(
         super().__init__(
             room_size=4,
             room_size=4,
             num_rows=3,
             num_rows=3,
-            seed=seed
+            seed=seed, 
+            **kwargs
         )
         )
 
 
 class KeyCorridorS5R3(KeyCorridor):
 class KeyCorridorS5R3(KeyCorridor):
-    def __init__(self, seed=None):
+    def __init__(self, seed=None, **kwargs):
         super().__init__(
         super().__init__(
             room_size=5,
             room_size=5,
             num_rows=3,
             num_rows=3,
-            seed=seed
+            seed=seed, 
+            **kwargs
         )
         )
 
 
 class KeyCorridorS6R3(KeyCorridor):
 class KeyCorridorS6R3(KeyCorridor):
-    def __init__(self, seed=None):
+    def __init__(self, seed=None, **kwargs):
         super().__init__(
         super().__init__(
             room_size=6,
             room_size=6,
             num_rows=3,
             num_rows=3,
-            seed=seed
+            seed=seed, 
+            **kwargs
         )
         )
 
 
 register(
 register(

+ 9 - 8
gym_minigrid/envs/lavagap.py

@@ -7,14 +7,15 @@ class LavaGapEnv(MiniGridEnv):
     This environment is similar to LavaCrossing but simpler in structure.
     This environment is similar to LavaCrossing but simpler in structure.
     """
     """
 
 
-    def __init__(self, size, obstacle_type=Lava, seed=None):
+    def __init__(self, size, obstacle_type=Lava, seed=None, **kwargs):
         self.obstacle_type = obstacle_type
         self.obstacle_type = obstacle_type
         super().__init__(
         super().__init__(
             grid_size=size,
             grid_size=size,
             max_steps=4*size*size,
             max_steps=4*size*size,
             # Set this to True for maximum speed
             # Set this to True for maximum speed
             see_through_walls=False,
             see_through_walls=False,
-            seed=None
+            seed=None,
+            **kwargs
         )
         )
 
 
     def _gen_grid(self, width, height):
     def _gen_grid(self, width, height):
@@ -53,16 +54,16 @@ class LavaGapEnv(MiniGridEnv):
         )
         )
 
 
 class LavaGapS5Env(LavaGapEnv):
 class LavaGapS5Env(LavaGapEnv):
-    def __init__(self):
-        super().__init__(size=5)
+    def __init__(self, **kwargs):
+        super().__init__(size=5, **kwargs)
 
 
 class LavaGapS6Env(LavaGapEnv):
 class LavaGapS6Env(LavaGapEnv):
-    def __init__(self):
-        super().__init__(size=6)
+    def __init__(self, **kwargs):
+        super().__init__(size=6, **kwargs)
 
 
 class LavaGapS7Env(LavaGapEnv):
 class LavaGapS7Env(LavaGapEnv):
-    def __init__(self):
-        super().__init__(size=7)
+    def __init__(self, **kwargs):
+        super().__init__(size=7, **kwargs)
 
 
 register(
 register(
     id='MiniGrid-LavaGapS5-v0',
     id='MiniGrid-LavaGapS5-v0',

+ 3 - 2
gym_minigrid/envs/lockedroom.py

@@ -30,9 +30,10 @@ class LockedRoom(MiniGridEnv):
 
 
     def __init__(
     def __init__(
         self,
         self,
-        size=19
+        size=19, 
+        **kwargs
     ):
     ):
-        super().__init__(grid_size=size, max_steps=10*size)
+        super().__init__(grid_size=size, max_steps=10*size, **kwargs)
 
 
     def _gen_grid(self, width, height):
     def _gen_grid(self, width, height):
         # Create the grid
         # Create the grid

+ 16 - 14
gym_minigrid/envs/memory.py

@@ -15,7 +15,8 @@ class MemoryEnv(MiniGridEnv):
         self,
         self,
         seed,
         seed,
         size=8,
         size=8,
-        random_length=False,
+        random_length=False, 
+        **kwargs
     ):
     ):
         self.random_length = random_length
         self.random_length = random_length
         super().__init__(
         super().__init__(
@@ -23,7 +24,8 @@ class MemoryEnv(MiniGridEnv):
             grid_size=size,
             grid_size=size,
             max_steps=5*size**2,
             max_steps=5*size**2,
             # Set this to True for maximum speed
             # Set this to True for maximum speed
-            see_through_walls=False,
+            see_through_walls=False, 
+            **kwargs
         )
         )
 
 
     def _gen_grid(self, width, height):
     def _gen_grid(self, width, height):
@@ -100,8 +102,8 @@ class MemoryEnv(MiniGridEnv):
         return obs, reward, done, info
         return obs, reward, done, info
 
 
 class MemoryS17Random(MemoryEnv):
 class MemoryS17Random(MemoryEnv):
-    def __init__(self, seed=None):
-        super().__init__(seed=seed, size=17, random_length=True)
+    def __init__(self, seed=None, **kwargs):
+        super().__init__(seed=seed, size=17, random_length=True, **kwargs)
 
 
 register(
 register(
     id='MiniGrid-MemoryS17Random-v0',
     id='MiniGrid-MemoryS17Random-v0',
@@ -109,8 +111,8 @@ register(
 )
 )
 
 
 class MemoryS13Random(MemoryEnv):
 class MemoryS13Random(MemoryEnv):
-    def __init__(self, seed=None):
-        super().__init__(seed=seed, size=13, random_length=True)
+    def __init__(self, seed=None, **kwargs):
+        super().__init__(seed=seed, size=13, random_length=True, **kwargs)
 
 
 register(
 register(
     id='MiniGrid-MemoryS13Random-v0',
     id='MiniGrid-MemoryS13Random-v0',
@@ -118,8 +120,8 @@ register(
 )
 )
 
 
 class MemoryS13(MemoryEnv):
 class MemoryS13(MemoryEnv):
-    def __init__(self, seed=None):
-        super().__init__(seed=seed, size=13)
+    def __init__(self, seed=None, **kwargs):
+        super().__init__(seed=seed, size=13, **kwargs)
 
 
 register(
 register(
     id='MiniGrid-MemoryS13-v0',
     id='MiniGrid-MemoryS13-v0',
@@ -127,8 +129,8 @@ register(
 )
 )
 
 
 class MemoryS11(MemoryEnv):
 class MemoryS11(MemoryEnv):
-    def __init__(self, seed=None):
-        super().__init__(seed=seed, size=11)
+    def __init__(self, seed=None, **kwargs):
+        super().__init__(seed=seed, size=11, **kwargs)
 
 
 register(
 register(
     id='MiniGrid-MemoryS11-v0',
     id='MiniGrid-MemoryS11-v0',
@@ -136,8 +138,8 @@ register(
 )
 )
 
 
 class MemoryS9(MemoryEnv):
 class MemoryS9(MemoryEnv):
-    def __init__(self, seed=None):
-        super().__init__(seed=seed, size=9)
+    def __init__(self, seed=None, **kwargs):
+        super().__init__(seed=seed, size=9, **kwargs)
 
 
 register(
 register(
     id='MiniGrid-MemoryS9-v0',
     id='MiniGrid-MemoryS9-v0',
@@ -145,8 +147,8 @@ register(
 )
 )
 
 
 class MemoryS7(MemoryEnv):
 class MemoryS7(MemoryEnv):
-    def __init__(self, seed=None):
-        super().__init__(seed=seed, size=7)
+    def __init__(self, seed=None, **kwargs):
+        super().__init__(seed=seed, size=7, **kwargs)
 
 
 register(
 register(
     id='MiniGrid-MemoryS7-v0',
     id='MiniGrid-MemoryS7-v0',

+ 13 - 8
gym_minigrid/envs/multiroom.py

@@ -21,7 +21,8 @@ class MultiRoomEnv(MiniGridEnv):
     def __init__(self,
     def __init__(self,
         minNumRooms,
         minNumRooms,
         maxNumRooms,
         maxNumRooms,
-        maxRoomSize=10
+        maxRoomSize=10,
+        **kwargs
     ):
     ):
         assert minNumRooms > 0
         assert minNumRooms > 0
         assert maxNumRooms >= minNumRooms
         assert maxNumRooms >= minNumRooms
@@ -35,7 +36,8 @@ class MultiRoomEnv(MiniGridEnv):
 
 
         super(MultiRoomEnv, self).__init__(
         super(MultiRoomEnv, self).__init__(
             grid_size=25,
             grid_size=25,
-            max_steps=self.maxNumRooms * 20
+            max_steps=self.maxNumRooms * 20,
+            **kwargs
         )
         )
 
 
     def _gen_grid(self, width, height):
     def _gen_grid(self, width, height):
@@ -237,26 +239,29 @@ class MultiRoomEnv(MiniGridEnv):
         return True
         return True
 
 
 class MultiRoomEnvN2S4(MultiRoomEnv):
 class MultiRoomEnvN2S4(MultiRoomEnv):
-    def __init__(self):
+    def __init__(self, **kwargs):
         super().__init__(
         super().__init__(
             minNumRooms=2,
             minNumRooms=2,
             maxNumRooms=2,
             maxNumRooms=2,
-            maxRoomSize=4
+            maxRoomSize=4,
+            **kwargs
         )
         )
 
 
 class MultiRoomEnvN4S5(MultiRoomEnv):
 class MultiRoomEnvN4S5(MultiRoomEnv):
-    def __init__(self):
+    def __init__(self, **kwargs):
         super().__init__(
         super().__init__(
             minNumRooms=4,
             minNumRooms=4,
             maxNumRooms=4,
             maxNumRooms=4,
-            maxRoomSize=5
+            maxRoomSize=5,
+            **kwargs
         )
         )
 
 
 class MultiRoomEnvN6(MultiRoomEnv):
 class MultiRoomEnvN6(MultiRoomEnv):
-    def __init__(self):
+    def __init__(self, **kwargs):
         super().__init__(
         super().__init__(
             minNumRooms=6,
             minNumRooms=6,
-            maxNumRooms=6
+            maxNumRooms=6,
+            **kwargs
         )
         )
 
 
 register(
 register(

+ 24 - 20
gym_minigrid/envs/obstructedmaze.py

@@ -12,7 +12,8 @@ class ObstructedMazeEnv(RoomGrid):
         num_rows,
         num_rows,
         num_cols,
         num_cols,
         num_rooms_visited,
         num_rooms_visited,
-        seed=None
+        seed=None, 
+        **kwargs
     ):
     ):
         room_size = 6
         room_size = 6
         max_steps = 4*num_rooms_visited*room_size**2
         max_steps = 4*num_rooms_visited*room_size**2
@@ -22,7 +23,8 @@ class ObstructedMazeEnv(RoomGrid):
             num_rows=num_rows,
             num_rows=num_rows,
             num_cols=num_cols,
             num_cols=num_cols,
             max_steps=max_steps,
             max_steps=max_steps,
-            seed=seed
+            seed=seed, 
+            **kwargs
         )
         )
 
 
     def _gen_grid(self, width, height):
     def _gen_grid(self, width, height):
@@ -79,7 +81,7 @@ class ObstructedMaze_1Dlhb(ObstructedMazeEnv):
     rooms. Doors are obstructed by a ball and keys are hidden in boxes.
     rooms. Doors are obstructed by a ball and keys are hidden in boxes.
     """
     """
 
 
-    def __init__(self, key_in_box=True, blocked=True, seed=None):
+    def __init__(self, key_in_box=True, blocked=True, seed=None, **kwargs):
         self.key_in_box = key_in_box
         self.key_in_box = key_in_box
         self.blocked = blocked
         self.blocked = blocked
 
 
@@ -87,7 +89,8 @@ class ObstructedMaze_1Dlhb(ObstructedMazeEnv):
             num_rows=1,
             num_rows=1,
             num_cols=2,
             num_cols=2,
             num_rooms_visited=2,
             num_rooms_visited=2,
-            seed=seed
+            seed=seed, 
+            **kwargs
         )
         )
 
 
     def _gen_grid(self, width, height):
     def _gen_grid(self, width, height):
@@ -102,12 +105,12 @@ class ObstructedMaze_1Dlhb(ObstructedMazeEnv):
         self.place_agent(0, 0)
         self.place_agent(0, 0)
 
 
 class ObstructedMaze_1Dl(ObstructedMaze_1Dlhb):
 class ObstructedMaze_1Dl(ObstructedMaze_1Dlhb):
-    def __init__(self, seed=None):
-        super().__init__(False, False, seed)
+    def __init__(self, seed=None, **kwargs):
+        super().__init__(False, False, seed, **kwargs)
 
 
 class ObstructedMaze_1Dlh(ObstructedMaze_1Dlhb):
 class ObstructedMaze_1Dlh(ObstructedMaze_1Dlhb):
-    def __init__(self, seed=None):
-        super().__init__(True, False, seed)
+    def __init__(self, seed=None, **kwargs):
+        super().__init__(True, False, seed, **kwargs)
 
 
 class ObstructedMaze_Full(ObstructedMazeEnv):
 class ObstructedMaze_Full(ObstructedMazeEnv):
     """
     """
@@ -117,7 +120,7 @@ class ObstructedMaze_Full(ObstructedMazeEnv):
     """
     """
 
 
     def __init__(self, agent_room=(1, 1), key_in_box=True, blocked=True,
     def __init__(self, agent_room=(1, 1), key_in_box=True, blocked=True,
-                 num_quarters=4, num_rooms_visited=25, seed=None):
+                 num_quarters=4, num_rooms_visited=25, seed=None, **kwargs):
         self.agent_room = agent_room
         self.agent_room = agent_room
         self.key_in_box = key_in_box
         self.key_in_box = key_in_box
         self.blocked = blocked
         self.blocked = blocked
@@ -127,7 +130,8 @@ class ObstructedMaze_Full(ObstructedMazeEnv):
             num_rows=3,
             num_rows=3,
             num_cols=3,
             num_cols=3,
             num_rooms_visited=num_rooms_visited,
             num_rooms_visited=num_rooms_visited,
-            seed=seed
+            seed=seed, 
+            **kwargs
         )
         )
 
 
     def _gen_grid(self, width, height):
     def _gen_grid(self, width, height):
@@ -158,25 +162,25 @@ class ObstructedMaze_Full(ObstructedMazeEnv):
         self.place_agent(*self.agent_room)
         self.place_agent(*self.agent_room)
 
 
 class ObstructedMaze_2Dl(ObstructedMaze_Full):
 class ObstructedMaze_2Dl(ObstructedMaze_Full):
-    def __init__(self, seed=None):
-        super().__init__((2, 1), False, False, 1, 4, seed)
+    def __init__(self, seed=None, **kwargs):
+        super().__init__((2, 1), False, False, 1, 4, seed, **kwargs)
 
 
 class ObstructedMaze_2Dlh(ObstructedMaze_Full):
 class ObstructedMaze_2Dlh(ObstructedMaze_Full):
-    def __init__(self, seed=None):
-        super().__init__((2, 1), True, False, 1, 4, seed)
+    def __init__(self, seed=None, **kwargs):
+        super().__init__((2, 1), True, False, 1, 4, seed, **kwargs)
 
 
 
 
 class ObstructedMaze_2Dlhb(ObstructedMaze_Full):
 class ObstructedMaze_2Dlhb(ObstructedMaze_Full):
-    def __init__(self, seed=None):
-        super().__init__((2, 1), True, True, 1, 4, seed)
+    def __init__(self, seed=None, **kwargs):
+        super().__init__((2, 1), True, True, 1, 4, seed, **kwargs)
 
 
 class ObstructedMaze_1Q(ObstructedMaze_Full):
 class ObstructedMaze_1Q(ObstructedMaze_Full):
-    def __init__(self, seed=None):
-        super().__init__((1, 1), True, True, 1, 5, seed)
+    def __init__(self, seed=None, **kwargs):
+        super().__init__((1, 1), True, True, 1, 5, seed, **kwargs)
 
 
 class ObstructedMaze_2Q(ObstructedMaze_Full):
 class ObstructedMaze_2Q(ObstructedMaze_Full):
-    def __init__(self, seed=None):
-        super().__init__((1, 1), True, True, 2, 11, seed)
+    def __init__(self, seed=None, **kwargs):
+        super().__init__((1, 1), True, True, 2, 11, seed, **kwargs)
 
 
 register(
 register(
     id="MiniGrid-ObstructedMaze-1Dl-v0",
     id="MiniGrid-ObstructedMaze-1Dl-v0",

+ 2 - 2
gym_minigrid/envs/playground_v0.py

@@ -7,8 +7,8 @@ class PlaygroundV0(MiniGridEnv):
     This environment has no specific goals or rewards.
     This environment has no specific goals or rewards.
     """
     """
 
 
-    def __init__(self):
-        super().__init__(grid_size=19, max_steps=100)
+    def __init__(self, **kwargs):
+        super().__init__(grid_size=19, max_steps=100, **kwargs)
 
 
     def _gen_grid(self, width, height):
     def _gen_grid(self, width, height):
         # Create the grid
         # Create the grid

+ 6 - 4
gym_minigrid/envs/putnear.py

@@ -10,7 +10,8 @@ class PutNearEnv(MiniGridEnv):
     def __init__(
     def __init__(
         self,
         self,
         size=6,
         size=6,
-        numObjs=2
+        numObjs=2, 
+        **kwargs
     ):
     ):
         self.numObjs = numObjs
         self.numObjs = numObjs
 
 
@@ -18,7 +19,8 @@ class PutNearEnv(MiniGridEnv):
             grid_size=size,
             grid_size=size,
             max_steps=5*size,
             max_steps=5*size,
             # Set this to True for maximum speed
             # Set this to True for maximum speed
-            see_through_walls=True
+            see_through_walls=True, 
+            **kwargs
         )
         )
 
 
     def _gen_grid(self, width, height):
     def _gen_grid(self, width, height):
@@ -112,8 +114,8 @@ class PutNearEnv(MiniGridEnv):
         return obs, reward, done, info
         return obs, reward, done, info
 
 
 class PutNear8x8N3(PutNearEnv):
 class PutNear8x8N3(PutNearEnv):
-    def __init__(self):
-        super().__init__(size=8, numObjs=3)
+    def __init__(self, **kwargs):
+        super().__init__(size=8, numObjs=3, **kwargs)
 
 
 register(
 register(
     id='MiniGrid-PutNear-6x6-N2-v0',
     id='MiniGrid-PutNear-6x6-N2-v0',

+ 5 - 4
gym_minigrid/envs/redbluedoors.py

@@ -8,13 +8,14 @@ class RedBlueDoorEnv(MiniGridEnv):
     obtain a reward.
     obtain a reward.
     """
     """
 
 
-    def __init__(self, size=8):
+    def __init__(self, size=8, **kwargs):
         self.size = size
         self.size = size
 
 
         super().__init__(
         super().__init__(
             width=2*size,
             width=2*size,
             height=size,
             height=size,
-            max_steps=20*size*size
+            max_steps=20*size*size,
+            **kwargs
         )
         )
 
 
     def _gen_grid(self, width, height):
     def _gen_grid(self, width, height):
@@ -66,8 +67,8 @@ class RedBlueDoorEnv(MiniGridEnv):
         return obs, reward, done, info
         return obs, reward, done, info
 
 
 class RedBlueDoorEnv6x6(RedBlueDoorEnv):
 class RedBlueDoorEnv6x6(RedBlueDoorEnv):
-    def __init__(self):
-        super().__init__(size=6)
+    def __init__(self, **kwargs):
+        super().__init__(size=6, **kwargs)
 
 
 register(
 register(
     id='MiniGrid-RedBlueDoors-6x6-v0',
     id='MiniGrid-RedBlueDoors-6x6-v0',

+ 3 - 2
gym_minigrid/envs/unlock.py

@@ -7,14 +7,15 @@ class Unlock(RoomGrid):
     Unlock a door
     Unlock a door
     """
     """
 
 
-    def __init__(self, seed=None):
+    def __init__(self, seed=None, **kwargs):
         room_size = 6
         room_size = 6
         super().__init__(
         super().__init__(
             num_rows=1,
             num_rows=1,
             num_cols=2,
             num_cols=2,
             room_size=room_size,
             room_size=room_size,
             max_steps=8*room_size**2,
             max_steps=8*room_size**2,
-            seed=seed
+            seed=seed, 
+            **kwargs
         )
         )
 
 
     def _gen_grid(self, width, height):
     def _gen_grid(self, width, height):

+ 3 - 2
gym_minigrid/envs/unlockpickup.py

@@ -7,14 +7,15 @@ class UnlockPickup(RoomGrid):
     Unlock a door, then pick up a box in another room
     Unlock a door, then pick up a box in another room
     """
     """
 
 
-    def __init__(self, seed=None):
+    def __init__(self, seed=None, **kwargs):
         room_size = 6
         room_size = 6
         super().__init__(
         super().__init__(
             num_rows=1,
             num_rows=1,
             num_cols=2,
             num_cols=2,
             room_size=room_size,
             room_size=room_size,
             max_steps=8*room_size**2,
             max_steps=8*room_size**2,
-            seed=seed
+            seed=seed, 
+            **kwargs
         )
         )
 
 
     def _gen_grid(self, width, height):
     def _gen_grid(self, width, height):

+ 4 - 2
gym_minigrid/roomgrid.py

@@ -73,7 +73,8 @@ class RoomGrid(MiniGridEnv):
         num_cols=3,
         num_cols=3,
         max_steps=100,
         max_steps=100,
         seed=0,
         seed=0,
-        agent_view_size=7
+        agent_view_size=7,
+        **kwargs
     ):
     ):
         assert room_size > 0
         assert room_size > 0
         assert room_size >= 3
         assert room_size >= 3
@@ -95,7 +96,8 @@ class RoomGrid(MiniGridEnv):
             max_steps=max_steps,
             max_steps=max_steps,
             see_through_walls=False,
             see_through_walls=False,
             seed=seed,
             seed=seed,
-            agent_view_size=agent_view_size
+            agent_view_size=agent_view_size,
+            **kwargs
         )
         )
 
 
     def room_from_pos(self, x, y):
     def room_from_pos(self, x, y):