grid.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329
  1. from __future__ import annotations
  2. import math
  3. from typing import Any, Callable
  4. import numpy as np
  5. from minigrid.core.constants import OBJECT_TO_IDX, TILE_PIXELS
  6. from minigrid.core.world_object import Wall, WorldObj
  7. from minigrid.utils.rendering import (
  8. downsample,
  9. fill_coords,
  10. highlight_img,
  11. point_in_rect,
  12. point_in_triangle,
  13. rotate_fn,
  14. )
  15. class Grid:
  16. """
  17. Represent a grid and operations on it
  18. """
  19. # Static cache of pre-renderer tiles
  20. tile_cache: dict[tuple[Any, ...], Any] = {}
  21. def __init__(self, width: int, height: int):
  22. assert width >= 3
  23. assert height >= 3
  24. self.width: int = width
  25. self.height: int = height
  26. self.grid: list[WorldObj | None] = [None] * (width * height)
  27. def __contains__(self, key: Any) -> bool:
  28. if isinstance(key, WorldObj):
  29. for e in self.grid:
  30. if e is key:
  31. return True
  32. elif isinstance(key, tuple):
  33. for e in self.grid:
  34. if e is None:
  35. continue
  36. if (e.color, e.type) == key:
  37. return True
  38. if key[0] is None and key[1] == e.type:
  39. return True
  40. return False
  41. def __eq__(self, other: Grid) -> bool:
  42. grid1 = self.encode()
  43. grid2 = other.encode()
  44. return np.array_equal(grid2, grid1)
  45. def __ne__(self, other: Grid) -> bool:
  46. return not self == other
  47. def copy(self) -> Grid:
  48. from copy import deepcopy
  49. return deepcopy(self)
  50. def set(self, i: int, j: int, v: WorldObj | None):
  51. assert (
  52. 0 <= i < self.width
  53. ), f"column index {i} outside of grid of width {self.width}"
  54. assert (
  55. 0 <= j < self.height
  56. ), f"row index {j} outside of grid of height {self.height}"
  57. self.grid[j * self.width + i] = v
  58. def get(self, i: int, j: int) -> WorldObj | None:
  59. assert 0 <= i < self.width
  60. assert 0 <= j < self.height
  61. assert self.grid is not None
  62. return self.grid[j * self.width + i]
  63. def horz_wall(
  64. self,
  65. x: int,
  66. y: int,
  67. length: int | None = None,
  68. obj_type: Callable[[], WorldObj] = Wall,
  69. ):
  70. if length is None:
  71. length = self.width - x
  72. for i in range(0, length):
  73. self.set(x + i, y, obj_type())
  74. def vert_wall(
  75. self,
  76. x: int,
  77. y: int,
  78. length: int | None = None,
  79. obj_type: Callable[[], WorldObj] = Wall,
  80. ):
  81. if length is None:
  82. length = self.height - y
  83. for j in range(0, length):
  84. self.set(x, y + j, obj_type())
  85. def wall_rect(self, x: int, y: int, w: int, h: int):
  86. self.horz_wall(x, y, w)
  87. self.horz_wall(x, y + h - 1, w)
  88. self.vert_wall(x, y, h)
  89. self.vert_wall(x + w - 1, y, h)
  90. def rotate_left(self) -> Grid:
  91. """
  92. Rotate the grid to the left (counter-clockwise)
  93. """
  94. grid = Grid(self.height, self.width)
  95. for i in range(self.width):
  96. for j in range(self.height):
  97. v = self.get(i, j)
  98. grid.set(j, grid.height - 1 - i, v)
  99. return grid
  100. def slice(self, topX: int, topY: int, width: int, height: int) -> Grid:
  101. """
  102. Get a subset of the grid
  103. """
  104. grid = Grid(width, height)
  105. for j in range(0, height):
  106. for i in range(0, width):
  107. x = topX + i
  108. y = topY + j
  109. if 0 <= x < self.width and 0 <= y < self.height:
  110. v = self.get(x, y)
  111. else:
  112. v = Wall()
  113. grid.set(i, j, v)
  114. return grid
  115. @classmethod
  116. def render_tile(
  117. cls,
  118. obj: WorldObj | None,
  119. agent_dir: int | None = None,
  120. highlight: bool = False,
  121. tile_size: int = TILE_PIXELS,
  122. subdivs: int = 3,
  123. ) -> np.ndarray:
  124. """
  125. Render a tile and cache the result
  126. """
  127. # Hash map lookup key for the cache
  128. key: tuple[Any, ...] = (agent_dir, highlight, tile_size)
  129. key = obj.encode() + key if obj else key
  130. if key in cls.tile_cache:
  131. return cls.tile_cache[key]
  132. img = np.zeros(
  133. shape=(tile_size * subdivs, tile_size * subdivs, 3), dtype=np.uint8
  134. )
  135. # Draw the grid lines (top and left edges)
  136. fill_coords(img, point_in_rect(0, 0.031, 0, 1), (100, 100, 100))
  137. fill_coords(img, point_in_rect(0, 1, 0, 0.031), (100, 100, 100))
  138. if obj is not None:
  139. obj.render(img)
  140. # Overlay the agent on top
  141. if agent_dir is not None:
  142. tri_fn = point_in_triangle(
  143. (0.12, 0.19),
  144. (0.87, 0.50),
  145. (0.12, 0.81),
  146. )
  147. # Rotate the agent based on its direction
  148. tri_fn = rotate_fn(tri_fn, cx=0.5, cy=0.5, theta=0.5 * math.pi * agent_dir)
  149. fill_coords(img, tri_fn, (255, 0, 0))
  150. # Highlight the cell if needed
  151. if highlight:
  152. highlight_img(img)
  153. # Downsample the image to perform supersampling/anti-aliasing
  154. img = downsample(img, subdivs)
  155. # Cache the rendered tile
  156. cls.tile_cache[key] = img
  157. return img
  158. def render(
  159. self,
  160. tile_size: int,
  161. agent_pos: tuple[int, int],
  162. agent_dir: int | None = None,
  163. highlight_mask: np.ndarray | None = None,
  164. ) -> np.ndarray:
  165. """
  166. Render this grid at a given scale
  167. :param r: target renderer object
  168. :param tile_size: tile size in pixels
  169. """
  170. if highlight_mask is None:
  171. highlight_mask = np.zeros(shape=(self.width, self.height), dtype=bool)
  172. # Compute the total grid size
  173. width_px = self.width * tile_size
  174. height_px = self.height * tile_size
  175. img = np.zeros(shape=(height_px, width_px, 3), dtype=np.uint8)
  176. # Render the grid
  177. for j in range(0, self.height):
  178. for i in range(0, self.width):
  179. cell = self.get(i, j)
  180. agent_here = np.array_equal(agent_pos, (i, j))
  181. assert highlight_mask is not None
  182. tile_img = Grid.render_tile(
  183. cell,
  184. agent_dir=agent_dir if agent_here else None,
  185. highlight=highlight_mask[i, j],
  186. tile_size=tile_size,
  187. )
  188. ymin = j * tile_size
  189. ymax = (j + 1) * tile_size
  190. xmin = i * tile_size
  191. xmax = (i + 1) * tile_size
  192. img[ymin:ymax, xmin:xmax, :] = tile_img
  193. return img
  194. def encode(self, vis_mask: np.ndarray | None = None) -> np.ndarray:
  195. """
  196. Produce a compact numpy encoding of the grid
  197. """
  198. if vis_mask is None:
  199. vis_mask = np.ones((self.width, self.height), dtype=bool)
  200. array = np.zeros((self.width, self.height, 3), dtype="uint8")
  201. for i in range(self.width):
  202. for j in range(self.height):
  203. assert vis_mask is not None
  204. if vis_mask[i, j]:
  205. v = self.get(i, j)
  206. if v is None:
  207. array[i, j, 0] = OBJECT_TO_IDX["empty"]
  208. array[i, j, 1] = 0
  209. array[i, j, 2] = 0
  210. else:
  211. array[i, j, :] = v.encode()
  212. return array
  213. @staticmethod
  214. def decode(array: np.ndarray) -> tuple[Grid, np.ndarray]:
  215. """
  216. Decode an array grid encoding back into a grid
  217. """
  218. width, height, channels = array.shape
  219. assert channels == 3
  220. vis_mask = np.ones(shape=(width, height), dtype=bool)
  221. grid = Grid(width, height)
  222. for i in range(width):
  223. for j in range(height):
  224. type_idx, color_idx, state = array[i, j]
  225. v = WorldObj.decode(type_idx, color_idx, state)
  226. grid.set(i, j, v)
  227. vis_mask[i, j] = type_idx != OBJECT_TO_IDX["unseen"]
  228. return grid, vis_mask
  229. def process_vis(self, agent_pos: tuple[int, int]) -> np.ndarray:
  230. mask = np.zeros(shape=(self.width, self.height), dtype=bool)
  231. mask[agent_pos[0], agent_pos[1]] = True
  232. for j in reversed(range(0, self.height)):
  233. for i in range(0, self.width - 1):
  234. if not mask[i, j]:
  235. continue
  236. cell = self.get(i, j)
  237. if cell and not cell.see_behind():
  238. continue
  239. mask[i + 1, j] = True
  240. if j > 0:
  241. mask[i + 1, j - 1] = True
  242. mask[i, j - 1] = True
  243. for i in reversed(range(1, self.width)):
  244. if not mask[i, j]:
  245. continue
  246. cell = self.get(i, j)
  247. if cell and not cell.see_behind():
  248. continue
  249. mask[i - 1, j] = True
  250. if j > 0:
  251. mask[i - 1, j - 1] = True
  252. mask[i, j - 1] = True
  253. for j in range(0, self.height):
  254. for i in range(0, self.width):
  255. if not mask[i, j]:
  256. self.set(i, j, None)
  257. return mask