|
@@ -4,6 +4,7 @@ from enum import IntEnum
|
|
import numpy as np
|
|
import numpy as np
|
|
from gym import error, spaces, utils
|
|
from gym import error, spaces, utils
|
|
from gym.utils import seeding
|
|
from gym.utils import seeding
|
|
|
|
+from .rendering import *
|
|
|
|
|
|
# Size in pixels of a tile in the full-scale human view
|
|
# Size in pixels of a tile in the full-scale human view
|
|
TILE_PIXELS = 32
|
|
TILE_PIXELS = 32
|
|
@@ -110,6 +111,7 @@ class WorldObj:
|
|
"""Encode the a description of this object as a 3-tuple of integers"""
|
|
"""Encode the a description of this object as a 3-tuple of integers"""
|
|
return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], 0)
|
|
return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], 0)
|
|
|
|
|
|
|
|
+ @staticmethod
|
|
def decode(type_idx, color_idx, state):
|
|
def decode(type_idx, color_idx, state):
|
|
"""Create an object from a 3-tuple state description"""
|
|
"""Create an object from a 3-tuple state description"""
|
|
|
|
|
|
@@ -148,12 +150,6 @@ class WorldObj:
|
|
"""Draw this object with the given renderer"""
|
|
"""Draw this object with the given renderer"""
|
|
raise NotImplementedError
|
|
raise NotImplementedError
|
|
|
|
|
|
- def _set_color(self, r):
|
|
|
|
- """Set the color of this object as the active drawing color"""
|
|
|
|
- c = COLORS[self.color]
|
|
|
|
- r.setLineColor(c[0], c[1], c[2])
|
|
|
|
- r.setColor(c[0], c[1], c[2])
|
|
|
|
-
|
|
|
|
class Goal(WorldObj):
|
|
class Goal(WorldObj):
|
|
def __init__(self):
|
|
def __init__(self):
|
|
super().__init__('goal', 'green')
|
|
super().__init__('goal', 'green')
|
|
@@ -162,13 +158,7 @@ class Goal(WorldObj):
|
|
return True
|
|
return True
|
|
|
|
|
|
def render(self, r):
|
|
def render(self, r):
|
|
- self._set_color(r)
|
|
|
|
- r.drawPolygon([
|
|
|
|
- (0 , TILE_PIXELS),
|
|
|
|
- (TILE_PIXELS, TILE_PIXELS),
|
|
|
|
- (TILE_PIXELS, 0),
|
|
|
|
- (0 , 0)
|
|
|
|
- ])
|
|
|
|
|
|
+ fill_coords(img, point_in_rect(0.5, 0.5, 0.5, 0.5), COLORS[self.color])
|
|
|
|
|
|
class Floor(WorldObj):
|
|
class Floor(WorldObj):
|
|
"""
|
|
"""
|
|
@@ -246,13 +236,7 @@ class Wall(WorldObj):
|
|
return False
|
|
return False
|
|
|
|
|
|
def render(self, r):
|
|
def render(self, r):
|
|
- self._set_color(r)
|
|
|
|
- r.drawPolygon([
|
|
|
|
- (0 , TILE_PIXELS),
|
|
|
|
- (TILE_PIXELS, TILE_PIXELS),
|
|
|
|
- (TILE_PIXELS, 0),
|
|
|
|
- (0 , 0)
|
|
|
|
- ])
|
|
|
|
|
|
+ fill_coords(img, point_in_rect(0.5, 0.5, 0.5, 0.5), COLORS[self.color])
|
|
|
|
|
|
class Door(WorldObj):
|
|
class Door(WorldObj):
|
|
def __init__(self, color, is_open=False, is_locked=False):
|
|
def __init__(self, color, is_open=False, is_locked=False):
|
|
@@ -375,9 +359,8 @@ class Ball(WorldObj):
|
|
def can_pickup(self):
|
|
def can_pickup(self):
|
|
return True
|
|
return True
|
|
|
|
|
|
- def render(self, r):
|
|
|
|
- self._set_color(r)
|
|
|
|
- r.drawCircle(TILE_PIXELS * 0.5, TILE_PIXELS * 0.5, 10)
|
|
|
|
|
|
+ def render(self, img):
|
|
|
|
+ fill_coords(img, point_in_circle(0.5, 0.5, 0.31), COLORS[self.color])
|
|
|
|
|
|
class Box(WorldObj):
|
|
class Box(WorldObj):
|
|
def __init__(self, color, contains=None):
|
|
def __init__(self, color, contains=None):
|
|
@@ -414,32 +397,14 @@ class Box(WorldObj):
|
|
env.grid.set(*pos, self.contains)
|
|
env.grid.set(*pos, self.contains)
|
|
return True
|
|
return True
|
|
|
|
|
|
-
|
|
|
|
-
|
|
|
|
-
|
|
|
|
-
|
|
|
|
-
|
|
|
|
-def render_tile(
|
|
|
|
- obj,
|
|
|
|
- agent_dir=None,
|
|
|
|
- highlight=False,
|
|
|
|
- tile_size=TILE_PIXELS
|
|
|
|
-):
|
|
|
|
- """
|
|
|
|
- Render a tile and cache the result
|
|
|
|
- """
|
|
|
|
-
|
|
|
|
- pass
|
|
|
|
-
|
|
|
|
-
|
|
|
|
-
|
|
|
|
-
|
|
|
|
-
|
|
|
|
class Grid:
|
|
class Grid:
|
|
"""
|
|
"""
|
|
Represent a grid and operations on it
|
|
Represent a grid and operations on it
|
|
"""
|
|
"""
|
|
|
|
|
|
|
|
+ # Static cache of pre-renderer tiles
|
|
|
|
+ tile_cache = {}
|
|
|
|
+
|
|
def __init__(self, width, height):
|
|
def __init__(self, width, height):
|
|
assert width >= 3
|
|
assert width >= 3
|
|
assert height >= 3
|
|
assert height >= 3
|
|
@@ -540,35 +505,55 @@ class Grid:
|
|
|
|
|
|
return grid
|
|
return grid
|
|
|
|
|
|
- def render(self, r, tile_size):
|
|
|
|
|
|
+ @classmethod
|
|
|
|
+ def render_tile(
|
|
|
|
+ cls,
|
|
|
|
+ obj,
|
|
|
|
+ agent_dir=None,
|
|
|
|
+ highlight=False,
|
|
|
|
+ tile_size=TILE_PIXELS
|
|
|
|
+ ):
|
|
|
|
+ """
|
|
|
|
+ Render a tile and cache the result
|
|
|
|
+ """
|
|
|
|
+
|
|
|
|
+ # Hash map lookup key for the cache
|
|
|
|
+ key = obj.encode() + (agent_dir, highlight, tile_size)
|
|
|
|
+
|
|
|
|
+ if key in cls.tile_cache:
|
|
|
|
+ return tile_cache[key]
|
|
|
|
+
|
|
|
|
+ img = np.zeros(shape=(tile_size, tile_size, 3), dtype=np.uint8)
|
|
|
|
+
|
|
|
|
+ obj.render_tile(img)
|
|
|
|
+
|
|
|
|
+ # TODO: overlay agent on top
|
|
|
|
+ if agent_dir is not None:
|
|
|
|
+ pass
|
|
|
|
+
|
|
|
|
+ # TODO: highlighting
|
|
|
|
+ if highlight:
|
|
|
|
+ pass
|
|
|
|
+
|
|
|
|
+ # Cache the rendered tile
|
|
|
|
+ tile_cache[key] = img
|
|
|
|
+
|
|
|
|
+ return img
|
|
|
|
+
|
|
|
|
+ def render(self, tile_size):
|
|
"""
|
|
"""
|
|
Render this grid at a given scale
|
|
Render this grid at a given scale
|
|
:param r: target renderer object
|
|
:param r: target renderer object
|
|
:param tile_size: tile size in pixels
|
|
:param tile_size: tile size in pixels
|
|
"""
|
|
"""
|
|
|
|
|
|
- assert r.width == self.width * tile_size
|
|
|
|
- assert r.height == self.height * tile_size
|
|
|
|
|
|
+ # Compute the total grid size
|
|
|
|
+ width_px = self.width * TILE_PIXELS
|
|
|
|
+ height_px = self.height * TILE_PIXELS
|
|
|
|
|
|
- # Total grid size at native scale
|
|
|
|
- widthPx = self.width * TILE_PIXELS
|
|
|
|
- heightPx = self.height * TILE_PIXELS
|
|
|
|
-
|
|
|
|
- r.push()
|
|
|
|
-
|
|
|
|
- # Internally, we draw at the "large" full-grid resolution, but we
|
|
|
|
- # use the renderer to scale back to the desired size
|
|
|
|
- r.scale(tile_size / TILE_PIXELS, tile_size / TILE_PIXELS)
|
|
|
|
-
|
|
|
|
- # Draw the background of the in-world cells black
|
|
|
|
- r.fillRect(
|
|
|
|
- 0,
|
|
|
|
- 0,
|
|
|
|
- widthPx,
|
|
|
|
- heightPx,
|
|
|
|
- 0, 0, 0
|
|
|
|
- )
|
|
|
|
|
|
+ img = np.zeros(shape=(height_px, width_px, 3), dtype=np.uint8)
|
|
|
|
|
|
|
|
+ """
|
|
# Draw grid lines
|
|
# Draw grid lines
|
|
r.setLineColor(100, 100, 100)
|
|
r.setLineColor(100, 100, 100)
|
|
for rowIdx in range(0, self.height):
|
|
for rowIdx in range(0, self.height):
|
|
@@ -577,6 +562,7 @@ class Grid:
|
|
for colIdx in range(0, self.width):
|
|
for colIdx in range(0, self.width):
|
|
x = TILE_PIXELS * colIdx
|
|
x = TILE_PIXELS * colIdx
|
|
r.drawLine(x, 0, x, heightPx)
|
|
r.drawLine(x, 0, x, heightPx)
|
|
|
|
+ """
|
|
|
|
|
|
# Render the grid
|
|
# Render the grid
|
|
for j in range(0, self.height):
|
|
for j in range(0, self.height):
|
|
@@ -584,12 +570,18 @@ class Grid:
|
|
cell = self.get(i, j)
|
|
cell = self.get(i, j)
|
|
if cell == None:
|
|
if cell == None:
|
|
continue
|
|
continue
|
|
- r.push()
|
|
|
|
- r.translate(i * TILE_PIXELS, j * TILE_PIXELS)
|
|
|
|
- cell.render(r)
|
|
|
|
- r.pop()
|
|
|
|
|
|
|
|
- r.pop()
|
|
|
|
|
|
+ """
|
|
|
|
+ tile_img = Grid.render_tile(
|
|
|
|
+ cell,
|
|
|
|
+ agent_dir=None,
|
|
|
|
+ highlight=False,
|
|
|
|
+ tile_size=tile_size
|
|
|
|
+ )
|
|
|
|
+ """
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+ return img
|
|
|
|
|
|
def encode(self, vis_mask=None):
|
|
def encode(self, vis_mask=None):
|
|
"""
|
|
"""
|
|
@@ -1270,22 +1262,12 @@ class MiniGridEnv(gym.Env):
|
|
Render an agent observation for visualization
|
|
Render an agent observation for visualization
|
|
"""
|
|
"""
|
|
|
|
|
|
- if self.obs_render == None:
|
|
|
|
- from gym_minigrid.rendering import Renderer
|
|
|
|
- self.obs_render = Renderer(
|
|
|
|
- self.agent_view_size * tile_size,
|
|
|
|
- self.agent_view_size * tile_size
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
- r = self.obs_render
|
|
|
|
-
|
|
|
|
- r.beginFrame()
|
|
|
|
-
|
|
|
|
grid = Grid.decode(obs)
|
|
grid = Grid.decode(obs)
|
|
|
|
|
|
# Render the whole grid
|
|
# Render the whole grid
|
|
- grid.render(r, tile_size)
|
|
|
|
|
|
+ img = grid.render(r, tile_size)
|
|
|
|
|
|
|
|
+ """
|
|
# Draw the agent
|
|
# Draw the agent
|
|
ratio = tile_size / TILE_PIXELS
|
|
ratio = tile_size / TILE_PIXELS
|
|
r.push()
|
|
r.push()
|
|
@@ -1304,42 +1286,29 @@ class MiniGridEnv(gym.Env):
|
|
])
|
|
])
|
|
r.pop()
|
|
r.pop()
|
|
|
|
|
|
- r.endFrame()
|
|
|
|
-
|
|
|
|
if mode == 'rgb_array':
|
|
if mode == 'rgb_array':
|
|
return r.getArray()
|
|
return r.getArray()
|
|
elif mode == 'pixmap':
|
|
elif mode == 'pixmap':
|
|
return r.getPixmap()
|
|
return r.getPixmap()
|
|
return r
|
|
return r
|
|
|
|
+ """
|
|
|
|
|
|
def render(self, mode='human', close=False, highlight=True, tile_size=TILE_PIXELS):
|
|
def render(self, mode='human', close=False, highlight=True, tile_size=TILE_PIXELS):
|
|
"""
|
|
"""
|
|
Render the whole-grid human view
|
|
Render the whole-grid human view
|
|
"""
|
|
"""
|
|
|
|
|
|
|
|
+ """
|
|
if close:
|
|
if close:
|
|
if self.grid_render:
|
|
if self.grid_render:
|
|
self.grid_render.close()
|
|
self.grid_render.close()
|
|
return
|
|
return
|
|
-
|
|
|
|
- if self.grid_render is None or self.grid_render.window is None or (self.grid_render.width != self.width * tile_size):
|
|
|
|
- from gym_minigrid.rendering import Renderer
|
|
|
|
- self.grid_render = Renderer(
|
|
|
|
- self.width * tile_size,
|
|
|
|
- self.height * tile_size,
|
|
|
|
- True if mode == 'human' else False
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
- r = self.grid_render
|
|
|
|
-
|
|
|
|
- if r.window:
|
|
|
|
- r.window.setText(self.mission)
|
|
|
|
-
|
|
|
|
- r.beginFrame()
|
|
|
|
|
|
+ """
|
|
|
|
|
|
# Render the whole grid
|
|
# Render the whole grid
|
|
- self.grid.render(r, tile_size)
|
|
|
|
|
|
+ img = self.grid.render(tile_size)
|
|
|
|
|
|
|
|
+ """
|
|
# Draw the agent
|
|
# Draw the agent
|
|
ratio = tile_size / TILE_PIXELS
|
|
ratio = tile_size / TILE_PIXELS
|
|
r.push()
|
|
r.push()
|
|
@@ -1357,7 +1326,9 @@ class MiniGridEnv(gym.Env):
|
|
(-12, -10)
|
|
(-12, -10)
|
|
])
|
|
])
|
|
r.pop()
|
|
r.pop()
|
|
|
|
+ """
|
|
|
|
|
|
|
|
+ """
|
|
# Compute which cells are visible to the agent
|
|
# Compute which cells are visible to the agent
|
|
_, vis_mask = self.gen_obs_grid()
|
|
_, vis_mask = self.gen_obs_grid()
|
|
|
|
|
|
@@ -1386,11 +1357,13 @@ class MiniGridEnv(gym.Env):
|
|
tile_size,
|
|
tile_size,
|
|
255, 255, 255, 75
|
|
255, 255, 255, 75
|
|
)
|
|
)
|
|
|
|
+ """
|
|
|
|
|
|
- r.endFrame()
|
|
|
|
-
|
|
|
|
|
|
+ """
|
|
if mode == 'rgb_array':
|
|
if mode == 'rgb_array':
|
|
return r.getArray()
|
|
return r.getArray()
|
|
elif mode == 'pixmap':
|
|
elif mode == 'pixmap':
|
|
return r.getPixmap()
|
|
return r.getPixmap()
|
|
- return r
|
|
|
|
|
|
+ """
|
|
|
|
+
|
|
|
|
+ return img
|