123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329 |
- from __future__ import annotations
- import math
- from typing import Any, Callable
- import numpy as np
- from minigrid.core.constants import OBJECT_TO_IDX, TILE_PIXELS
- from minigrid.core.world_object import Wall, WorldObj
- from minigrid.utils.rendering import (
- downsample,
- fill_coords,
- highlight_img,
- point_in_rect,
- point_in_triangle,
- rotate_fn,
- )
- class Grid:
- """
- Represent a grid and operations on it
- """
- # Static cache of pre-renderer tiles
- tile_cache: dict[tuple[Any, ...], Any] = {}
- def __init__(self, width: int, height: int):
- assert width >= 3
- assert height >= 3
- self.width: int = width
- self.height: int = height
- self.grid: list[WorldObj | None] = [None] * (width * height)
- def __contains__(self, key: Any) -> bool:
- if isinstance(key, WorldObj):
- for e in self.grid:
- if e is key:
- return True
- elif isinstance(key, tuple):
- for e in self.grid:
- if e is None:
- continue
- if (e.color, e.type) == key:
- return True
- if key[0] is None and key[1] == e.type:
- return True
- return False
- def __eq__(self, other: Grid) -> bool:
- grid1 = self.encode()
- grid2 = other.encode()
- return np.array_equal(grid2, grid1)
- def __ne__(self, other: Grid) -> bool:
- return not self == other
- def copy(self) -> Grid:
- from copy import deepcopy
- return deepcopy(self)
- def set(self, i: int, j: int, v: WorldObj | None):
- assert (
- 0 <= i < self.width
- ), f"column index {i} outside of grid of width {self.width}"
- assert (
- 0 <= j < self.height
- ), f"row index {j} outside of grid of height {self.height}"
- self.grid[j * self.width + i] = v
- def get(self, i: int, j: int) -> WorldObj | None:
- assert 0 <= i < self.width
- assert 0 <= j < self.height
- assert self.grid is not None
- return self.grid[j * self.width + i]
- def horz_wall(
- self,
- x: int,
- y: int,
- length: int | None = None,
- obj_type: Callable[[], WorldObj] = Wall,
- ):
- if length is None:
- length = self.width - x
- for i in range(0, length):
- self.set(x + i, y, obj_type())
- def vert_wall(
- self,
- x: int,
- y: int,
- length: int | None = None,
- obj_type: Callable[[], WorldObj] = Wall,
- ):
- if length is None:
- length = self.height - y
- for j in range(0, length):
- self.set(x, y + j, obj_type())
- def wall_rect(self, x: int, y: int, w: int, h: int):
- self.horz_wall(x, y, w)
- self.horz_wall(x, y + h - 1, w)
- self.vert_wall(x, y, h)
- self.vert_wall(x + w - 1, y, h)
- def rotate_left(self) -> Grid:
- """
- Rotate the grid to the left (counter-clockwise)
- """
- grid = Grid(self.height, self.width)
- for i in range(self.width):
- for j in range(self.height):
- v = self.get(i, j)
- grid.set(j, grid.height - 1 - i, v)
- return grid
- def slice(self, topX: int, topY: int, width: int, height: int) -> Grid:
- """
- Get a subset of the grid
- """
- grid = Grid(width, height)
- for j in range(0, height):
- for i in range(0, width):
- x = topX + i
- y = topY + j
- if 0 <= x < self.width and 0 <= y < self.height:
- v = self.get(x, y)
- else:
- v = Wall()
- grid.set(i, j, v)
- return grid
- @classmethod
- def render_tile(
- cls,
- obj: WorldObj | None,
- agent_dir: int | None = None,
- highlight: bool = False,
- tile_size: int = TILE_PIXELS,
- subdivs: int = 3,
- ) -> np.ndarray:
- """
- Render a tile and cache the result
- """
- # Hash map lookup key for the cache
- key: tuple[Any, ...] = (agent_dir, highlight, tile_size)
- key = obj.encode() + key if obj else key
- if key in cls.tile_cache:
- return cls.tile_cache[key]
- img = np.zeros(
- shape=(tile_size * subdivs, tile_size * subdivs, 3), dtype=np.uint8
- )
- # Draw the grid lines (top and left edges)
- fill_coords(img, point_in_rect(0, 0.031, 0, 1), (100, 100, 100))
- fill_coords(img, point_in_rect(0, 1, 0, 0.031), (100, 100, 100))
- if obj is not None:
- obj.render(img)
- # Overlay the agent on top
- if agent_dir is not None:
- tri_fn = point_in_triangle(
- (0.12, 0.19),
- (0.87, 0.50),
- (0.12, 0.81),
- )
- # Rotate the agent based on its direction
- tri_fn = rotate_fn(tri_fn, cx=0.5, cy=0.5, theta=0.5 * math.pi * agent_dir)
- fill_coords(img, tri_fn, (255, 0, 0))
- # Highlight the cell if needed
- if highlight:
- highlight_img(img)
- # Downsample the image to perform supersampling/anti-aliasing
- img = downsample(img, subdivs)
- # Cache the rendered tile
- cls.tile_cache[key] = img
- return img
- def render(
- self,
- tile_size: int,
- agent_pos: tuple[int, int],
- agent_dir: int | None = None,
- highlight_mask: np.ndarray | None = None,
- ) -> np.ndarray:
- """
- Render this grid at a given scale
- :param r: target renderer object
- :param tile_size: tile size in pixels
- """
- if highlight_mask is None:
- highlight_mask = np.zeros(shape=(self.width, self.height), dtype=bool)
- # Compute the total grid size
- width_px = self.width * tile_size
- height_px = self.height * tile_size
- img = np.zeros(shape=(height_px, width_px, 3), dtype=np.uint8)
- # Render the grid
- for j in range(0, self.height):
- for i in range(0, self.width):
- cell = self.get(i, j)
- agent_here = np.array_equal(agent_pos, (i, j))
- assert highlight_mask is not None
- tile_img = Grid.render_tile(
- cell,
- agent_dir=agent_dir if agent_here else None,
- highlight=highlight_mask[i, j],
- tile_size=tile_size,
- )
- ymin = j * tile_size
- ymax = (j + 1) * tile_size
- xmin = i * tile_size
- xmax = (i + 1) * tile_size
- img[ymin:ymax, xmin:xmax, :] = tile_img
- return img
- def encode(self, vis_mask: np.ndarray | None = None) -> np.ndarray:
- """
- Produce a compact numpy encoding of the grid
- """
- if vis_mask is None:
- vis_mask = np.ones((self.width, self.height), dtype=bool)
- array = np.zeros((self.width, self.height, 3), dtype="uint8")
- for i in range(self.width):
- for j in range(self.height):
- assert vis_mask is not None
- if vis_mask[i, j]:
- v = self.get(i, j)
- if v is None:
- array[i, j, 0] = OBJECT_TO_IDX["empty"]
- array[i, j, 1] = 0
- array[i, j, 2] = 0
- else:
- array[i, j, :] = v.encode()
- return array
- @staticmethod
- def decode(array: np.ndarray) -> tuple[Grid, np.ndarray]:
- """
- Decode an array grid encoding back into a grid
- """
- width, height, channels = array.shape
- assert channels == 3
- vis_mask = np.ones(shape=(width, height), dtype=bool)
- grid = Grid(width, height)
- for i in range(width):
- for j in range(height):
- type_idx, color_idx, state = array[i, j]
- v = WorldObj.decode(type_idx, color_idx, state)
- grid.set(i, j, v)
- vis_mask[i, j] = type_idx != OBJECT_TO_IDX["unseen"]
- return grid, vis_mask
- def process_vis(self, agent_pos: tuple[int, int]) -> np.ndarray:
- mask = np.zeros(shape=(self.width, self.height), dtype=bool)
- mask[agent_pos[0], agent_pos[1]] = True
- for j in reversed(range(0, self.height)):
- for i in range(0, self.width - 1):
- if not mask[i, j]:
- continue
- cell = self.get(i, j)
- if cell and not cell.see_behind():
- continue
- mask[i + 1, j] = True
- if j > 0:
- mask[i + 1, j - 1] = True
- mask[i, j - 1] = True
- for i in reversed(range(1, self.width)):
- if not mask[i, j]:
- continue
- cell = self.get(i, j)
- if cell and not cell.see_behind():
- continue
- mask[i - 1, j] = True
- if j > 0:
- mask[i - 1, j - 1] = True
- mask[i, j - 1] = True
- for j in range(0, self.height):
- for i in range(0, self.width):
- if not mask[i, j]:
- self.set(i, j, None)
- return mask
|