瀏覽代碼

backward compatibility for environment metadata

saleml 2 年之前
父節點
當前提交
0832a47bc1
共有 1 個文件被更改,包括 80 次插入57 次删除
  1. 80 57
      gym_minigrid/minigrid.py

+ 80 - 57
gym_minigrid/minigrid.py

@@ -12,48 +12,48 @@ TILE_PIXELS = 32
 
 # Map of color names to RGB values
 COLORS = {
-    'red'   : np.array([255, 0, 0]),
-    'green' : np.array([0, 255, 0]),
-    'blue'  : np.array([0, 0, 255]),
+    'red': np.array([255, 0, 0]),
+    'green': np.array([0, 255, 0]),
+    'blue': np.array([0, 0, 255]),
     'purple': np.array([112, 39, 195]),
     'yellow': np.array([255, 255, 0]),
-    'grey'  : np.array([100, 100, 100])
+    'grey': np.array([100, 100, 100])
 }
 
 COLOR_NAMES = sorted(list(COLORS.keys()))
 
 # Used to map colors to integers
 COLOR_TO_IDX = {
-    'red'   : 0,
-    'green' : 1,
-    'blue'  : 2,
+    'red': 0,
+    'green': 1,
+    'blue': 2,
     'purple': 3,
     'yellow': 4,
-    'grey'  : 5
+    'grey': 5
 }
 
 IDX_TO_COLOR = dict(zip(COLOR_TO_IDX.values(), COLOR_TO_IDX.keys()))
 
 # Map of object type to integers
 OBJECT_TO_IDX = {
-    'unseen'        : 0,
-    'empty'         : 1,
-    'wall'          : 2,
-    'floor'         : 3,
-    'door'          : 4,
-    'key'           : 5,
-    'ball'          : 6,
-    'box'           : 7,
-    'goal'          : 8,
-    'lava'          : 9,
-    'agent'         : 10,
+    'unseen': 0,
+    'empty': 1,
+    'wall': 2,
+    'floor': 3,
+    'door': 4,
+    'key': 5,
+    'ball': 6,
+    'box': 7,
+    'goal': 8,
+    'lava': 9,
+    'agent': 10,
 }
 
 IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys()))
 
 # Map of state names to integers
 STATE_TO_IDX = {
-    'open'  : 0,
+    'open': 0,
     'closed': 1,
     'locked': 2,
 }
@@ -70,6 +70,7 @@ DIR_TO_VEC = [
     np.array((0, -1)),
 ]
 
+
 class WorldObj:
     """
     Base class for grid world objects
@@ -151,6 +152,7 @@ class WorldObj:
         """Draw this object with the given renderer"""
         raise NotImplementedError
 
+
 class Goal(WorldObj):
     def __init__(self):
         super().__init__('goal', 'green')
@@ -161,6 +163,7 @@ class Goal(WorldObj):
     def render(self, img):
         fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color])
 
+
 class Floor(WorldObj):
     """
     Colored floor tile the agent can walk over
@@ -195,10 +198,15 @@ class Lava(WorldObj):
         for i in range(3):
             ylo = 0.3 + 0.2 * i
             yhi = 0.4 + 0.2 * i
-            fill_coords(img, point_in_line(0.1, ylo, 0.3, yhi, r=0.03), (0,0,0))
-            fill_coords(img, point_in_line(0.3, yhi, 0.5, ylo, r=0.03), (0,0,0))
-            fill_coords(img, point_in_line(0.5, ylo, 0.7, yhi, r=0.03), (0,0,0))
-            fill_coords(img, point_in_line(0.7, yhi, 0.9, ylo, r=0.03), (0,0,0))
+            fill_coords(img, point_in_line(
+                0.1, ylo, 0.3, yhi, r=0.03), (0, 0, 0))
+            fill_coords(img, point_in_line(
+                0.3, yhi, 0.5, ylo, r=0.03), (0, 0, 0))
+            fill_coords(img, point_in_line(
+                0.5, ylo, 0.7, yhi, r=0.03), (0, 0, 0))
+            fill_coords(img, point_in_line(
+                0.7, yhi, 0.9, ylo, r=0.03), (0, 0, 0))
+
 
 class Wall(WorldObj):
     def __init__(self, color='grey'):
@@ -210,6 +218,7 @@ class Wall(WorldObj):
     def render(self, img):
         fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color])
 
+
 class Door(WorldObj):
     def __init__(self, color, is_open=False, is_locked=False):
         super().__init__('door', color)
@@ -253,25 +262,27 @@ class Door(WorldObj):
 
         if self.is_open:
             fill_coords(img, point_in_rect(0.88, 1.00, 0.00, 1.00), c)
-            fill_coords(img, point_in_rect(0.92, 0.96, 0.04, 0.96), (0,0,0))
+            fill_coords(img, point_in_rect(0.92, 0.96, 0.04, 0.96), (0, 0, 0))
             return
 
         # Door frame and door
         if self.is_locked:
             fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)
-            fill_coords(img, point_in_rect(0.06, 0.94, 0.06, 0.94), 0.45 * np.array(c))
+            fill_coords(img, point_in_rect(
+                0.06, 0.94, 0.06, 0.94), 0.45 * np.array(c))
 
             # Draw key slot
             fill_coords(img, point_in_rect(0.52, 0.75, 0.50, 0.56), c)
         else:
             fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)
-            fill_coords(img, point_in_rect(0.04, 0.96, 0.04, 0.96), (0,0,0))
+            fill_coords(img, point_in_rect(0.04, 0.96, 0.04, 0.96), (0, 0, 0))
             fill_coords(img, point_in_rect(0.08, 0.92, 0.08, 0.92), c)
-            fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), (0,0,0))
+            fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), (0, 0, 0))
 
             # Draw door handle
             fill_coords(img, point_in_circle(cx=0.75, cy=0.50, r=0.08), c)
 
+
 class Key(WorldObj):
     def __init__(self, color='blue'):
         super(Key, self).__init__('key', color)
@@ -291,7 +302,8 @@ class Key(WorldObj):
 
         # Ring
         fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.190), c)
-        fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.064), (0,0,0))
+        fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.064), (0, 0, 0))
+
 
 class Ball(WorldObj):
     def __init__(self, color='blue'):
@@ -303,6 +315,7 @@ class Ball(WorldObj):
     def render(self, img):
         fill_coords(img, point_in_circle(0.5, 0.5, 0.31), COLORS[self.color])
 
+
 class Box(WorldObj):
     def __init__(self, color, contains=None):
         super(Box, self).__init__('box', color)
@@ -316,7 +329,7 @@ class Box(WorldObj):
 
         # Outline
         fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), c)
-        fill_coords(img, point_in_rect(0.18, 0.82, 0.18, 0.82), (0,0,0))
+        fill_coords(img, point_in_rect(0.18, 0.82, 0.18, 0.82), (0, 0, 0))
 
         # Horizontal slit
         fill_coords(img, point_in_rect(0.16, 0.84, 0.47, 0.53), c)
@@ -326,6 +339,7 @@ class Box(WorldObj):
         env.grid.set(*pos, self.contains)
         return True
 
+
 class Grid:
     """
     Represent a grid and operations on it
@@ -359,7 +373,7 @@ class Grid:
         return False
 
     def __eq__(self, other):
-        grid1  = self.encode()
+        grid1 = self.encode()
         grid2 = other.encode()
         return np.array_equal(grid2, grid1)
 
@@ -454,7 +468,8 @@ class Grid:
         if key in cls.tile_cache:
             return cls.tile_cache[key]
 
-        img = np.zeros(shape=(tile_size * subdivs, tile_size * subdivs, 3), dtype=np.uint8)
+        img = np.zeros(shape=(tile_size * subdivs,
+                       tile_size * subdivs, 3), dtype=np.uint8)
 
         # Draw the grid lines (top and left edges)
         fill_coords(img, point_in_rect(0, 0.031, 0, 1), (100, 100, 100))
@@ -472,7 +487,8 @@ class Grid:
             )
 
             # Rotate the agent based on its direction
-            tri_fn = rotate_fn(tri_fn, cx=0.5, cy=0.5, theta=0.5*math.pi*agent_dir)
+            tri_fn = rotate_fn(tri_fn, cx=0.5, cy=0.5,
+                               theta=0.5*math.pi*agent_dir)
             fill_coords(img, tri_fn, (255, 0, 0))
 
         # Highlight the cell if needed
@@ -501,7 +517,8 @@ class Grid:
         """
 
         if highlight_mask is None:
-            highlight_mask = np.zeros(shape=(self.width, self.height), dtype=bool)
+            highlight_mask = np.zeros(
+                shape=(self.width, self.height), dtype=bool)
 
         # Compute the total grid size
         width_px = self.width * tile_size
@@ -615,14 +632,18 @@ class Grid:
 
         return mask
 
+
 class MiniGridEnv(gym.Env):
     """
     2D grid world game environment
     """
 
     metadata = {
+        # Deprecated: use 'render_modes' instead
+        'render.modes': ['human', 'rgb_array'],
+        'video.frames_per_second': 10,  # Deprecated: use 'render_fps' instead
         'render_modes': ['human', 'rgb_array'],
-        'render_fps' : 10
+        'render_fps': 10
     }
 
     # Enumeration of possible actions
@@ -682,7 +703,7 @@ class MiniGridEnv(gym.Env):
             'direction': spaces.Discrete(4),
             'mission': spaces.Text(max_length=200,
                                    charset=string.ascii_letters + string.digits + ' .,!- '
-                                  )
+                                   )
         })
 
         # render mode
@@ -704,7 +725,6 @@ class MiniGridEnv(gym.Env):
         self.agent_pos = None
         self.agent_dir = None
 
-
         # Initialize the state
         self.reset()
 
@@ -735,14 +755,14 @@ class MiniGridEnv(gym.Env):
         obs = self.gen_obs()
         return obs
 
-
     def hash(self, size=16):
         """Compute a hash that uniquely identifies the current state of the environment.
         :param size: Size of the hashing
         """
         sample_hash = hashlib.sha256()
 
-        to_encode = [self.grid.encode().tolist(), self.agent_pos, self.agent_dir]
+        to_encode = [self.grid.encode().tolist(), self.agent_pos,
+                     self.agent_dir]
         for item in to_encode:
             sample_hash.update(str(item).encode('utf8'))
 
@@ -761,14 +781,14 @@ class MiniGridEnv(gym.Env):
 
         # Map of object types to short string
         OBJECT_TO_STR = {
-            'wall'          : 'W',
-            'floor'         : 'F',
-            'door'          : 'D',
-            'key'           : 'K',
-            'ball'          : 'A',
-            'box'           : 'B',
-            'goal'          : 'G',
-            'lava'          : 'V',
+            'wall': 'W',
+            'floor': 'F',
+            'door': 'D',
+            'key': 'K',
+            'ball': 'A',
+            'box': 'B',
+            'goal': 'G',
+            'lava': 'V',
         }
 
         # Short string for opened door
@@ -888,12 +908,12 @@ class MiniGridEnv(gym.Env):
         )
 
     def place_obj(self,
-        obj,
-        top=None,
-        size=None,
-        reject_fn=None,
-        max_tries=math.inf
-    ):
+                  obj,
+                  top=None,
+                  size=None,
+                  reject_fn=None,
+                  max_tries=math.inf
+                  ):
         """
         Place an object at an empty position in the grid
 
@@ -1174,7 +1194,7 @@ class MiniGridEnv(gym.Env):
         """
 
         topX, topY, botX, botY = self.get_view_exts(agent_view_size)
-        
+
         agent_view_size = agent_view_size or self.agent_view_size
 
         grid = self.grid.slice(topX, topY, agent_view_size, agent_view_size)
@@ -1185,7 +1205,8 @@ class MiniGridEnv(gym.Env):
         # Process occluders and visibility
         # Note that this incurs some performance cost
         if not self.see_through_walls:
-            vis_mask = grid.process_vis(agent_pos=(agent_view_size // 2 , agent_view_size - 1))
+            vis_mask = grid.process_vis(agent_pos=(
+                agent_view_size // 2, agent_view_size - 1))
         else:
             vis_mask = np.ones(shape=(grid.width, grid.height), dtype=bool)
 
@@ -1210,7 +1231,8 @@ class MiniGridEnv(gym.Env):
         # Encode the partially observable view into a numpy array
         image = grid.encode(vis_mask)
 
-        assert hasattr(self, 'mission'), "environments must define a textual mission string"
+        assert hasattr(
+            self, 'mission'), "environments must define a textual mission string"
 
         # Observations are dictionaries containing:
         # - an image (partially observable view of the environment)
@@ -1264,7 +1286,8 @@ class MiniGridEnv(gym.Env):
         # of the agent's view area
         f_vec = self.dir_vec
         r_vec = self.right_vec
-        top_left = self.agent_pos + f_vec * (self.agent_view_size-1) - r_vec * (self.agent_view_size // 2)
+        top_left = self.agent_pos + f_vec * \
+            (self.agent_view_size-1) - r_vec * (self.agent_view_size // 2)
 
         # Mask of which cells to highlight
         highlight_mask = np.zeros(shape=(self.width, self.height), dtype=bool)