Parcourir la source

Implement Lava tiles and LavaCrossing environment (#7) (#29)

* Implement Lava tiles and LavaCrossing environment (#7)

 * Add new tile type Lava

 * Add new environment LavaCrossing

 * Add new drawPolyine method to Renderer

 * Add README description and figures for LavaCrossing

* Update lavacrossing.py
A. Baisero il y a 6 ans
Parent
commit
0be0e71ba3

+ 21 - 0
README.md

@@ -349,3 +349,24 @@ by balls. This environment can be solved without relying on language.
 The agent has to pick up a box which is placed in a corner of a 3x3 maze.
 The doors are locked, the keys are hidden in boxes and doors are obstructed
 by balls. This environment can be solved without relying on language.
+
+## Lava crossing environment
+
+Registered configurations:
+- `MiniGrid-LavaCrossing-S9N1-v0`
+- `MiniGrid-LavaCrossing-S9N2-v0`
+- `MiniGrid-LavaCrossing-S9N3-v0`
+- `MiniGrid-LavaCrossing-S11N5-v0`
+
+<p align="center">
+  <img src="figures/LavaCrossingS9N1.png" width="200">
+  <img src="figures/LavaCrossingS9N2.png" width="200">
+  <img src="figures/LavaCrossingS9N3.png" width="200">
+  <img src="figures/LavaCrossingS11N5.png" width="250">
+</p>
+
+The agent has to reach the green goal square on the other corner of the room
+while avoiding rivers of deadly lava which terminate the episode in failure.
+Each lava stream runs across the room either horizontally or vertically, and
+has a single crossing point which can be safely used;  Luckily, a path to the
+goal is guaranteed to exist.

BIN
figures/LavaCrossingS11N5.png


BIN
figures/LavaCrossingS9N1.png


BIN
figures/LavaCrossingS9N2.png


BIN
figures/LavaCrossingS9N3.png


+ 1 - 0
gym_minigrid/envs/__init__.py

@@ -15,3 +15,4 @@ from gym_minigrid.envs.redbluedoors import *
 from gym_minigrid.envs.obstructedmaze import *
 from gym_minigrid.envs.memory import *
 from gym_minigrid.envs.fourrooms import *
+from gym_minigrid.envs.lavacrossing import *

+ 110 - 0
gym_minigrid/envs/lavacrossing.py

@@ -0,0 +1,110 @@
+from gym_minigrid.minigrid import *
+from gym_minigrid.register import register
+
+import itertools as itt
+
+
+class LavaCrossingEnv(MiniGridEnv):
+    """
+    Environment with lava obstacles, sparse reward
+    """
+
+    def __init__(self, size=9, num_crossings=1, seed=None):
+        self.num_crossings = num_crossings
+        super().__init__(
+            grid_size=size,
+            max_steps=4*size*size,
+            # Set this to True for maximum speed
+            see_through_walls=False,
+            seed=None
+        )
+
+    def _gen_grid(self, width, height):
+        assert width % 2 == 1 and height % 2 == 1  # odd size
+
+        # Create an empty grid
+        self.grid = Grid(width, height)
+
+        # Generate the surrounding walls
+        self.grid.wall_rect(0, 0, width, height)
+
+        # Place the agent in the top-left corner
+        self.start_pos = (1, 1)
+        self.start_dir = 0
+
+        # Place a goal square in the bottom-right corner
+        self.grid.set(width - 2, height - 2, Goal())
+
+        # Place lava tiles
+        v, h = object(), object()  # singleton `vertical` and `horizontal` objects
+
+        # Lava river specified by direction and position in grid
+        rivers = [(v, i) for i in range(2, height - 2, 2)]
+        rivers += [(h, j) for j in range(2, width - 2, 2)]
+        self.np_random.shuffle(rivers)
+        rivers = rivers[:self.num_crossings]  # sample random rivers
+        rivers_v = sorted([pos for direction, pos in rivers if direction is v])
+        rivers_h = sorted([pos for direction, pos in rivers if direction is h])
+        lava_pos = itt.chain(
+            itt.product(range(1, width - 1), rivers_h),
+            itt.product(rivers_v, range(1, height - 1)),
+        )
+        for i, j in lava_pos:
+            self.grid.set(i, j, Lava())
+
+        # Sample path to goal
+        path = [h] * len(rivers_v) + [v] * len(rivers_h)
+        self.np_random.shuffle(path)
+
+        # Create openings in lava rivers
+        limits_v = [0] + rivers_v + [height - 1]
+        limits_h = [0] + rivers_h + [width - 1]
+        room_i, room_j = 0, 0
+        for direction in path:
+            if direction is h:
+                i = limits_v[room_i + 1]
+                j = self.np_random.choice(
+                    range(limits_h[room_j] + 1, limits_h[room_j + 1]))
+                room_i += 1
+            elif direction is v:
+                i = self.np_random.choice(
+                    range(limits_v[room_i] + 1, limits_v[room_i + 1]))
+                j = limits_h[room_j + 1]
+                room_j += 1
+            else:
+                assert False
+            self.grid.set(i, j, None)
+
+        self.mission = "avoid the lava and get to the green goal square"
+
+class LavaCrossingS9N2Env(LavaCrossingEnv):
+    def __init__(self):
+        super().__init__(size=9, num_crossings=2)
+
+class LavaCrossingS9N3Env(LavaCrossingEnv):
+    def __init__(self):
+        super().__init__(size=9, num_crossings=3)
+
+class LavaCrossingS11N5Env(LavaCrossingEnv):
+    def __init__(self):
+        super().__init__(size=11, num_crossings=5)
+
+register(
+    id='MiniGrid-LavaCrossingS9N1-v0',
+    entry_point='gym_minigrid.envs:LavaCrossingEnv'
+)
+
+register(
+    id='MiniGrid-LavaCrossingS9N2-v0',
+    entry_point='gym_minigrid.envs:LavaCrossingS9N2Env'
+)
+
+register(
+    id='MiniGrid-LavaCrossingS9N3-v0',
+    entry_point='gym_minigrid.envs:LavaCrossingS9N3Env'
+)
+
+register(
+    id='MiniGrid-LavaCrossingS11N5-v0',
+    entry_point='gym_minigrid.envs:LavaCrossingS11N5Env'
+)

+ 52 - 2
gym_minigrid/minigrid.py

@@ -48,7 +48,8 @@ OBJECT_TO_IDX = {
     'key'           : 5,
     'ball'          : 6,
     'box'           : 7,
-    'goal'          : 8
+    'goal'          : 8,
+    'lava'          : 9
 }
 
 IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys()))
@@ -152,6 +153,51 @@ class Floor(WorldObj):
             (1          ,           1)
         ])
 
+class Lava(WorldObj):
+    def __init__(self):
+        super().__init__('lava', 'red')
+
+    def can_overlap(self):
+        return True
+
+    def render(self, r):
+        orange = 255, 128, 0
+        r.setLineColor(*orange)
+        r.setColor(*orange)
+        r.drawPolygon([
+            (0          , CELL_PIXELS),
+            (CELL_PIXELS, CELL_PIXELS),
+            (CELL_PIXELS, 0),
+            (0          , 0)
+        ])
+
+        # drawing the waves
+        r.setLineColor(0, 0, 0)
+
+        r.drawPolyline([
+            (.1 * CELL_PIXELS, .3 * CELL_PIXELS),
+            (.3 * CELL_PIXELS, .4 * CELL_PIXELS),
+            (.5 * CELL_PIXELS, .3 * CELL_PIXELS),
+            (.7 * CELL_PIXELS, .4 * CELL_PIXELS),
+            (.9 * CELL_PIXELS, .3 * CELL_PIXELS),
+        ])
+
+        r.drawPolyline([
+            (.1 * CELL_PIXELS, .5 * CELL_PIXELS),
+            (.3 * CELL_PIXELS, .6 * CELL_PIXELS),
+            (.5 * CELL_PIXELS, .5 * CELL_PIXELS),
+            (.7 * CELL_PIXELS, .6 * CELL_PIXELS),
+            (.9 * CELL_PIXELS, .5 * CELL_PIXELS),
+        ])
+
+        r.drawPolyline([
+            (.1 * CELL_PIXELS, .7 * CELL_PIXELS),
+            (.3 * CELL_PIXELS, .8 * CELL_PIXELS),
+            (.5 * CELL_PIXELS, .7 * CELL_PIXELS),
+            (.7 * CELL_PIXELS, .8 * CELL_PIXELS),
+            (.9 * CELL_PIXELS, .7 * CELL_PIXELS),
+        ])
+
 class Wall(WorldObj):
     def __init__(self, color='grey'):
         super().__init__('wall', color)
@@ -571,6 +617,8 @@ class Grid:
                     v = LockedDoor(color, is_open)
                 elif objType == 'goal':
                     v = Goal()
+                elif objType == 'lava':
+                    v = Lava()
                 else:
                     assert False, "unknown obj type in decode '%s'" % objType
 
@@ -1117,6 +1165,8 @@ class MiniGridEnv(gym.Env):
             if fwd_cell != None and fwd_cell.type == 'goal':
                 done = True
                 reward = self._reward()
+            if fwd_cell != None and fwd_cell.type == 'lava':
+                done = True
 
         # Pick up an object
         elif action == self.actions.pickup:
@@ -1327,4 +1377,4 @@ class MiniGridEnv(gym.Env):
         elif mode == 'pixmap':
             return r.getPixmap()
 
-        return r
+        return r

+ 5 - 0
gym_minigrid/rendering.py

@@ -176,5 +176,10 @@ class Renderer:
         points = map(lambda p: QPoint(p[0], p[1]), points)
         self.painter.drawPolygon(QPolygon(points))
 
+    def drawPolyline(self, points):
+        """Takes a list of points (tuples) as input"""
+        points = map(lambda p: QPoint(p[0], p[1]), points)
+        self.painter.drawPolyline(QPolygon(points))
+
     def fillRect(self, x, y, width, height, r, g, b, a=255):
         self.painter.fillRect(QRect(x, y, width, height), QColor(r, g, b, a))