소스 검색

Add SimpleCrossing environment. (#47)

* Add SimpleCrossing environment.

* Update comments.

* Update README.md
Florin Gogianu 6 년 전
부모
커밋
adcbf2cdd6

+ 20 - 0
README.md

@@ -372,3 +372,23 @@ Each lava stream runs across the room either horizontally or vertically, and
 has a single crossing point which can be safely used;  Luckily, a path to the
 has a single crossing point which can be safely used;  Luckily, a path to the
 goal is guaranteed to exist. This environment is useful for studying safety and
 goal is guaranteed to exist. This environment is useful for studying safety and
 safe exploration.
 safe exploration.
+
+## Simple crossing environment
+
+Registered configurations:
+- `MiniGrid-SimpleCrossingS9N1-v0`
+- `MiniGrid-SimpleCrossingS9N2-v0`
+- `MiniGrid-SimpleCrossingS9N3-v0`
+- `MiniGrid-SimpleCrossingS11N5-v0`
+
+<p align="center">
+  <img src="figures/SimpleCrossingS9N1.png" width="200">
+  <img src="figures/SimpleCrossingS9N2.png" width="200">
+  <img src="figures/SimpleCrossingS9N3.png" width="200">
+  <img src="figures/SimpleCrossingS11N5.png" width="250">
+</p>
+
+Similar to the `LavaCrossing` environment, the agent has to reach the green
+goal square on the other corner of the room, however lava is replaced by
+walls. This MDP is therefore much easier and and maybe useful for quickly
+testing your algorithms.

BIN
figures/SimpleCrossingS11N5.png


BIN
figures/SimpleCrossingS9N1.png


BIN
figures/SimpleCrossingS9N2.png


BIN
figures/SimpleCrossingS9N3.png


+ 1 - 1
gym_minigrid/envs/__init__.py

@@ -15,4 +15,4 @@ from gym_minigrid.envs.redbluedoors import *
 from gym_minigrid.envs.obstructedmaze import *
 from gym_minigrid.envs.obstructedmaze import *
 from gym_minigrid.envs.memory import *
 from gym_minigrid.envs.memory import *
 from gym_minigrid.envs.fourrooms import *
 from gym_minigrid.envs.fourrooms import *
-from gym_minigrid.envs.lavacrossing import *
+from gym_minigrid.envs.crossing import *

+ 58 - 13
gym_minigrid/envs/lavacrossing.py

@@ -4,13 +4,14 @@ from gym_minigrid.register import register
 import itertools as itt
 import itertools as itt
 
 
 
 
-class LavaCrossingEnv(MiniGridEnv):
+class CrossingEnv(MiniGridEnv):
     """
     """
-    Environment with lava obstacles, sparse reward
+    Environment with wall or lava obstacles, sparse reward.
     """
     """
 
 
-    def __init__(self, size=9, num_crossings=1, seed=None):
+    def __init__(self, size=9, num_crossings=1, obstacle_type=Lava, seed=None):
         self.num_crossings = num_crossings
         self.num_crossings = num_crossings
+        self.obstacle_type = obstacle_type
         super().__init__(
         super().__init__(
             grid_size=size,
             grid_size=size,
             max_steps=4*size*size,
             max_steps=4*size*size,
@@ -35,28 +36,28 @@ class LavaCrossingEnv(MiniGridEnv):
         # Place a goal square in the bottom-right corner
         # Place a goal square in the bottom-right corner
         self.grid.set(width - 2, height - 2, Goal())
         self.grid.set(width - 2, height - 2, Goal())
 
 
-        # Place lava tiles
+        # Place obstacles (lava or walls)
         v, h = object(), object()  # singleton `vertical` and `horizontal` objects
         v, h = object(), object()  # singleton `vertical` and `horizontal` objects
 
 
-        # Lava river specified by direction and position in grid
+        # Lava rivers or walls specified by direction and position in grid
         rivers = [(v, i) for i in range(2, height - 2, 2)]
         rivers = [(v, i) for i in range(2, height - 2, 2)]
         rivers += [(h, j) for j in range(2, width - 2, 2)]
         rivers += [(h, j) for j in range(2, width - 2, 2)]
         self.np_random.shuffle(rivers)
         self.np_random.shuffle(rivers)
         rivers = rivers[:self.num_crossings]  # sample random rivers
         rivers = rivers[:self.num_crossings]  # sample random rivers
         rivers_v = sorted([pos for direction, pos in rivers if direction is v])
         rivers_v = sorted([pos for direction, pos in rivers if direction is v])
         rivers_h = sorted([pos for direction, pos in rivers if direction is h])
         rivers_h = sorted([pos for direction, pos in rivers if direction is h])
-        lava_pos = itt.chain(
+        obstacle_pos = itt.chain(
             itt.product(range(1, width - 1), rivers_h),
             itt.product(range(1, width - 1), rivers_h),
             itt.product(rivers_v, range(1, height - 1)),
             itt.product(rivers_v, range(1, height - 1)),
         )
         )
-        for i, j in lava_pos:
-            self.grid.set(i, j, Lava())
+        for i, j in obstacle_pos:
+            self.grid.set(i, j, self.obstacle_type())
 
 
         # Sample path to goal
         # Sample path to goal
         path = [h] * len(rivers_v) + [v] * len(rivers_h)
         path = [h] * len(rivers_v) + [v] * len(rivers_h)
         self.np_random.shuffle(path)
         self.np_random.shuffle(path)
 
 
-        # Create openings in lava rivers
+        # Create openings
         limits_v = [0] + rivers_v + [height - 1]
         limits_v = [0] + rivers_v + [height - 1]
         limits_h = [0] + rivers_h + [width - 1]
         limits_h = [0] + rivers_h + [width - 1]
         room_i, room_j = 0, 0
         room_i, room_j = 0, 0
@@ -75,17 +76,25 @@ class LavaCrossingEnv(MiniGridEnv):
                 assert False
                 assert False
             self.grid.set(i, j, None)
             self.grid.set(i, j, None)
 
 
-        self.mission = "avoid the lava and get to the green goal square"
+        self.mission = (
+            "avoid the lava and get to the green goal square"
+            if self.obstacle_type == Lava
+            else "find the opening and get to the green goal square"
+        )
+
+class LavaCrossingEnv(CrossingEnv):
+    def __init__(self):
+        super().__init__(size=9, num_crossings=1)
 
 
-class LavaCrossingS9N2Env(LavaCrossingEnv):
+class LavaCrossingS9N2Env(CrossingEnv):
     def __init__(self):
     def __init__(self):
         super().__init__(size=9, num_crossings=2)
         super().__init__(size=9, num_crossings=2)
 
 
-class LavaCrossingS9N3Env(LavaCrossingEnv):
+class LavaCrossingS9N3Env(CrossingEnv):
     def __init__(self):
     def __init__(self):
         super().__init__(size=9, num_crossings=3)
         super().__init__(size=9, num_crossings=3)
 
 
-class LavaCrossingS11N5Env(LavaCrossingEnv):
+class LavaCrossingS11N5Env(CrossingEnv):
     def __init__(self):
     def __init__(self):
         super().__init__(size=11, num_crossings=5)
         super().__init__(size=11, num_crossings=5)
 
 
@@ -108,3 +117,39 @@ register(
     id='MiniGrid-LavaCrossingS11N5-v0',
     id='MiniGrid-LavaCrossingS11N5-v0',
     entry_point='gym_minigrid.envs:LavaCrossingS11N5Env'
     entry_point='gym_minigrid.envs:LavaCrossingS11N5Env'
 )
 )
+
+class SimpleCrossingEnv(CrossingEnv):
+    def __init__(self):
+        super().__init__(size=9, num_crossings=1, obstacle_type=Wall)
+
+class SimpleCrossingS9N2Env(CrossingEnv):
+    def __init__(self):
+        super().__init__(size=9, num_crossings=2, obstacle_type=Wall)
+
+class SimpleCrossingS9N3Env(CrossingEnv):
+    def __init__(self):
+        super().__init__(size=9, num_crossings=3, obstacle_type=Wall)
+
+class SimpleCrossingS11N5Env(CrossingEnv):
+    def __init__(self):
+        super().__init__(size=11, num_crossings=5, obstacle_type=Wall)
+
+register(
+    id='MiniGrid-SimpleCrossingS9N1-v0',
+    entry_point='gym_minigrid.envs:SimpleCrossingEnv'
+)
+
+register(
+    id='MiniGrid-SimpleCrossingS9N2-v0',
+    entry_point='gym_minigrid.envs:SimpleCrossingS9N2Env'
+)
+
+register(
+    id='MiniGrid-SimpleCrossingS9N3-v0',
+    entry_point='gym_minigrid.envs:SimpleCrossingS9N3Env'
+)
+
+register(
+    id='MiniGrid-SimpleCrossingS11N5-v0',
+    entry_point='gym_minigrid.envs:SimpleCrossingS11N5Env'
+)