Bladeren bron

Implemented LavaGap env (simpler LavaCrossing). Added screenshot command.

Maxime Chevalier-Boisvert 5 jaren geleden
bovenliggende
commit
32a5b207f7
4 gewijzigde bestanden met toevoegingen van 93 en 4 verwijderingen
  1. 1 0
      gym_minigrid/envs/__init__.py
  2. 80 0
      gym_minigrid/envs/lavagap.py
  3. 4 4
      gym_minigrid/minigrid.py
  4. 8 0
      manual_control.py

+ 1 - 0
gym_minigrid/envs/__init__.py

@@ -16,5 +16,6 @@ from gym_minigrid.envs.obstructedmaze import *
 from gym_minigrid.envs.memory import *
 from gym_minigrid.envs.fourrooms import *
 from gym_minigrid.envs.crossing import *
+from gym_minigrid.envs.lavagap import *
 from gym_minigrid.envs.dynamicobstacles import *
 from gym_minigrid.envs.distshift import *

+ 80 - 0
gym_minigrid/envs/lavagap.py

@@ -0,0 +1,80 @@
+from gym_minigrid.minigrid import *
+from gym_minigrid.register import register
+
+class LavaGapEnv(MiniGridEnv):
+    """
+    Environment with one wall of lava with a small gap to cross through
+    This environment is similar to LavaCrossing but simpler in structure.
+    """
+
+    def __init__(self, size, obstacle_type=Lava, seed=None):
+        self.obstacle_type = obstacle_type
+        super().__init__(
+            grid_size=size,
+            max_steps=4*size*size,
+            # Set this to True for maximum speed
+            see_through_walls=False,
+            seed=None
+        )
+
+    def _gen_grid(self, width, height):
+        assert width >= 5 and height >= 5
+
+        # Create an empty grid
+        self.grid = Grid(width, height)
+
+        # Generate the surrounding walls
+        self.grid.wall_rect(0, 0, width, height)
+
+        # Place the agent in the top-left corner
+        self.agent_pos = (1, 1)
+        self.agent_dir = 0
+
+        # Place a goal square in the bottom-right corner
+        self.goal_pos = np.array((width - 2, height - 2))
+        self.grid.set(*self.goal_pos, Goal())
+
+        # Generate and store random gap position
+        self.gap_pos = np.array((
+            self._rand_int(2, width - 2),
+            self._rand_int(1, height - 1),
+        ))
+
+        # Place the obstacle wall
+        self.grid.vert_wall(self.gap_pos[0], 1, height - 2, self.obstacle_type)
+
+        # Put a hole in the wall
+        self.grid.set(*self.gap_pos, None)
+
+        self.mission = (
+            "avoid the lava and get to the green goal square"
+            if self.obstacle_type == Lava
+            else "find the opening and get to the green goal square"
+        )
+
+class LavaGapS5Env(LavaGapEnv):
+    def __init__(self):
+        super().__init__(size=5)
+
+class LavaGapS6Env(LavaGapEnv):
+    def __init__(self):
+        super().__init__(size=6)
+
+class LavaGapS7Env(LavaGapEnv):
+    def __init__(self):
+        super().__init__(size=7)
+
+register(
+    id='MiniGrid-LavaGapS5-v0',
+    entry_point='gym_minigrid.envs:LavaGapS5Env'
+)
+
+register(
+    id='MiniGrid-LavaGapS6-v0',
+    entry_point='gym_minigrid.envs:LavaGapS6Env'
+)
+
+register(
+    id='MiniGrid-LavaGapS7-v0',
+    entry_point='gym_minigrid.envs:LavaGapS7Env'
+)

+ 4 - 4
gym_minigrid/minigrid.py

@@ -414,17 +414,17 @@ class Grid:
         assert j >= 0 and j < self.height
         return self.grid[j * self.width + i]
 
-    def horz_wall(self, x, y, length=None):
+    def horz_wall(self, x, y, length=None, obj_type=Wall):
         if length is None:
             length = self.width - x
         for i in range(0, length):
-            self.set(x + i, y, Wall())
+            self.set(x + i, y, obj_type())
 
-    def vert_wall(self, x, y, length=None):
+    def vert_wall(self, x, y, length=None, obj_type=Wall):
         if length is None:
             length = self.height - y
         for j in range(0, length):
-            self.set(x, y + j, Wall())
+            self.set(x, y + j, obj_type())
 
     def wall_rect(self, x, y, w, h):
         self.horz_wall(x, y, w)

+ 8 - 0
manual_control.py

@@ -61,6 +61,14 @@ def main():
         elif keyName == 'RETURN':
             action = env.actions.done
 
+        # Screenshot funcitonality
+        elif keyName == 'ALT':
+            screen_path = options.env_name + '.png'
+            print('saving screenshot "{}"'.format(screen_path))
+            pixmap = env.render('pixmap')
+            pixmap.save(screen_path)
+            return
+
         else:
             print("unknown key %s" % keyName)
             return