|
@@ -4,13 +4,14 @@ from gym_minigrid.register import register
|
|
|
import itertools as itt
|
|
|
|
|
|
|
|
|
-class LavaCrossingEnv(MiniGridEnv):
|
|
|
+class CrossingEnv(MiniGridEnv):
|
|
|
"""
|
|
|
- Environment with lava obstacles, sparse reward
|
|
|
+ Environment with wall or lava obstacles, sparse reward.
|
|
|
"""
|
|
|
|
|
|
- def __init__(self, size=9, num_crossings=1, seed=None):
|
|
|
+ def __init__(self, size=9, num_crossings=1, obstacle_type=Lava, seed=None):
|
|
|
self.num_crossings = num_crossings
|
|
|
+ self.obstacle_type = obstacle_type
|
|
|
super().__init__(
|
|
|
grid_size=size,
|
|
|
max_steps=4*size*size,
|
|
@@ -35,28 +36,28 @@ class LavaCrossingEnv(MiniGridEnv):
|
|
|
# Place a goal square in the bottom-right corner
|
|
|
self.grid.set(width - 2, height - 2, Goal())
|
|
|
|
|
|
- # Place lava tiles
|
|
|
+ # Place obstacles (lava or walls)
|
|
|
v, h = object(), object() # singleton `vertical` and `horizontal` objects
|
|
|
|
|
|
- # Lava river specified by direction and position in grid
|
|
|
+ # Lava rivers or walls specified by direction and position in grid
|
|
|
rivers = [(v, i) for i in range(2, height - 2, 2)]
|
|
|
rivers += [(h, j) for j in range(2, width - 2, 2)]
|
|
|
self.np_random.shuffle(rivers)
|
|
|
rivers = rivers[:self.num_crossings] # sample random rivers
|
|
|
rivers_v = sorted([pos for direction, pos in rivers if direction is v])
|
|
|
rivers_h = sorted([pos for direction, pos in rivers if direction is h])
|
|
|
- lava_pos = itt.chain(
|
|
|
+ obstacle_pos = itt.chain(
|
|
|
itt.product(range(1, width - 1), rivers_h),
|
|
|
itt.product(rivers_v, range(1, height - 1)),
|
|
|
)
|
|
|
- for i, j in lava_pos:
|
|
|
- self.grid.set(i, j, Lava())
|
|
|
+ for i, j in obstacle_pos:
|
|
|
+ self.grid.set(i, j, self.obstacle_type())
|
|
|
|
|
|
# Sample path to goal
|
|
|
path = [h] * len(rivers_v) + [v] * len(rivers_h)
|
|
|
self.np_random.shuffle(path)
|
|
|
|
|
|
- # Create openings in lava rivers
|
|
|
+ # Create openings
|
|
|
limits_v = [0] + rivers_v + [height - 1]
|
|
|
limits_h = [0] + rivers_h + [width - 1]
|
|
|
room_i, room_j = 0, 0
|
|
@@ -75,17 +76,25 @@ class LavaCrossingEnv(MiniGridEnv):
|
|
|
assert False
|
|
|
self.grid.set(i, j, None)
|
|
|
|
|
|
- self.mission = "avoid the lava and get to the green goal square"
|
|
|
+ self.mission = (
|
|
|
+ "avoid the lava and get to the green goal square"
|
|
|
+ if self.obstacle_type == Lava
|
|
|
+ else "find the opening and get to the green goal square"
|
|
|
+ )
|
|
|
+
|
|
|
+class LavaCrossingEnv(CrossingEnv):
|
|
|
+ def __init__(self):
|
|
|
+ super().__init__(size=9, num_crossings=1)
|
|
|
|
|
|
-class LavaCrossingS9N2Env(LavaCrossingEnv):
|
|
|
+class LavaCrossingS9N2Env(CrossingEnv):
|
|
|
def __init__(self):
|
|
|
super().__init__(size=9, num_crossings=2)
|
|
|
|
|
|
-class LavaCrossingS9N3Env(LavaCrossingEnv):
|
|
|
+class LavaCrossingS9N3Env(CrossingEnv):
|
|
|
def __init__(self):
|
|
|
super().__init__(size=9, num_crossings=3)
|
|
|
|
|
|
-class LavaCrossingS11N5Env(LavaCrossingEnv):
|
|
|
+class LavaCrossingS11N5Env(CrossingEnv):
|
|
|
def __init__(self):
|
|
|
super().__init__(size=11, num_crossings=5)
|
|
|
|
|
@@ -108,3 +117,39 @@ register(
|
|
|
id='MiniGrid-LavaCrossingS11N5-v0',
|
|
|
entry_point='gym_minigrid.envs:LavaCrossingS11N5Env'
|
|
|
)
|
|
|
+
|
|
|
+class SimpleCrossingEnv(CrossingEnv):
|
|
|
+ def __init__(self):
|
|
|
+ super().__init__(size=9, num_crossings=1, obstacle_type=Wall)
|
|
|
+
|
|
|
+class SimpleCrossingS9N2Env(CrossingEnv):
|
|
|
+ def __init__(self):
|
|
|
+ super().__init__(size=9, num_crossings=2, obstacle_type=Wall)
|
|
|
+
|
|
|
+class SimpleCrossingS9N3Env(CrossingEnv):
|
|
|
+ def __init__(self):
|
|
|
+ super().__init__(size=9, num_crossings=3, obstacle_type=Wall)
|
|
|
+
|
|
|
+class SimpleCrossingS11N5Env(CrossingEnv):
|
|
|
+ def __init__(self):
|
|
|
+ super().__init__(size=11, num_crossings=5, obstacle_type=Wall)
|
|
|
+
|
|
|
+register(
|
|
|
+ id='MiniGrid-SimpleCrossingS9N1-v0',
|
|
|
+ entry_point='gym_minigrid.envs:SimpleCrossingEnv'
|
|
|
+)
|
|
|
+
|
|
|
+register(
|
|
|
+ id='MiniGrid-SimpleCrossingS9N2-v0',
|
|
|
+ entry_point='gym_minigrid.envs:SimpleCrossingS9N2Env'
|
|
|
+)
|
|
|
+
|
|
|
+register(
|
|
|
+ id='MiniGrid-SimpleCrossingS9N3-v0',
|
|
|
+ entry_point='gym_minigrid.envs:SimpleCrossingS9N3Env'
|
|
|
+)
|
|
|
+
|
|
|
+register(
|
|
|
+ id='MiniGrid-SimpleCrossingS11N5-v0',
|
|
|
+ entry_point='gym_minigrid.envs:SimpleCrossingS11N5Env'
|
|
|
+)
|