Pārlūkot izejas kodu

Added encode and decode methods to WorldObj

Maxime Chevalier-Boisvert 5 gadi atpakaļ
vecāks
revīzija
638d5d21d5
3 mainītis faili ar 193 papildinājumiem un 107 dzēšanām
  1. 142 106
      gym_minigrid/minigrid.py
  2. 51 0
      gym_minigrid/rendering.py
  3. 0 1
      gym_minigrid/wrappers.py

+ 142 - 106
gym_minigrid/minigrid.py

@@ -5,8 +5,8 @@ import numpy as np
 from gym import error, spaces, utils
 from gym.utils import seeding
 
-# Size in pixels of a cell in the full-scale human view
-CELL_PIXELS = 32
+# Size in pixels of a tile in the full-scale human view
+TILE_PIXELS = 32
 
 # Map of color names to RGB values
 COLORS = {
@@ -106,6 +106,44 @@ class WorldObj:
         """Method to trigger/toggle an action this object performs"""
         return False
 
+    def encode(self):
+        """Encode the a description of this object as a 3-tuple of integers"""
+        return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], 0)
+
+    def decode(type_idx, color_idx, state):
+        """Create an object from a 3-tuple state description"""
+
+        obj_type = IDX_TO_OBJECT[type_idx]
+        color = IDX_TO_COLOR[color_idx]
+
+        if obj_type == 'empty' or obj_type == 'unseen':
+            return None
+
+        # State, 0: open, 1: closed, 2: locked
+        is_open = state == 0
+        is_locked = state == 2
+
+        if obj_type == 'wall':
+            v = Wall(color)
+        elif obj_type == 'floor':
+            v = Floor(color)
+        elif obj_type == 'ball':
+            v = Ball(color)
+        elif obj_type == 'key':
+            v = Key(color)
+        elif obj_type == 'box':
+            v = Box(color)
+        elif obj_type == 'door':
+            v = Door(color, is_open, is_locked)
+        elif obj_type == 'goal':
+            v = Goal()
+        elif obj_type == 'lava':
+            v = Lava()
+        else:
+            assert False, "unknown object type in decode '%s'" % objType
+
+        return v
+
     def render(self, r):
         """Draw this object with the given renderer"""
         raise NotImplementedError
@@ -126,9 +164,9 @@ class Goal(WorldObj):
     def render(self, r):
         self._set_color(r)
         r.drawPolygon([
-            (0          , CELL_PIXELS),
-            (CELL_PIXELS, CELL_PIXELS),
-            (CELL_PIXELS,           0),
+            (0          , TILE_PIXELS),
+            (TILE_PIXELS, TILE_PIXELS),
+            (TILE_PIXELS,           0),
             (0          ,           0)
         ])
 
@@ -149,9 +187,9 @@ class Floor(WorldObj):
         r.setLineColor(100, 100, 100, 0)
         r.setColor(*c/2)
         r.drawPolygon([
-            (1          , CELL_PIXELS),
-            (CELL_PIXELS, CELL_PIXELS),
-            (CELL_PIXELS,           1),
+            (1          , TILE_PIXELS),
+            (TILE_PIXELS, TILE_PIXELS),
+            (TILE_PIXELS,           1),
             (1          ,           1)
         ])
 
@@ -167,9 +205,9 @@ class Lava(WorldObj):
         r.setLineColor(*orange)
         r.setColor(*orange)
         r.drawPolygon([
-            (0          , CELL_PIXELS),
-            (CELL_PIXELS, CELL_PIXELS),
-            (CELL_PIXELS, 0),
+            (0          , TILE_PIXELS),
+            (TILE_PIXELS, TILE_PIXELS),
+            (TILE_PIXELS, 0),
             (0          , 0)
         ])
 
@@ -177,27 +215,27 @@ class Lava(WorldObj):
         r.setLineColor(0, 0, 0)
 
         r.drawPolyline([
-            (.1 * CELL_PIXELS, .3 * CELL_PIXELS),
-            (.3 * CELL_PIXELS, .4 * CELL_PIXELS),
-            (.5 * CELL_PIXELS, .3 * CELL_PIXELS),
-            (.7 * CELL_PIXELS, .4 * CELL_PIXELS),
-            (.9 * CELL_PIXELS, .3 * CELL_PIXELS),
+            (.1 * TILE_PIXELS, .3 * TILE_PIXELS),
+            (.3 * TILE_PIXELS, .4 * TILE_PIXELS),
+            (.5 * TILE_PIXELS, .3 * TILE_PIXELS),
+            (.7 * TILE_PIXELS, .4 * TILE_PIXELS),
+            (.9 * TILE_PIXELS, .3 * TILE_PIXELS),
         ])
 
         r.drawPolyline([
-            (.1 * CELL_PIXELS, .5 * CELL_PIXELS),
-            (.3 * CELL_PIXELS, .6 * CELL_PIXELS),
-            (.5 * CELL_PIXELS, .5 * CELL_PIXELS),
-            (.7 * CELL_PIXELS, .6 * CELL_PIXELS),
-            (.9 * CELL_PIXELS, .5 * CELL_PIXELS),
+            (.1 * TILE_PIXELS, .5 * TILE_PIXELS),
+            (.3 * TILE_PIXELS, .6 * TILE_PIXELS),
+            (.5 * TILE_PIXELS, .5 * TILE_PIXELS),
+            (.7 * TILE_PIXELS, .6 * TILE_PIXELS),
+            (.9 * TILE_PIXELS, .5 * TILE_PIXELS),
         ])
 
         r.drawPolyline([
-            (.1 * CELL_PIXELS, .7 * CELL_PIXELS),
-            (.3 * CELL_PIXELS, .8 * CELL_PIXELS),
-            (.5 * CELL_PIXELS, .7 * CELL_PIXELS),
-            (.7 * CELL_PIXELS, .8 * CELL_PIXELS),
-            (.9 * CELL_PIXELS, .7 * CELL_PIXELS),
+            (.1 * TILE_PIXELS, .7 * TILE_PIXELS),
+            (.3 * TILE_PIXELS, .8 * TILE_PIXELS),
+            (.5 * TILE_PIXELS, .7 * TILE_PIXELS),
+            (.7 * TILE_PIXELS, .8 * TILE_PIXELS),
+            (.9 * TILE_PIXELS, .7 * TILE_PIXELS),
         ])
 
 class Wall(WorldObj):
@@ -210,9 +248,9 @@ class Wall(WorldObj):
     def render(self, r):
         self._set_color(r)
         r.drawPolygon([
-            (0          , CELL_PIXELS),
-            (CELL_PIXELS, CELL_PIXELS),
-            (CELL_PIXELS,           0),
+            (0          , TILE_PIXELS),
+            (TILE_PIXELS, TILE_PIXELS),
+            (TILE_PIXELS,           0),
             (0          ,           0)
         ])
 
@@ -241,6 +279,19 @@ class Door(WorldObj):
         self.is_open = not self.is_open
         return True
 
+    def encode(self):
+        """Encode the a description of this object as a 3-tuple of integers"""
+
+        # State, 0: open, 1: closed, 2: locked
+        if self.is_open:
+            state = 0
+        elif self.is_locked:
+            state = 2
+        elif not self.is_open:
+            state = 1
+
+        return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], state)
+
     def render(self, r):
         c = COLORS[self.color]
         r.setLineColor(c[0], c[1], c[2])
@@ -248,37 +299,37 @@ class Door(WorldObj):
 
         if self.is_open:
             r.drawPolygon([
-                (CELL_PIXELS-2, CELL_PIXELS),
-                (CELL_PIXELS  , CELL_PIXELS),
-                (CELL_PIXELS  ,           0),
-                (CELL_PIXELS-2,           0)
+                (TILE_PIXELS-2, TILE_PIXELS),
+                (TILE_PIXELS  , TILE_PIXELS),
+                (TILE_PIXELS  ,           0),
+                (TILE_PIXELS-2,           0)
             ])
             return
 
         r.drawPolygon([
-            (0          , CELL_PIXELS),
-            (CELL_PIXELS, CELL_PIXELS),
-            (CELL_PIXELS,           0),
+            (0          , TILE_PIXELS),
+            (TILE_PIXELS, TILE_PIXELS),
+            (TILE_PIXELS,           0),
             (0          ,           0)
         ])
         r.drawPolygon([
-            (2            , CELL_PIXELS-2),
-            (CELL_PIXELS-2, CELL_PIXELS-2),
-            (CELL_PIXELS-2,           2),
+            (2            , TILE_PIXELS-2),
+            (TILE_PIXELS-2, TILE_PIXELS-2),
+            (TILE_PIXELS-2,           2),
             (2            ,           2)
         ])
 
         if self.is_locked:
             # Draw key slot
             r.drawLine(
-                CELL_PIXELS * 0.55,
-                CELL_PIXELS * 0.5,
-                CELL_PIXELS * 0.75,
-                CELL_PIXELS * 0.5
+                TILE_PIXELS * 0.55,
+                TILE_PIXELS * 0.5,
+                TILE_PIXELS * 0.75,
+                TILE_PIXELS * 0.5
             )
         else:
             # Draw door handle
-            r.drawCircle(CELL_PIXELS * 0.75, CELL_PIXELS * 0.5, 2)
+            r.drawCircle(TILE_PIXELS * 0.75, TILE_PIXELS * 0.5, 2)
 
 class Key(WorldObj):
     def __init__(self, color='blue'):
@@ -326,7 +377,7 @@ class Ball(WorldObj):
 
     def render(self, r):
         self._set_color(r)
-        r.drawCircle(CELL_PIXELS * 0.5, CELL_PIXELS * 0.5, 10)
+        r.drawCircle(TILE_PIXELS * 0.5, TILE_PIXELS * 0.5, 10)
 
 class Box(WorldObj):
     def __init__(self, color, contains=None):
@@ -343,17 +394,17 @@ class Box(WorldObj):
         r.setLineWidth(2)
 
         r.drawPolygon([
-            (4            , CELL_PIXELS-4),
-            (CELL_PIXELS-4, CELL_PIXELS-4),
-            (CELL_PIXELS-4,             4),
+            (4            , TILE_PIXELS-4),
+            (TILE_PIXELS-4, TILE_PIXELS-4),
+            (TILE_PIXELS-4,             4),
             (4            ,             4)
         ])
 
         r.drawLine(
             4,
-            CELL_PIXELS / 2,
-            CELL_PIXELS - 4,
-            CELL_PIXELS / 2
+            TILE_PIXELS / 2,
+            TILE_PIXELS - 4,
+            TILE_PIXELS / 2
         )
 
         r.setLineWidth(1)
@@ -363,6 +414,27 @@ class Box(WorldObj):
         env.grid.set(*pos, self.contains)
         return True
 
+
+
+
+
+
+def render_tile(
+    obj,
+    agent_dir=None,
+    highlight=False,
+    tile_size=TILE_PIXELS
+):
+    """
+    Render a tile and cache the result
+    """
+
+    pass
+
+
+
+
+
 class Grid:
     """
     Represent a grid and operations on it
@@ -479,14 +551,14 @@ class Grid:
         assert r.height == self.height * tile_size
 
         # Total grid size at native scale
-        widthPx = self.width * CELL_PIXELS
-        heightPx = self.height * CELL_PIXELS
+        widthPx = self.width * TILE_PIXELS
+        heightPx = self.height * TILE_PIXELS
 
         r.push()
 
         # Internally, we draw at the "large" full-grid resolution, but we
         # use the renderer to scale back to the desired size
-        r.scale(tile_size / CELL_PIXELS, tile_size / CELL_PIXELS)
+        r.scale(tile_size / TILE_PIXELS, tile_size / TILE_PIXELS)
 
         # Draw the background of the in-world cells black
         r.fillRect(
@@ -500,10 +572,10 @@ class Grid:
         # Draw grid lines
         r.setLineColor(100, 100, 100)
         for rowIdx in range(0, self.height):
-            y = CELL_PIXELS * rowIdx
+            y = TILE_PIXELS * rowIdx
             r.drawLine(0, y, widthPx, y)
         for colIdx in range(0, self.width):
-            x = CELL_PIXELS * colIdx
+            x = TILE_PIXELS * colIdx
             r.drawLine(x, 0, x, heightPx)
 
         # Render the grid
@@ -513,7 +585,7 @@ class Grid:
                 if cell == None:
                     continue
                 r.push()
-                r.translate(i * CELL_PIXELS, j * CELL_PIXELS)
+                r.translate(i * TILE_PIXELS, j * TILE_PIXELS)
                 cell.render(r)
                 r.pop()
 
@@ -528,6 +600,7 @@ class Grid:
             vis_mask = np.ones((self.width, self.height), dtype=bool)
 
         array = np.zeros((self.width, self.height, 3), dtype='uint8')
+
         for i in range(self.width):
             for j in range(self.height):
                 if vis_mask[i, j]:
@@ -537,17 +610,9 @@ class Grid:
                         array[i, j, 0] = OBJECT_TO_IDX['empty']
                         array[i, j, 1] = 0
                         array[i, j, 2] = 0
-                    else:
-                        # State, 0: open, 1: closed, 2: locked
-                        state = 0
-                        if hasattr(v, 'is_open') and not v.is_open:
-                            state = 1
-                        if hasattr(v, 'is_locked') and v.is_locked:
-                            state = 2
 
-                        array[i, j, 0] = OBJECT_TO_IDX[v.type]
-                        array[i, j, 1] = COLOR_TO_IDX[v.color]
-                        array[i, j, 2] = state
+                    else:
+                        array[i, j, :] = v.encode()
 
         return array
 
@@ -563,37 +628,8 @@ class Grid:
         grid = Grid(width, height)
         for i in range(width):
             for j in range(height):
-                typeIdx, colorIdx, state = array[i, j]
-
-                if typeIdx == OBJECT_TO_IDX['unseen'] or \
-                        typeIdx == OBJECT_TO_IDX['empty']:
-                    continue
-
-                objType = IDX_TO_OBJECT[typeIdx]
-                color = IDX_TO_COLOR[colorIdx]
-                # State, 0: open, 1: closed, 2: locked
-                is_open = state == 0
-                is_locked = state == 2
-
-                if objType == 'wall':
-                    v = Wall(color)
-                elif objType == 'floor':
-                    v = Floor(color)
-                elif objType == 'ball':
-                    v = Ball(color)
-                elif objType == 'key':
-                    v = Key(color)
-                elif objType == 'box':
-                    v = Box(color)
-                elif objType == 'door':
-                    v = Door(color, is_open, is_locked)
-                elif objType == 'goal':
-                    v = Goal()
-                elif objType == 'lava':
-                    v = Lava()
-                else:
-                    assert False, "unknown obj type in decode '%s'" % objType
-
+                type_idx, color_idx, state = array[i, j]
+                v = WorldObj.decode(type_idx, color_idx, state)
                 grid.set(i, j, v)
 
         return grid
@@ -1229,7 +1265,7 @@ class MiniGridEnv(gym.Env):
 
         return obs
 
-    def get_obs_render(self, obs, tile_size=CELL_PIXELS//2, mode='pixmap'):
+    def get_obs_render(self, obs, tile_size=TILE_PIXELS//2, mode='pixmap'):
         """
         Render an agent observation for visualization
         """
@@ -1251,12 +1287,12 @@ class MiniGridEnv(gym.Env):
         grid.render(r, tile_size)
 
         # Draw the agent
-        ratio = tile_size / CELL_PIXELS
+        ratio = tile_size / TILE_PIXELS
         r.push()
         r.scale(ratio, ratio)
         r.translate(
-            CELL_PIXELS * (0.5 + self.agent_view_size // 2),
-            CELL_PIXELS * (self.agent_view_size - 0.5)
+            TILE_PIXELS * (0.5 + self.agent_view_size // 2),
+            TILE_PIXELS * (self.agent_view_size - 0.5)
         )
         r.rotate(3 * 90)
         r.setLineColor(255, 0, 0)
@@ -1276,7 +1312,7 @@ class MiniGridEnv(gym.Env):
             return r.getPixmap()
         return r
 
-    def render(self, mode='human', close=False, highlight=True, tile_size=CELL_PIXELS):
+    def render(self, mode='human', close=False, highlight=True, tile_size=TILE_PIXELS):
         """
         Render the whole-grid human view
         """
@@ -1305,12 +1341,12 @@ class MiniGridEnv(gym.Env):
         self.grid.render(r, tile_size)
 
         # Draw the agent
-        ratio = tile_size / CELL_PIXELS
+        ratio = tile_size / TILE_PIXELS
         r.push()
         r.scale(ratio, ratio)
         r.translate(
-            CELL_PIXELS * (self.agent_pos[0] + 0.5),
-            CELL_PIXELS * (self.agent_pos[1] + 0.5)
+            TILE_PIXELS * (self.agent_pos[0] + 0.5),
+            TILE_PIXELS * (self.agent_pos[1] + 0.5)
         )
         r.rotate(self.agent_dir * 90)
         r.setLineColor(255, 0, 0)

+ 51 - 0
gym_minigrid/rendering.py

@@ -193,3 +193,54 @@ class Renderer:
 
     def fillRect(self, x, y, width, height, r, g, b, a=255):
         self.painter.fillRect(QRect(x, y, width, height), QColor(r, g, b, a))
+
+
+
+
+
+
+def point_in_circle(cx, cy, r):
+    def fn(x, y):
+        return (x-cx)*(x-cx) + (y-cy)*(y-cy) < r * r
+    return fn
+
+def point_in_rect(cx, cy, rx, ry):
+    def fn(x, y):
+        return abs(x-cx) < rx and abs(y - cy) < ry
+    return fn
+
+def point_in_triangle(a, b, c):
+    a = np.array(a)
+    b = np.array(b)
+    c = np.array(c)
+
+    def fn(x, y):
+        v0 = c - a
+        v1 = b - a
+        v2 = np.array((x, y)) - a
+
+        # Compute dot products
+        dot00 = np.dot(v0, v0)
+        dot01 = np.dot(v0, v1)
+        dot02 = np.dot(v0, v2)
+        dot11 = np.dot(v1, v1)
+        dot12 = np.dot(v1, v2)
+
+        # Compute barycentric coordinates
+        inv_denom = 1 / (dot00 * dot11 - dot01 * dot01)
+        u = (dot11 * dot02 - dot01 * dot12) * inv_denom
+        v = (dot00 * dot12 - dot01 * dot02) * inv_denom
+
+        # Check if point is in triangle
+        return (u >= 0) and (v >= 0) and (u + v) < 1
+
+    return fn
+
+# TODO: anti-aliased version, fill_coords_aa?
+def fill_coords(img, fn, color):
+    for y in range(img.shape[0]):
+        for x in range(img.shape[1]):
+            yf = y / img.shape[0]
+            xf = x / img.shape[1]
+            if fn(xf, yf):
+                img[y, x] = color

+ 0 - 1
gym_minigrid/wrappers.py

@@ -6,7 +6,6 @@ import numpy as np
 import gym
 from gym import error, spaces, utils
 from .minigrid import OBJECT_TO_IDX, COLOR_TO_IDX, STATE_TO_IDX
-from .minigrid import CELL_PIXELS
 
 class ReseedWrapper(gym.core.Wrapper):
     """