Parcourir la source

Merge branch 'master' into pre-commit

Mark Towers il y a 2 ans
Parent
commit
1017d91e1c
1 fichiers modifiés avec 11 ajouts et 8 suppressions
  1. 11 8
      gym_minigrid/envs/doorkey.py

+ 11 - 8
gym_minigrid/envs/doorkey.py

@@ -7,8 +7,11 @@ class DoorKeyEnv(MiniGridEnv):
     Environment with a door and key, sparse reward
     Environment with a door and key, sparse reward
     """
     """
 
 
-    def __init__(self, size=8):
-        super().__init__(grid_size=size, max_steps=10 * size * size)
+    def __init__(self, size=8, max_steps=None):
+        super().__init__(
+            grid_size=size,
+            max_steps=10*size*size if max_steps is None else max_steps
+        )
 
 
     def _gen_grid(self, width, height):
     def _gen_grid(self, width, height):
         # Create an empty grid
         # Create an empty grid
@@ -39,18 +42,18 @@ class DoorKeyEnv(MiniGridEnv):
 
 
 
 
 class DoorKeyEnv5x5(DoorKeyEnv):
 class DoorKeyEnv5x5(DoorKeyEnv):
-    def __init__(self):
-        super().__init__(size=5)
+    def __init__(self, max_steps=None):
+        super().__init__(size=5, max_steps=max_steps)
 
 
 
 
 class DoorKeyEnv6x6(DoorKeyEnv):
 class DoorKeyEnv6x6(DoorKeyEnv):
-    def __init__(self):
-        super().__init__(size=6)
+    def __init__(self, max_steps=None):
+        super().__init__(size=6, max_steps=max_steps)
 
 
 
 
 class DoorKeyEnv16x16(DoorKeyEnv):
 class DoorKeyEnv16x16(DoorKeyEnv):
-    def __init__(self):
-        super().__init__(size=16)
+    def __init__(self, max_steps=None):
+        super().__init__(size=16, max_steps=max_steps)
 
 
 
 
 register(id="MiniGrid-DoorKey-5x5-v0", entry_point="gym_minigrid.envs:DoorKeyEnv5x5")
 register(id="MiniGrid-DoorKey-5x5-v0", entry_point="gym_minigrid.envs:DoorKeyEnv5x5")