浏览代码

Sketched tile caching. Removed PyQT dependencies.

Maxime Chevalier-Boisvert 5 年之前
父节点
当前提交
c1d3cc5835
共有 3 个文件被更改,包括 87 次插入307 次删除
  1. 76 103
      gym_minigrid/minigrid.py
  2. 10 203
      gym_minigrid/rendering.py
  3. 1 1
      setup.py

+ 76 - 103
gym_minigrid/minigrid.py

@@ -4,6 +4,7 @@ from enum import IntEnum
 import numpy as np
 import numpy as np
 from gym import error, spaces, utils
 from gym import error, spaces, utils
 from gym.utils import seeding
 from gym.utils import seeding
+from .rendering import *
 
 
 # Size in pixels of a tile in the full-scale human view
 # Size in pixels of a tile in the full-scale human view
 TILE_PIXELS = 32
 TILE_PIXELS = 32
@@ -110,6 +111,7 @@ class WorldObj:
         """Encode the a description of this object as a 3-tuple of integers"""
         """Encode the a description of this object as a 3-tuple of integers"""
         return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], 0)
         return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], 0)
 
 
+    @staticmethod
     def decode(type_idx, color_idx, state):
     def decode(type_idx, color_idx, state):
         """Create an object from a 3-tuple state description"""
         """Create an object from a 3-tuple state description"""
 
 
@@ -148,12 +150,6 @@ class WorldObj:
         """Draw this object with the given renderer"""
         """Draw this object with the given renderer"""
         raise NotImplementedError
         raise NotImplementedError
 
 
-    def _set_color(self, r):
-        """Set the color of this object as the active drawing color"""
-        c = COLORS[self.color]
-        r.setLineColor(c[0], c[1], c[2])
-        r.setColor(c[0], c[1], c[2])
-
 class Goal(WorldObj):
 class Goal(WorldObj):
     def __init__(self):
     def __init__(self):
         super().__init__('goal', 'green')
         super().__init__('goal', 'green')
@@ -162,13 +158,7 @@ class Goal(WorldObj):
         return True
         return True
 
 
     def render(self, r):
     def render(self, r):
-        self._set_color(r)
-        r.drawPolygon([
-            (0          , TILE_PIXELS),
-            (TILE_PIXELS, TILE_PIXELS),
-            (TILE_PIXELS,           0),
-            (0          ,           0)
-        ])
+        fill_coords(img, point_in_rect(0.5, 0.5, 0.5, 0.5), COLORS[self.color])
 
 
 class Floor(WorldObj):
 class Floor(WorldObj):
     """
     """
@@ -246,13 +236,7 @@ class Wall(WorldObj):
         return False
         return False
 
 
     def render(self, r):
     def render(self, r):
-        self._set_color(r)
-        r.drawPolygon([
-            (0          , TILE_PIXELS),
-            (TILE_PIXELS, TILE_PIXELS),
-            (TILE_PIXELS,           0),
-            (0          ,           0)
-        ])
+        fill_coords(img, point_in_rect(0.5, 0.5, 0.5, 0.5), COLORS[self.color])
 
 
 class Door(WorldObj):
 class Door(WorldObj):
     def __init__(self, color, is_open=False, is_locked=False):
     def __init__(self, color, is_open=False, is_locked=False):
@@ -375,9 +359,8 @@ class Ball(WorldObj):
     def can_pickup(self):
     def can_pickup(self):
         return True
         return True
 
 
-    def render(self, r):
-        self._set_color(r)
-        r.drawCircle(TILE_PIXELS * 0.5, TILE_PIXELS * 0.5, 10)
+    def render(self, img):
+        fill_coords(img, point_in_circle(0.5, 0.5, 0.31), COLORS[self.color])
 
 
 class Box(WorldObj):
 class Box(WorldObj):
     def __init__(self, color, contains=None):
     def __init__(self, color, contains=None):
@@ -414,32 +397,14 @@ class Box(WorldObj):
         env.grid.set(*pos, self.contains)
         env.grid.set(*pos, self.contains)
         return True
         return True
 
 
-
-
-
-
-
-def render_tile(
-    obj,
-    agent_dir=None,
-    highlight=False,
-    tile_size=TILE_PIXELS
-):
-    """
-    Render a tile and cache the result
-    """
-
-    pass
-
-
-
-
-
 class Grid:
 class Grid:
     """
     """
     Represent a grid and operations on it
     Represent a grid and operations on it
     """
     """
 
 
+    # Static cache of pre-renderer tiles
+    tile_cache = {}
+
     def __init__(self, width, height):
     def __init__(self, width, height):
         assert width >= 3
         assert width >= 3
         assert height >= 3
         assert height >= 3
@@ -540,35 +505,55 @@ class Grid:
 
 
         return grid
         return grid
 
 
-    def render(self, r, tile_size):
+    @classmethod
+    def render_tile(
+        cls,
+        obj,
+        agent_dir=None,
+        highlight=False,
+        tile_size=TILE_PIXELS
+    ):
+        """
+        Render a tile and cache the result
+        """
+
+        # Hash map lookup key for the cache
+        key = obj.encode() + (agent_dir, highlight, tile_size)
+
+        if key in cls.tile_cache:
+            return tile_cache[key]
+
+        img = np.zeros(shape=(tile_size, tile_size, 3), dtype=np.uint8)
+
+        obj.render_tile(img)
+
+        # TODO: overlay agent on top
+        if agent_dir is not None:
+            pass
+
+        # TODO: highlighting
+        if highlight:
+            pass
+
+        # Cache the rendered tile
+        tile_cache[key] = img
+
+        return img
+
+    def render(self, tile_size):
         """
         """
         Render this grid at a given scale
         Render this grid at a given scale
         :param r: target renderer object
         :param r: target renderer object
         :param tile_size: tile size in pixels
         :param tile_size: tile size in pixels
         """
         """
 
 
-        assert r.width == self.width * tile_size
-        assert r.height == self.height * tile_size
+        # Compute the total grid size
+        width_px = self.width * TILE_PIXELS
+        height_px = self.height * TILE_PIXELS
 
 
-        # Total grid size at native scale
-        widthPx = self.width * TILE_PIXELS
-        heightPx = self.height * TILE_PIXELS
-
-        r.push()
-
-        # Internally, we draw at the "large" full-grid resolution, but we
-        # use the renderer to scale back to the desired size
-        r.scale(tile_size / TILE_PIXELS, tile_size / TILE_PIXELS)
-
-        # Draw the background of the in-world cells black
-        r.fillRect(
-            0,
-            0,
-            widthPx,
-            heightPx,
-            0, 0, 0
-        )
+        img = np.zeros(shape=(height_px, width_px, 3), dtype=np.uint8)
 
 
+        """
         # Draw grid lines
         # Draw grid lines
         r.setLineColor(100, 100, 100)
         r.setLineColor(100, 100, 100)
         for rowIdx in range(0, self.height):
         for rowIdx in range(0, self.height):
@@ -577,6 +562,7 @@ class Grid:
         for colIdx in range(0, self.width):
         for colIdx in range(0, self.width):
             x = TILE_PIXELS * colIdx
             x = TILE_PIXELS * colIdx
             r.drawLine(x, 0, x, heightPx)
             r.drawLine(x, 0, x, heightPx)
+        """
 
 
         # Render the grid
         # Render the grid
         for j in range(0, self.height):
         for j in range(0, self.height):
@@ -584,12 +570,18 @@ class Grid:
                 cell = self.get(i, j)
                 cell = self.get(i, j)
                 if cell == None:
                 if cell == None:
                     continue
                     continue
-                r.push()
-                r.translate(i * TILE_PIXELS, j * TILE_PIXELS)
-                cell.render(r)
-                r.pop()
 
 
-        r.pop()
+                """
+                tile_img = Grid.render_tile(
+                    cell,
+                    agent_dir=None,
+                    highlight=False,
+                    tile_size=tile_size
+                )
+                """
+
+
+        return img
 
 
     def encode(self, vis_mask=None):
     def encode(self, vis_mask=None):
         """
         """
@@ -1270,22 +1262,12 @@ class MiniGridEnv(gym.Env):
         Render an agent observation for visualization
         Render an agent observation for visualization
         """
         """
 
 
-        if self.obs_render == None:
-            from gym_minigrid.rendering import Renderer
-            self.obs_render = Renderer(
-                self.agent_view_size * tile_size,
-                self.agent_view_size * tile_size
-            )
-
-        r = self.obs_render
-
-        r.beginFrame()
-
         grid = Grid.decode(obs)
         grid = Grid.decode(obs)
 
 
         # Render the whole grid
         # Render the whole grid
-        grid.render(r, tile_size)
+        img = grid.render(r, tile_size)
 
 
+        """
         # Draw the agent
         # Draw the agent
         ratio = tile_size / TILE_PIXELS
         ratio = tile_size / TILE_PIXELS
         r.push()
         r.push()
@@ -1304,42 +1286,29 @@ class MiniGridEnv(gym.Env):
         ])
         ])
         r.pop()
         r.pop()
 
 
-        r.endFrame()
-
         if mode == 'rgb_array':
         if mode == 'rgb_array':
             return r.getArray()
             return r.getArray()
         elif mode == 'pixmap':
         elif mode == 'pixmap':
             return r.getPixmap()
             return r.getPixmap()
         return r
         return r
+        """
 
 
     def render(self, mode='human', close=False, highlight=True, tile_size=TILE_PIXELS):
     def render(self, mode='human', close=False, highlight=True, tile_size=TILE_PIXELS):
         """
         """
         Render the whole-grid human view
         Render the whole-grid human view
         """
         """
 
 
+        """
         if close:
         if close:
             if self.grid_render:
             if self.grid_render:
                 self.grid_render.close()
                 self.grid_render.close()
             return
             return
-
-        if self.grid_render is None or self.grid_render.window is None or (self.grid_render.width != self.width * tile_size):
-            from gym_minigrid.rendering import Renderer
-            self.grid_render = Renderer(
-                self.width * tile_size,
-                self.height * tile_size,
-                True if mode == 'human' else False
-            )
-
-        r = self.grid_render
-
-        if r.window:
-            r.window.setText(self.mission)
-
-        r.beginFrame()
+        """
 
 
         # Render the whole grid
         # Render the whole grid
-        self.grid.render(r, tile_size)
+        img = self.grid.render(tile_size)
 
 
+        """
         # Draw the agent
         # Draw the agent
         ratio = tile_size / TILE_PIXELS
         ratio = tile_size / TILE_PIXELS
         r.push()
         r.push()
@@ -1357,7 +1326,9 @@ class MiniGridEnv(gym.Env):
             (-12, -10)
             (-12, -10)
         ])
         ])
         r.pop()
         r.pop()
+        """
 
 
+        """
         # Compute which cells are visible to the agent
         # Compute which cells are visible to the agent
         _, vis_mask = self.gen_obs_grid()
         _, vis_mask = self.gen_obs_grid()
 
 
@@ -1386,11 +1357,13 @@ class MiniGridEnv(gym.Env):
                         tile_size,
                         tile_size,
                         255, 255, 255, 75
                         255, 255, 255, 75
                     )
                     )
+        """
 
 
-        r.endFrame()
-
+        """
         if mode == 'rgb_array':
         if mode == 'rgb_array':
             return r.getArray()
             return r.getArray()
         elif mode == 'pixmap':
         elif mode == 'pixmap':
             return r.getPixmap()
             return r.getPixmap()
-        return r
+        """
+
+        return img

+ 10 - 203
gym_minigrid/rendering.py

@@ -1,203 +1,19 @@
 import numpy as np
 import numpy as np
-from PyQt5.QtCore import Qt
-from PyQt5.QtGui import QImage, QPixmap, QPainter, QColor, QPolygon
-from PyQt5.QtCore import QPoint, QSize, QRect
-from PyQt5.QtWidgets import QApplication, QMainWindow, QWidget, QTextEdit
-from PyQt5.QtWidgets import QHBoxLayout, QVBoxLayout, QLabel, QFrame
 
 
-class Window(QMainWindow):
+# TODO: anti-aliased version, fill_coords_aa?
+def fill_coords(img, fn, color):
     """
     """
-    Simple application window to render the environment into
+    Fill pixels of an image with coordinates matching a filter function
     """
     """
 
 
-    def __init__(self):
-        super().__init__()
-
-        self.setWindowTitle('MiniGrid Gym Environment')
-
-        # Image label to display the rendering
-        self.imgLabel = QLabel()
-        self.imgLabel.setFrameStyle(QFrame.Panel | QFrame.Sunken)
-
-        # Text box for the mission
-        self.missionBox = QTextEdit()
-        self.missionBox.setReadOnly(True)
-        self.missionBox.setMinimumSize(400, 100)
-
-        # Center the image
-        hbox = QHBoxLayout()
-        hbox.addStretch(1)
-        hbox.addWidget(self.imgLabel)
-        hbox.addStretch(1)
-
-        # Arrange widgets vertically
-        vbox = QVBoxLayout()
-        vbox.addLayout(hbox)
-        vbox.addWidget(self.missionBox)
-
-        # Create a main widget for the window
-        mainWidget = QWidget(self)
-        self.setCentralWidget(mainWidget)
-        mainWidget.setLayout(vbox)
-
-        # Show the application window
-        self.show()
-        self.setFocus()
-
-        self.closed = False
-
-        # Callback for keyboard events
-        self.keyDownCb = None
-
-    def closeEvent(self, event):
-        self.closed = True
-
-    def setPixmap(self, pixmap):
-        self.imgLabel.setPixmap(pixmap)
-
-    def setText(self, text):
-        self.missionBox.setPlainText(text)
-
-    def setKeyDownCb(self, callback):
-        self.keyDownCb = callback
-
-    def keyPressEvent(self, e):
-        if self.keyDownCb == None:
-            return
-
-        keyName = None
-        if e.key() == Qt.Key_Left:
-            keyName = 'LEFT'
-        elif e.key() == Qt.Key_Right:
-            keyName = 'RIGHT'
-        elif e.key() == Qt.Key_Up:
-            keyName = 'UP'
-        elif e.key() == Qt.Key_Down:
-            keyName = 'DOWN'
-        elif e.key() == Qt.Key_Space:
-            keyName = 'SPACE'
-        elif e.key() == Qt.Key_Return:
-            keyName = 'RETURN'
-        elif e.key() == Qt.Key_Alt:
-            keyName = 'ALT'
-        elif e.key() == Qt.Key_Control:
-            keyName = 'CTRL'
-        elif e.key() == Qt.Key_PageUp:
-            keyName = 'PAGE_UP'
-        elif e.key() == Qt.Key_PageDown:
-            keyName = 'PAGE_DOWN'
-        elif e.key() == Qt.Key_Backspace:
-            keyName = 'BACKSPACE'
-        elif e.key() == Qt.Key_Escape:
-            keyName = 'ESCAPE'
-
-        if keyName == None:
-            return
-        self.keyDownCb(keyName)
-
-class Renderer:
-    def __init__(self, width, height, ownWindow=False):
-        self.width = width
-        self.height = height
-
-        self.img = QImage(width, height, QImage.Format_RGB888)
-        self.painter = QPainter()
-
-        self.window = None
-        if ownWindow:
-            self.app = QApplication([])
-            self.window = Window()
-
-    def close(self):
-        """
-        Deallocate resources used
-        """
-        pass
-
-    def beginFrame(self):
-        self.painter.begin(self.img)
-        self.painter.setRenderHint(QPainter.Antialiasing, False)
-
-        # Clear the background
-        self.painter.setBrush(QColor(0, 0, 0))
-        self.painter.drawRect(0, 0, self.width - 1, self.height - 1)
-
-    def endFrame(self):
-        self.painter.end()
-
-        if self.window:
-            if self.window.closed:
-                self.window = None
-            else:
-                self.window.setPixmap(self.getPixmap())
-                self.app.processEvents()
-
-    def getPixmap(self):
-        return QPixmap.fromImage(self.img)
-
-    def getArray(self):
-        """
-        Get a numpy array of RGB pixel values.
-        The array will have shape (height, width, 3)
-        """
-
-        numBytes = self.width * self.height * 3
-        buf = self.img.bits().asstring(numBytes)
-        output = np.frombuffer(buf, dtype='uint8')
-        output = output.reshape((self.height, self.width, 3))
-
-        return output
-
-    def push(self):
-        self.painter.save()
-
-    def pop(self):
-        self.painter.restore()
-
-    def rotate(self, degrees):
-        self.painter.rotate(degrees)
-
-    def translate(self, x, y):
-        self.painter.translate(x, y)
-
-    def scale(self, x, y):
-        self.painter.scale(x, y)
-
-    def setLineColor(self, r, g, b, a=255):
-        self.painter.setPen(QColor(r, g, b, a))
-
-    def setColor(self, r, g, b, a=255):
-        self.painter.setBrush(QColor(r, g, b, a))
-
-    def setLineWidth(self, width):
-        pen = self.painter.pen()
-        pen.setWidthF(width)
-        self.painter.setPen(pen)
-
-    def drawLine(self, x0, y0, x1, y1):
-        self.painter.drawLine(x0, y0, x1, y1)
-
-    def drawCircle(self, x, y, r):
-        center = QPoint(x, y)
-        self.painter.drawEllipse(center, r, r)
-
-    def drawPolygon(self, points):
-        """Takes a list of points (tuples) as input"""
-        points = map(lambda p: QPoint(p[0], p[1]), points)
-        self.painter.drawPolygon(QPolygon(points))
-
-    def drawPolyline(self, points):
-        """Takes a list of points (tuples) as input"""
-        points = map(lambda p: QPoint(p[0], p[1]), points)
-        self.painter.drawPolyline(QPolygon(points))
-
-    def fillRect(self, x, y, width, height, r, g, b, a=255):
-        self.painter.fillRect(QRect(x, y, width, height), QColor(r, g, b, a))
-
-
-
-
+    for y in range(img.shape[0]):
+        for x in range(img.shape[1]):
+            yf = y / img.shape[0]
+            xf = x / img.shape[1]
+            if fn(xf, yf):
+                img[y, x] = color
 
 
+    return img
 
 
 def point_in_circle(cx, cy, r):
 def point_in_circle(cx, cy, r):
     def fn(x, y):
     def fn(x, y):
@@ -235,12 +51,3 @@ def point_in_triangle(a, b, c):
         return (u >= 0) and (v >= 0) and (u + v) < 1
         return (u >= 0) and (v >= 0) and (u + v) < 1
 
 
     return fn
     return fn
-
-# TODO: anti-aliased version, fill_coords_aa?
-def fill_coords(img, fn, color):
-    for y in range(img.shape[0]):
-        for x in range(img.shape[1]):
-            yf = y / img.shape[0]
-            xf = x / img.shape[1]
-            if fn(xf, yf):
-                img[y, x] = color

+ 1 - 1
setup.py

@@ -10,6 +10,6 @@ setup(
     install_requires=[
     install_requires=[
         'gym>=0.9.6',
         'gym>=0.9.6',
         'numpy>=1.15.0',
         'numpy>=1.15.0',
-        'pyqt5>=5.10.1'
+        'matplotlib'
     ]
     ]
 )
 )