Browse Source

Implement Lava tiles and LavaCrossing environment (#7) (#29)

* Implement Lava tiles and LavaCrossing environment (#7)

 * Add new tile type Lava

 * Add new environment LavaCrossing

 * Add new drawPolyine method to Renderer

 * Add README description and figures for LavaCrossing

* Update lavacrossing.py
A. Baisero 6 years ago
parent
commit
0be0e71ba3

+ 21 - 0
README.md

@@ -349,3 +349,24 @@ by balls. This environment can be solved without relying on language.
 The agent has to pick up a box which is placed in a corner of a 3x3 maze.
 The agent has to pick up a box which is placed in a corner of a 3x3 maze.
 The doors are locked, the keys are hidden in boxes and doors are obstructed
 The doors are locked, the keys are hidden in boxes and doors are obstructed
 by balls. This environment can be solved without relying on language.
 by balls. This environment can be solved without relying on language.
+
+## Lava crossing environment
+
+Registered configurations:
+- `MiniGrid-LavaCrossing-S9N1-v0`
+- `MiniGrid-LavaCrossing-S9N2-v0`
+- `MiniGrid-LavaCrossing-S9N3-v0`
+- `MiniGrid-LavaCrossing-S11N5-v0`
+
+<p align="center">
+  <img src="figures/LavaCrossingS9N1.png" width="200">
+  <img src="figures/LavaCrossingS9N2.png" width="200">
+  <img src="figures/LavaCrossingS9N3.png" width="200">
+  <img src="figures/LavaCrossingS11N5.png" width="250">
+</p>
+
+The agent has to reach the green goal square on the other corner of the room
+while avoiding rivers of deadly lava which terminate the episode in failure.
+Each lava stream runs across the room either horizontally or vertically, and
+has a single crossing point which can be safely used;  Luckily, a path to the
+goal is guaranteed to exist.

BIN
figures/LavaCrossingS11N5.png


BIN
figures/LavaCrossingS9N1.png


BIN
figures/LavaCrossingS9N2.png


BIN
figures/LavaCrossingS9N3.png


+ 1 - 0
gym_minigrid/envs/__init__.py

@@ -15,3 +15,4 @@ from gym_minigrid.envs.redbluedoors import *
 from gym_minigrid.envs.obstructedmaze import *
 from gym_minigrid.envs.obstructedmaze import *
 from gym_minigrid.envs.memory import *
 from gym_minigrid.envs.memory import *
 from gym_minigrid.envs.fourrooms import *
 from gym_minigrid.envs.fourrooms import *
+from gym_minigrid.envs.lavacrossing import *

+ 110 - 0
gym_minigrid/envs/lavacrossing.py

@@ -0,0 +1,110 @@
+from gym_minigrid.minigrid import *
+from gym_minigrid.register import register
+
+import itertools as itt
+
+
+class LavaCrossingEnv(MiniGridEnv):
+    """
+    Environment with lava obstacles, sparse reward
+    """
+
+    def __init__(self, size=9, num_crossings=1, seed=None):
+        self.num_crossings = num_crossings
+        super().__init__(
+            grid_size=size,
+            max_steps=4*size*size,
+            # Set this to True for maximum speed
+            see_through_walls=False,
+            seed=None
+        )
+
+    def _gen_grid(self, width, height):
+        assert width % 2 == 1 and height % 2 == 1  # odd size
+
+        # Create an empty grid
+        self.grid = Grid(width, height)
+
+        # Generate the surrounding walls
+        self.grid.wall_rect(0, 0, width, height)
+
+        # Place the agent in the top-left corner
+        self.start_pos = (1, 1)
+        self.start_dir = 0
+
+        # Place a goal square in the bottom-right corner
+        self.grid.set(width - 2, height - 2, Goal())
+
+        # Place lava tiles
+        v, h = object(), object()  # singleton `vertical` and `horizontal` objects
+
+        # Lava river specified by direction and position in grid
+        rivers = [(v, i) for i in range(2, height - 2, 2)]
+        rivers += [(h, j) for j in range(2, width - 2, 2)]
+        self.np_random.shuffle(rivers)
+        rivers = rivers[:self.num_crossings]  # sample random rivers
+        rivers_v = sorted([pos for direction, pos in rivers if direction is v])
+        rivers_h = sorted([pos for direction, pos in rivers if direction is h])
+        lava_pos = itt.chain(
+            itt.product(range(1, width - 1), rivers_h),
+            itt.product(rivers_v, range(1, height - 1)),
+        )
+        for i, j in lava_pos:
+            self.grid.set(i, j, Lava())
+
+        # Sample path to goal
+        path = [h] * len(rivers_v) + [v] * len(rivers_h)
+        self.np_random.shuffle(path)
+
+        # Create openings in lava rivers
+        limits_v = [0] + rivers_v + [height - 1]
+        limits_h = [0] + rivers_h + [width - 1]
+        room_i, room_j = 0, 0
+        for direction in path:
+            if direction is h:
+                i = limits_v[room_i + 1]
+                j = self.np_random.choice(
+                    range(limits_h[room_j] + 1, limits_h[room_j + 1]))
+                room_i += 1
+            elif direction is v:
+                i = self.np_random.choice(
+                    range(limits_v[room_i] + 1, limits_v[room_i + 1]))
+                j = limits_h[room_j + 1]
+                room_j += 1
+            else:
+                assert False
+            self.grid.set(i, j, None)
+
+        self.mission = "avoid the lava and get to the green goal square"
+
+class LavaCrossingS9N2Env(LavaCrossingEnv):
+    def __init__(self):
+        super().__init__(size=9, num_crossings=2)
+
+class LavaCrossingS9N3Env(LavaCrossingEnv):
+    def __init__(self):
+        super().__init__(size=9, num_crossings=3)
+
+class LavaCrossingS11N5Env(LavaCrossingEnv):
+    def __init__(self):
+        super().__init__(size=11, num_crossings=5)
+
+register(
+    id='MiniGrid-LavaCrossingS9N1-v0',
+    entry_point='gym_minigrid.envs:LavaCrossingEnv'
+)
+
+register(
+    id='MiniGrid-LavaCrossingS9N2-v0',
+    entry_point='gym_minigrid.envs:LavaCrossingS9N2Env'
+)
+
+register(
+    id='MiniGrid-LavaCrossingS9N3-v0',
+    entry_point='gym_minigrid.envs:LavaCrossingS9N3Env'
+)
+
+register(
+    id='MiniGrid-LavaCrossingS11N5-v0',
+    entry_point='gym_minigrid.envs:LavaCrossingS11N5Env'
+)

+ 52 - 2
gym_minigrid/minigrid.py

@@ -48,7 +48,8 @@ OBJECT_TO_IDX = {
     'key'           : 5,
     'key'           : 5,
     'ball'          : 6,
     'ball'          : 6,
     'box'           : 7,
     'box'           : 7,
-    'goal'          : 8
+    'goal'          : 8,
+    'lava'          : 9
 }
 }
 
 
 IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys()))
 IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys()))
@@ -152,6 +153,51 @@ class Floor(WorldObj):
             (1          ,           1)
             (1          ,           1)
         ])
         ])
 
 
+class Lava(WorldObj):
+    def __init__(self):
+        super().__init__('lava', 'red')
+
+    def can_overlap(self):
+        return True
+
+    def render(self, r):
+        orange = 255, 128, 0
+        r.setLineColor(*orange)
+        r.setColor(*orange)
+        r.drawPolygon([
+            (0          , CELL_PIXELS),
+            (CELL_PIXELS, CELL_PIXELS),
+            (CELL_PIXELS, 0),
+            (0          , 0)
+        ])
+
+        # drawing the waves
+        r.setLineColor(0, 0, 0)
+
+        r.drawPolyline([
+            (.1 * CELL_PIXELS, .3 * CELL_PIXELS),
+            (.3 * CELL_PIXELS, .4 * CELL_PIXELS),
+            (.5 * CELL_PIXELS, .3 * CELL_PIXELS),
+            (.7 * CELL_PIXELS, .4 * CELL_PIXELS),
+            (.9 * CELL_PIXELS, .3 * CELL_PIXELS),
+        ])
+
+        r.drawPolyline([
+            (.1 * CELL_PIXELS, .5 * CELL_PIXELS),
+            (.3 * CELL_PIXELS, .6 * CELL_PIXELS),
+            (.5 * CELL_PIXELS, .5 * CELL_PIXELS),
+            (.7 * CELL_PIXELS, .6 * CELL_PIXELS),
+            (.9 * CELL_PIXELS, .5 * CELL_PIXELS),
+        ])
+
+        r.drawPolyline([
+            (.1 * CELL_PIXELS, .7 * CELL_PIXELS),
+            (.3 * CELL_PIXELS, .8 * CELL_PIXELS),
+            (.5 * CELL_PIXELS, .7 * CELL_PIXELS),
+            (.7 * CELL_PIXELS, .8 * CELL_PIXELS),
+            (.9 * CELL_PIXELS, .7 * CELL_PIXELS),
+        ])
+
 class Wall(WorldObj):
 class Wall(WorldObj):
     def __init__(self, color='grey'):
     def __init__(self, color='grey'):
         super().__init__('wall', color)
         super().__init__('wall', color)
@@ -571,6 +617,8 @@ class Grid:
                     v = LockedDoor(color, is_open)
                     v = LockedDoor(color, is_open)
                 elif objType == 'goal':
                 elif objType == 'goal':
                     v = Goal()
                     v = Goal()
+                elif objType == 'lava':
+                    v = Lava()
                 else:
                 else:
                     assert False, "unknown obj type in decode '%s'" % objType
                     assert False, "unknown obj type in decode '%s'" % objType
 
 
@@ -1117,6 +1165,8 @@ class MiniGridEnv(gym.Env):
             if fwd_cell != None and fwd_cell.type == 'goal':
             if fwd_cell != None and fwd_cell.type == 'goal':
                 done = True
                 done = True
                 reward = self._reward()
                 reward = self._reward()
+            if fwd_cell != None and fwd_cell.type == 'lava':
+                done = True
 
 
         # Pick up an object
         # Pick up an object
         elif action == self.actions.pickup:
         elif action == self.actions.pickup:
@@ -1327,4 +1377,4 @@ class MiniGridEnv(gym.Env):
         elif mode == 'pixmap':
         elif mode == 'pixmap':
             return r.getPixmap()
             return r.getPixmap()
 
 
-        return r
+        return r

+ 5 - 0
gym_minigrid/rendering.py

@@ -176,5 +176,10 @@ class Renderer:
         points = map(lambda p: QPoint(p[0], p[1]), points)
         points = map(lambda p: QPoint(p[0], p[1]), points)
         self.painter.drawPolygon(QPolygon(points))
         self.painter.drawPolygon(QPolygon(points))
 
 
+    def drawPolyline(self, points):
+        """Takes a list of points (tuples) as input"""
+        points = map(lambda p: QPoint(p[0], p[1]), points)
+        self.painter.drawPolyline(QPolygon(points))
+
     def fillRect(self, x, y, width, height, r, g, b, a=255):
     def fillRect(self, x, y, width, height, r, g, b, a=255):
         self.painter.fillRect(QRect(x, y, width, height), QColor(r, g, b, a))
         self.painter.fillRect(QRect(x, y, width, height), QColor(r, g, b, a))