瀏覽代碼

Sketched tile caching. Removed PyQT dependencies.

Maxime Chevalier-Boisvert 5 年之前
父節點
當前提交
c1d3cc5835
共有 3 個文件被更改,包括 87 次插入307 次删除
  1. 76 103
      gym_minigrid/minigrid.py
  2. 10 203
      gym_minigrid/rendering.py
  3. 1 1
      setup.py

+ 76 - 103
gym_minigrid/minigrid.py

@@ -4,6 +4,7 @@ from enum import IntEnum
 import numpy as np
 from gym import error, spaces, utils
 from gym.utils import seeding
+from .rendering import *
 
 # Size in pixels of a tile in the full-scale human view
 TILE_PIXELS = 32
@@ -110,6 +111,7 @@ class WorldObj:
         """Encode the a description of this object as a 3-tuple of integers"""
         return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], 0)
 
+    @staticmethod
     def decode(type_idx, color_idx, state):
         """Create an object from a 3-tuple state description"""
 
@@ -148,12 +150,6 @@ class WorldObj:
         """Draw this object with the given renderer"""
         raise NotImplementedError
 
-    def _set_color(self, r):
-        """Set the color of this object as the active drawing color"""
-        c = COLORS[self.color]
-        r.setLineColor(c[0], c[1], c[2])
-        r.setColor(c[0], c[1], c[2])
-
 class Goal(WorldObj):
     def __init__(self):
         super().__init__('goal', 'green')
@@ -162,13 +158,7 @@ class Goal(WorldObj):
         return True
 
     def render(self, r):
-        self._set_color(r)
-        r.drawPolygon([
-            (0          , TILE_PIXELS),
-            (TILE_PIXELS, TILE_PIXELS),
-            (TILE_PIXELS,           0),
-            (0          ,           0)
-        ])
+        fill_coords(img, point_in_rect(0.5, 0.5, 0.5, 0.5), COLORS[self.color])
 
 class Floor(WorldObj):
     """
@@ -246,13 +236,7 @@ class Wall(WorldObj):
         return False
 
     def render(self, r):
-        self._set_color(r)
-        r.drawPolygon([
-            (0          , TILE_PIXELS),
-            (TILE_PIXELS, TILE_PIXELS),
-            (TILE_PIXELS,           0),
-            (0          ,           0)
-        ])
+        fill_coords(img, point_in_rect(0.5, 0.5, 0.5, 0.5), COLORS[self.color])
 
 class Door(WorldObj):
     def __init__(self, color, is_open=False, is_locked=False):
@@ -375,9 +359,8 @@ class Ball(WorldObj):
     def can_pickup(self):
         return True
 
-    def render(self, r):
-        self._set_color(r)
-        r.drawCircle(TILE_PIXELS * 0.5, TILE_PIXELS * 0.5, 10)
+    def render(self, img):
+        fill_coords(img, point_in_circle(0.5, 0.5, 0.31), COLORS[self.color])
 
 class Box(WorldObj):
     def __init__(self, color, contains=None):
@@ -414,32 +397,14 @@ class Box(WorldObj):
         env.grid.set(*pos, self.contains)
         return True
 
-
-
-
-
-
-def render_tile(
-    obj,
-    agent_dir=None,
-    highlight=False,
-    tile_size=TILE_PIXELS
-):
-    """
-    Render a tile and cache the result
-    """
-
-    pass
-
-
-
-
-
 class Grid:
     """
     Represent a grid and operations on it
     """
 
+    # Static cache of pre-renderer tiles
+    tile_cache = {}
+
     def __init__(self, width, height):
         assert width >= 3
         assert height >= 3
@@ -540,35 +505,55 @@ class Grid:
 
         return grid
 
-    def render(self, r, tile_size):
+    @classmethod
+    def render_tile(
+        cls,
+        obj,
+        agent_dir=None,
+        highlight=False,
+        tile_size=TILE_PIXELS
+    ):
+        """
+        Render a tile and cache the result
+        """
+
+        # Hash map lookup key for the cache
+        key = obj.encode() + (agent_dir, highlight, tile_size)
+
+        if key in cls.tile_cache:
+            return tile_cache[key]
+
+        img = np.zeros(shape=(tile_size, tile_size, 3), dtype=np.uint8)
+
+        obj.render_tile(img)
+
+        # TODO: overlay agent on top
+        if agent_dir is not None:
+            pass
+
+        # TODO: highlighting
+        if highlight:
+            pass
+
+        # Cache the rendered tile
+        tile_cache[key] = img
+
+        return img
+
+    def render(self, tile_size):
         """
         Render this grid at a given scale
         :param r: target renderer object
         :param tile_size: tile size in pixels
         """
 
-        assert r.width == self.width * tile_size
-        assert r.height == self.height * tile_size
+        # Compute the total grid size
+        width_px = self.width * TILE_PIXELS
+        height_px = self.height * TILE_PIXELS
 
-        # Total grid size at native scale
-        widthPx = self.width * TILE_PIXELS
-        heightPx = self.height * TILE_PIXELS
-
-        r.push()
-
-        # Internally, we draw at the "large" full-grid resolution, but we
-        # use the renderer to scale back to the desired size
-        r.scale(tile_size / TILE_PIXELS, tile_size / TILE_PIXELS)
-
-        # Draw the background of the in-world cells black
-        r.fillRect(
-            0,
-            0,
-            widthPx,
-            heightPx,
-            0, 0, 0
-        )
+        img = np.zeros(shape=(height_px, width_px, 3), dtype=np.uint8)
 
+        """
         # Draw grid lines
         r.setLineColor(100, 100, 100)
         for rowIdx in range(0, self.height):
@@ -577,6 +562,7 @@ class Grid:
         for colIdx in range(0, self.width):
             x = TILE_PIXELS * colIdx
             r.drawLine(x, 0, x, heightPx)
+        """
 
         # Render the grid
         for j in range(0, self.height):
@@ -584,12 +570,18 @@ class Grid:
                 cell = self.get(i, j)
                 if cell == None:
                     continue
-                r.push()
-                r.translate(i * TILE_PIXELS, j * TILE_PIXELS)
-                cell.render(r)
-                r.pop()
 
-        r.pop()
+                """
+                tile_img = Grid.render_tile(
+                    cell,
+                    agent_dir=None,
+                    highlight=False,
+                    tile_size=tile_size
+                )
+                """
+
+
+        return img
 
     def encode(self, vis_mask=None):
         """
@@ -1270,22 +1262,12 @@ class MiniGridEnv(gym.Env):
         Render an agent observation for visualization
         """
 
-        if self.obs_render == None:
-            from gym_minigrid.rendering import Renderer
-            self.obs_render = Renderer(
-                self.agent_view_size * tile_size,
-                self.agent_view_size * tile_size
-            )
-
-        r = self.obs_render
-
-        r.beginFrame()
-
         grid = Grid.decode(obs)
 
         # Render the whole grid
-        grid.render(r, tile_size)
+        img = grid.render(r, tile_size)
 
+        """
         # Draw the agent
         ratio = tile_size / TILE_PIXELS
         r.push()
@@ -1304,42 +1286,29 @@ class MiniGridEnv(gym.Env):
         ])
         r.pop()
 
-        r.endFrame()
-
         if mode == 'rgb_array':
             return r.getArray()
         elif mode == 'pixmap':
             return r.getPixmap()
         return r
+        """
 
     def render(self, mode='human', close=False, highlight=True, tile_size=TILE_PIXELS):
         """
         Render the whole-grid human view
         """
 
+        """
         if close:
             if self.grid_render:
                 self.grid_render.close()
             return
-
-        if self.grid_render is None or self.grid_render.window is None or (self.grid_render.width != self.width * tile_size):
-            from gym_minigrid.rendering import Renderer
-            self.grid_render = Renderer(
-                self.width * tile_size,
-                self.height * tile_size,
-                True if mode == 'human' else False
-            )
-
-        r = self.grid_render
-
-        if r.window:
-            r.window.setText(self.mission)
-
-        r.beginFrame()
+        """
 
         # Render the whole grid
-        self.grid.render(r, tile_size)
+        img = self.grid.render(tile_size)
 
+        """
         # Draw the agent
         ratio = tile_size / TILE_PIXELS
         r.push()
@@ -1357,7 +1326,9 @@ class MiniGridEnv(gym.Env):
             (-12, -10)
         ])
         r.pop()
+        """
 
+        """
         # Compute which cells are visible to the agent
         _, vis_mask = self.gen_obs_grid()
 
@@ -1386,11 +1357,13 @@ class MiniGridEnv(gym.Env):
                         tile_size,
                         255, 255, 255, 75
                     )
+        """
 
-        r.endFrame()
-
+        """
         if mode == 'rgb_array':
             return r.getArray()
         elif mode == 'pixmap':
             return r.getPixmap()
-        return r
+        """
+
+        return img

+ 10 - 203
gym_minigrid/rendering.py

@@ -1,203 +1,19 @@
 import numpy as np
-from PyQt5.QtCore import Qt
-from PyQt5.QtGui import QImage, QPixmap, QPainter, QColor, QPolygon
-from PyQt5.QtCore import QPoint, QSize, QRect
-from PyQt5.QtWidgets import QApplication, QMainWindow, QWidget, QTextEdit
-from PyQt5.QtWidgets import QHBoxLayout, QVBoxLayout, QLabel, QFrame
 
-class Window(QMainWindow):
+# TODO: anti-aliased version, fill_coords_aa?
+def fill_coords(img, fn, color):
     """
-    Simple application window to render the environment into
+    Fill pixels of an image with coordinates matching a filter function
     """
 
-    def __init__(self):
-        super().__init__()
-
-        self.setWindowTitle('MiniGrid Gym Environment')
-
-        # Image label to display the rendering
-        self.imgLabel = QLabel()
-        self.imgLabel.setFrameStyle(QFrame.Panel | QFrame.Sunken)
-
-        # Text box for the mission
-        self.missionBox = QTextEdit()
-        self.missionBox.setReadOnly(True)
-        self.missionBox.setMinimumSize(400, 100)
-
-        # Center the image
-        hbox = QHBoxLayout()
-        hbox.addStretch(1)
-        hbox.addWidget(self.imgLabel)
-        hbox.addStretch(1)
-
-        # Arrange widgets vertically
-        vbox = QVBoxLayout()
-        vbox.addLayout(hbox)
-        vbox.addWidget(self.missionBox)
-
-        # Create a main widget for the window
-        mainWidget = QWidget(self)
-        self.setCentralWidget(mainWidget)
-        mainWidget.setLayout(vbox)
-
-        # Show the application window
-        self.show()
-        self.setFocus()
-
-        self.closed = False
-
-        # Callback for keyboard events
-        self.keyDownCb = None
-
-    def closeEvent(self, event):
-        self.closed = True
-
-    def setPixmap(self, pixmap):
-        self.imgLabel.setPixmap(pixmap)
-
-    def setText(self, text):
-        self.missionBox.setPlainText(text)
-
-    def setKeyDownCb(self, callback):
-        self.keyDownCb = callback
-
-    def keyPressEvent(self, e):
-        if self.keyDownCb == None:
-            return
-
-        keyName = None
-        if e.key() == Qt.Key_Left:
-            keyName = 'LEFT'
-        elif e.key() == Qt.Key_Right:
-            keyName = 'RIGHT'
-        elif e.key() == Qt.Key_Up:
-            keyName = 'UP'
-        elif e.key() == Qt.Key_Down:
-            keyName = 'DOWN'
-        elif e.key() == Qt.Key_Space:
-            keyName = 'SPACE'
-        elif e.key() == Qt.Key_Return:
-            keyName = 'RETURN'
-        elif e.key() == Qt.Key_Alt:
-            keyName = 'ALT'
-        elif e.key() == Qt.Key_Control:
-            keyName = 'CTRL'
-        elif e.key() == Qt.Key_PageUp:
-            keyName = 'PAGE_UP'
-        elif e.key() == Qt.Key_PageDown:
-            keyName = 'PAGE_DOWN'
-        elif e.key() == Qt.Key_Backspace:
-            keyName = 'BACKSPACE'
-        elif e.key() == Qt.Key_Escape:
-            keyName = 'ESCAPE'
-
-        if keyName == None:
-            return
-        self.keyDownCb(keyName)
-
-class Renderer:
-    def __init__(self, width, height, ownWindow=False):
-        self.width = width
-        self.height = height
-
-        self.img = QImage(width, height, QImage.Format_RGB888)
-        self.painter = QPainter()
-
-        self.window = None
-        if ownWindow:
-            self.app = QApplication([])
-            self.window = Window()
-
-    def close(self):
-        """
-        Deallocate resources used
-        """
-        pass
-
-    def beginFrame(self):
-        self.painter.begin(self.img)
-        self.painter.setRenderHint(QPainter.Antialiasing, False)
-
-        # Clear the background
-        self.painter.setBrush(QColor(0, 0, 0))
-        self.painter.drawRect(0, 0, self.width - 1, self.height - 1)
-
-    def endFrame(self):
-        self.painter.end()
-
-        if self.window:
-            if self.window.closed:
-                self.window = None
-            else:
-                self.window.setPixmap(self.getPixmap())
-                self.app.processEvents()
-
-    def getPixmap(self):
-        return QPixmap.fromImage(self.img)
-
-    def getArray(self):
-        """
-        Get a numpy array of RGB pixel values.
-        The array will have shape (height, width, 3)
-        """
-
-        numBytes = self.width * self.height * 3
-        buf = self.img.bits().asstring(numBytes)
-        output = np.frombuffer(buf, dtype='uint8')
-        output = output.reshape((self.height, self.width, 3))
-
-        return output
-
-    def push(self):
-        self.painter.save()
-
-    def pop(self):
-        self.painter.restore()
-
-    def rotate(self, degrees):
-        self.painter.rotate(degrees)
-
-    def translate(self, x, y):
-        self.painter.translate(x, y)
-
-    def scale(self, x, y):
-        self.painter.scale(x, y)
-
-    def setLineColor(self, r, g, b, a=255):
-        self.painter.setPen(QColor(r, g, b, a))
-
-    def setColor(self, r, g, b, a=255):
-        self.painter.setBrush(QColor(r, g, b, a))
-
-    def setLineWidth(self, width):
-        pen = self.painter.pen()
-        pen.setWidthF(width)
-        self.painter.setPen(pen)
-
-    def drawLine(self, x0, y0, x1, y1):
-        self.painter.drawLine(x0, y0, x1, y1)
-
-    def drawCircle(self, x, y, r):
-        center = QPoint(x, y)
-        self.painter.drawEllipse(center, r, r)
-
-    def drawPolygon(self, points):
-        """Takes a list of points (tuples) as input"""
-        points = map(lambda p: QPoint(p[0], p[1]), points)
-        self.painter.drawPolygon(QPolygon(points))
-
-    def drawPolyline(self, points):
-        """Takes a list of points (tuples) as input"""
-        points = map(lambda p: QPoint(p[0], p[1]), points)
-        self.painter.drawPolyline(QPolygon(points))
-
-    def fillRect(self, x, y, width, height, r, g, b, a=255):
-        self.painter.fillRect(QRect(x, y, width, height), QColor(r, g, b, a))
-
-
-
-
+    for y in range(img.shape[0]):
+        for x in range(img.shape[1]):
+            yf = y / img.shape[0]
+            xf = x / img.shape[1]
+            if fn(xf, yf):
+                img[y, x] = color
 
+    return img
 
 def point_in_circle(cx, cy, r):
     def fn(x, y):
@@ -235,12 +51,3 @@ def point_in_triangle(a, b, c):
         return (u >= 0) and (v >= 0) and (u + v) < 1
 
     return fn
-
-# TODO: anti-aliased version, fill_coords_aa?
-def fill_coords(img, fn, color):
-    for y in range(img.shape[0]):
-        for x in range(img.shape[1]):
-            yf = y / img.shape[0]
-            xf = x / img.shape[1]
-            if fn(xf, yf):
-                img[y, x] = color

+ 1 - 1
setup.py

@@ -10,6 +10,6 @@ setup(
     install_requires=[
         'gym>=0.9.6',
         'numpy>=1.15.0',
-        'pyqt5>=5.10.1'
+        'matplotlib'
     ]
 )