|
@@ -12,7 +12,8 @@ class ObstructedMazeEnv(RoomGrid):
|
|
num_rows,
|
|
num_rows,
|
|
num_cols,
|
|
num_cols,
|
|
num_rooms_visited,
|
|
num_rooms_visited,
|
|
- seed=None
|
|
|
|
|
|
+ seed=None,
|
|
|
|
+ **kwargs
|
|
):
|
|
):
|
|
room_size = 6
|
|
room_size = 6
|
|
max_steps = 4*num_rooms_visited*room_size**2
|
|
max_steps = 4*num_rooms_visited*room_size**2
|
|
@@ -22,7 +23,8 @@ class ObstructedMazeEnv(RoomGrid):
|
|
num_rows=num_rows,
|
|
num_rows=num_rows,
|
|
num_cols=num_cols,
|
|
num_cols=num_cols,
|
|
max_steps=max_steps,
|
|
max_steps=max_steps,
|
|
- seed=seed
|
|
|
|
|
|
+ seed=seed,
|
|
|
|
+ **kwargs
|
|
)
|
|
)
|
|
|
|
|
|
def _gen_grid(self, width, height):
|
|
def _gen_grid(self, width, height):
|
|
@@ -79,7 +81,7 @@ class ObstructedMaze_1Dlhb(ObstructedMazeEnv):
|
|
rooms. Doors are obstructed by a ball and keys are hidden in boxes.
|
|
rooms. Doors are obstructed by a ball and keys are hidden in boxes.
|
|
"""
|
|
"""
|
|
|
|
|
|
- def __init__(self, key_in_box=True, blocked=True, seed=None):
|
|
|
|
|
|
+ def __init__(self, key_in_box=True, blocked=True, seed=None, **kwargs):
|
|
self.key_in_box = key_in_box
|
|
self.key_in_box = key_in_box
|
|
self.blocked = blocked
|
|
self.blocked = blocked
|
|
|
|
|
|
@@ -87,7 +89,8 @@ class ObstructedMaze_1Dlhb(ObstructedMazeEnv):
|
|
num_rows=1,
|
|
num_rows=1,
|
|
num_cols=2,
|
|
num_cols=2,
|
|
num_rooms_visited=2,
|
|
num_rooms_visited=2,
|
|
- seed=seed
|
|
|
|
|
|
+ seed=seed,
|
|
|
|
+ **kwargs
|
|
)
|
|
)
|
|
|
|
|
|
def _gen_grid(self, width, height):
|
|
def _gen_grid(self, width, height):
|
|
@@ -102,12 +105,12 @@ class ObstructedMaze_1Dlhb(ObstructedMazeEnv):
|
|
self.place_agent(0, 0)
|
|
self.place_agent(0, 0)
|
|
|
|
|
|
class ObstructedMaze_1Dl(ObstructedMaze_1Dlhb):
|
|
class ObstructedMaze_1Dl(ObstructedMaze_1Dlhb):
|
|
- def __init__(self, seed=None):
|
|
|
|
- super().__init__(False, False, seed)
|
|
|
|
|
|
+ def __init__(self, seed=None, **kwargs):
|
|
|
|
+ super().__init__(False, False, seed, **kwargs)
|
|
|
|
|
|
class ObstructedMaze_1Dlh(ObstructedMaze_1Dlhb):
|
|
class ObstructedMaze_1Dlh(ObstructedMaze_1Dlhb):
|
|
- def __init__(self, seed=None):
|
|
|
|
- super().__init__(True, False, seed)
|
|
|
|
|
|
+ def __init__(self, seed=None, **kwargs):
|
|
|
|
+ super().__init__(True, False, seed, **kwargs)
|
|
|
|
|
|
class ObstructedMaze_Full(ObstructedMazeEnv):
|
|
class ObstructedMaze_Full(ObstructedMazeEnv):
|
|
"""
|
|
"""
|
|
@@ -117,7 +120,7 @@ class ObstructedMaze_Full(ObstructedMazeEnv):
|
|
"""
|
|
"""
|
|
|
|
|
|
def __init__(self, agent_room=(1, 1), key_in_box=True, blocked=True,
|
|
def __init__(self, agent_room=(1, 1), key_in_box=True, blocked=True,
|
|
- num_quarters=4, num_rooms_visited=25, seed=None):
|
|
|
|
|
|
+ num_quarters=4, num_rooms_visited=25, seed=None, **kwargs):
|
|
self.agent_room = agent_room
|
|
self.agent_room = agent_room
|
|
self.key_in_box = key_in_box
|
|
self.key_in_box = key_in_box
|
|
self.blocked = blocked
|
|
self.blocked = blocked
|
|
@@ -127,7 +130,8 @@ class ObstructedMaze_Full(ObstructedMazeEnv):
|
|
num_rows=3,
|
|
num_rows=3,
|
|
num_cols=3,
|
|
num_cols=3,
|
|
num_rooms_visited=num_rooms_visited,
|
|
num_rooms_visited=num_rooms_visited,
|
|
- seed=seed
|
|
|
|
|
|
+ seed=seed,
|
|
|
|
+ **kwargs
|
|
)
|
|
)
|
|
|
|
|
|
def _gen_grid(self, width, height):
|
|
def _gen_grid(self, width, height):
|
|
@@ -158,25 +162,25 @@ class ObstructedMaze_Full(ObstructedMazeEnv):
|
|
self.place_agent(*self.agent_room)
|
|
self.place_agent(*self.agent_room)
|
|
|
|
|
|
class ObstructedMaze_2Dl(ObstructedMaze_Full):
|
|
class ObstructedMaze_2Dl(ObstructedMaze_Full):
|
|
- def __init__(self, seed=None):
|
|
|
|
- super().__init__((2, 1), False, False, 1, 4, seed)
|
|
|
|
|
|
+ def __init__(self, seed=None, **kwargs):
|
|
|
|
+ super().__init__((2, 1), False, False, 1, 4, seed, **kwargs)
|
|
|
|
|
|
class ObstructedMaze_2Dlh(ObstructedMaze_Full):
|
|
class ObstructedMaze_2Dlh(ObstructedMaze_Full):
|
|
- def __init__(self, seed=None):
|
|
|
|
- super().__init__((2, 1), True, False, 1, 4, seed)
|
|
|
|
|
|
+ def __init__(self, seed=None, **kwargs):
|
|
|
|
+ super().__init__((2, 1), True, False, 1, 4, seed, **kwargs)
|
|
|
|
|
|
|
|
|
|
class ObstructedMaze_2Dlhb(ObstructedMaze_Full):
|
|
class ObstructedMaze_2Dlhb(ObstructedMaze_Full):
|
|
- def __init__(self, seed=None):
|
|
|
|
- super().__init__((2, 1), True, True, 1, 4, seed)
|
|
|
|
|
|
+ def __init__(self, seed=None, **kwargs):
|
|
|
|
+ super().__init__((2, 1), True, True, 1, 4, seed, **kwargs)
|
|
|
|
|
|
class ObstructedMaze_1Q(ObstructedMaze_Full):
|
|
class ObstructedMaze_1Q(ObstructedMaze_Full):
|
|
- def __init__(self, seed=None):
|
|
|
|
- super().__init__((1, 1), True, True, 1, 5, seed)
|
|
|
|
|
|
+ def __init__(self, seed=None, **kwargs):
|
|
|
|
+ super().__init__((1, 1), True, True, 1, 5, seed, **kwargs)
|
|
|
|
|
|
class ObstructedMaze_2Q(ObstructedMaze_Full):
|
|
class ObstructedMaze_2Q(ObstructedMaze_Full):
|
|
- def __init__(self, seed=None):
|
|
|
|
- super().__init__((1, 1), True, True, 2, 11, seed)
|
|
|
|
|
|
+ def __init__(self, seed=None, **kwargs):
|
|
|
|
+ super().__init__((1, 1), True, True, 2, 11, seed, **kwargs)
|
|
|
|
|
|
register(
|
|
register(
|
|
id="MiniGrid-ObstructedMaze-1Dl-v0",
|
|
id="MiniGrid-ObstructedMaze-1Dl-v0",
|