Sfoglia il codice sorgente

Add SimpleCrossing environment. (#47)

* Add SimpleCrossing environment.

* Update comments.

* Update README.md
Florin Gogianu 6 anni fa
parent
commit
adcbf2cdd6

+ 20 - 0
README.md

@@ -372,3 +372,23 @@ Each lava stream runs across the room either horizontally or vertically, and
 has a single crossing point which can be safely used;  Luckily, a path to the
 goal is guaranteed to exist. This environment is useful for studying safety and
 safe exploration.
+
+## Simple crossing environment
+
+Registered configurations:
+- `MiniGrid-SimpleCrossingS9N1-v0`
+- `MiniGrid-SimpleCrossingS9N2-v0`
+- `MiniGrid-SimpleCrossingS9N3-v0`
+- `MiniGrid-SimpleCrossingS11N5-v0`
+
+<p align="center">
+  <img src="figures/SimpleCrossingS9N1.png" width="200">
+  <img src="figures/SimpleCrossingS9N2.png" width="200">
+  <img src="figures/SimpleCrossingS9N3.png" width="200">
+  <img src="figures/SimpleCrossingS11N5.png" width="250">
+</p>
+
+Similar to the `LavaCrossing` environment, the agent has to reach the green
+goal square on the other corner of the room, however lava is replaced by
+walls. This MDP is therefore much easier and and maybe useful for quickly
+testing your algorithms.

BIN
figures/SimpleCrossingS11N5.png


BIN
figures/SimpleCrossingS9N1.png


BIN
figures/SimpleCrossingS9N2.png


BIN
figures/SimpleCrossingS9N3.png


+ 1 - 1
gym_minigrid/envs/__init__.py

@@ -15,4 +15,4 @@ from gym_minigrid.envs.redbluedoors import *
 from gym_minigrid.envs.obstructedmaze import *
 from gym_minigrid.envs.memory import *
 from gym_minigrid.envs.fourrooms import *
-from gym_minigrid.envs.lavacrossing import *
+from gym_minigrid.envs.crossing import *

+ 58 - 13
gym_minigrid/envs/lavacrossing.py

@@ -4,13 +4,14 @@ from gym_minigrid.register import register
 import itertools as itt
 
 
-class LavaCrossingEnv(MiniGridEnv):
+class CrossingEnv(MiniGridEnv):
     """
-    Environment with lava obstacles, sparse reward
+    Environment with wall or lava obstacles, sparse reward.
     """
 
-    def __init__(self, size=9, num_crossings=1, seed=None):
+    def __init__(self, size=9, num_crossings=1, obstacle_type=Lava, seed=None):
         self.num_crossings = num_crossings
+        self.obstacle_type = obstacle_type
         super().__init__(
             grid_size=size,
             max_steps=4*size*size,
@@ -35,28 +36,28 @@ class LavaCrossingEnv(MiniGridEnv):
         # Place a goal square in the bottom-right corner
         self.grid.set(width - 2, height - 2, Goal())
 
-        # Place lava tiles
+        # Place obstacles (lava or walls)
         v, h = object(), object()  # singleton `vertical` and `horizontal` objects
 
-        # Lava river specified by direction and position in grid
+        # Lava rivers or walls specified by direction and position in grid
         rivers = [(v, i) for i in range(2, height - 2, 2)]
         rivers += [(h, j) for j in range(2, width - 2, 2)]
         self.np_random.shuffle(rivers)
         rivers = rivers[:self.num_crossings]  # sample random rivers
         rivers_v = sorted([pos for direction, pos in rivers if direction is v])
         rivers_h = sorted([pos for direction, pos in rivers if direction is h])
-        lava_pos = itt.chain(
+        obstacle_pos = itt.chain(
             itt.product(range(1, width - 1), rivers_h),
             itt.product(rivers_v, range(1, height - 1)),
         )
-        for i, j in lava_pos:
-            self.grid.set(i, j, Lava())
+        for i, j in obstacle_pos:
+            self.grid.set(i, j, self.obstacle_type())
 
         # Sample path to goal
         path = [h] * len(rivers_v) + [v] * len(rivers_h)
         self.np_random.shuffle(path)
 
-        # Create openings in lava rivers
+        # Create openings
         limits_v = [0] + rivers_v + [height - 1]
         limits_h = [0] + rivers_h + [width - 1]
         room_i, room_j = 0, 0
@@ -75,17 +76,25 @@ class LavaCrossingEnv(MiniGridEnv):
                 assert False
             self.grid.set(i, j, None)
 
-        self.mission = "avoid the lava and get to the green goal square"
+        self.mission = (
+            "avoid the lava and get to the green goal square"
+            if self.obstacle_type == Lava
+            else "find the opening and get to the green goal square"
+        )
+
+class LavaCrossingEnv(CrossingEnv):
+    def __init__(self):
+        super().__init__(size=9, num_crossings=1)
 
-class LavaCrossingS9N2Env(LavaCrossingEnv):
+class LavaCrossingS9N2Env(CrossingEnv):
     def __init__(self):
         super().__init__(size=9, num_crossings=2)
 
-class LavaCrossingS9N3Env(LavaCrossingEnv):
+class LavaCrossingS9N3Env(CrossingEnv):
     def __init__(self):
         super().__init__(size=9, num_crossings=3)
 
-class LavaCrossingS11N5Env(LavaCrossingEnv):
+class LavaCrossingS11N5Env(CrossingEnv):
     def __init__(self):
         super().__init__(size=11, num_crossings=5)
 
@@ -108,3 +117,39 @@ register(
     id='MiniGrid-LavaCrossingS11N5-v0',
     entry_point='gym_minigrid.envs:LavaCrossingS11N5Env'
 )
+
+class SimpleCrossingEnv(CrossingEnv):
+    def __init__(self):
+        super().__init__(size=9, num_crossings=1, obstacle_type=Wall)
+
+class SimpleCrossingS9N2Env(CrossingEnv):
+    def __init__(self):
+        super().__init__(size=9, num_crossings=2, obstacle_type=Wall)
+
+class SimpleCrossingS9N3Env(CrossingEnv):
+    def __init__(self):
+        super().__init__(size=9, num_crossings=3, obstacle_type=Wall)
+
+class SimpleCrossingS11N5Env(CrossingEnv):
+    def __init__(self):
+        super().__init__(size=11, num_crossings=5, obstacle_type=Wall)
+
+register(
+    id='MiniGrid-SimpleCrossingS9N1-v0',
+    entry_point='gym_minigrid.envs:SimpleCrossingEnv'
+)
+
+register(
+    id='MiniGrid-SimpleCrossingS9N2-v0',
+    entry_point='gym_minigrid.envs:SimpleCrossingS9N2Env'
+)
+
+register(
+    id='MiniGrid-SimpleCrossingS9N3-v0',
+    entry_point='gym_minigrid.envs:SimpleCrossingS9N3Env'
+)
+
+register(
+    id='MiniGrid-SimpleCrossingS11N5-v0',
+    entry_point='gym_minigrid.envs:SimpleCrossingS11N5Env'
+)