瀏覽代碼

allow passing kwargs to EmptyEnvNxN (#83)

(agent_start_pos and agent_dir)
rosea-tf 5 年之前
父節點
當前提交
ded28f95ae
共有 1 個文件被更改,包括 6 次插入6 次删除
  1. 6 6
      gym_minigrid/envs/empty.py

+ 6 - 6
gym_minigrid/envs/empty.py

@@ -42,24 +42,24 @@ class EmptyEnv(MiniGridEnv):
         self.mission = "get to the green goal square"
         self.mission = "get to the green goal square"
 
 
 class EmptyEnv5x5(EmptyEnv):
 class EmptyEnv5x5(EmptyEnv):
-    def __init__(self):
-        super().__init__(size=5)
+    def __init__(self, **kwargs):
+        super().__init__(size=5, **kwargs)
 
 
 class EmptyRandomEnv5x5(EmptyEnv):
 class EmptyRandomEnv5x5(EmptyEnv):
     def __init__(self):
     def __init__(self):
         super().__init__(size=5, agent_start_pos=None)
         super().__init__(size=5, agent_start_pos=None)
 
 
 class EmptyEnv6x6(EmptyEnv):
 class EmptyEnv6x6(EmptyEnv):
-    def __init__(self):
-        super().__init__(size=6)
+    def __init__(self, **kwargs):
+        super().__init__(size=6, **kwargs)
 
 
 class EmptyRandomEnv6x6(EmptyEnv):
 class EmptyRandomEnv6x6(EmptyEnv):
     def __init__(self):
     def __init__(self):
         super().__init__(size=6, agent_start_pos=None)
         super().__init__(size=6, agent_start_pos=None)
 
 
 class EmptyEnv16x16(EmptyEnv):
 class EmptyEnv16x16(EmptyEnv):
-    def __init__(self):
-        super().__init__(size=16)
+    def __init__(self, **kwargs):
+        super().__init__(size=16, **kwargs)
 
 
 register(
 register(
     id='MiniGrid-Empty-5x5-v0',
     id='MiniGrid-Empty-5x5-v0',