grid.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. import math
  2. from typing import Any, List, Optional, Tuple, Type
  3. import numpy as np
  4. from minigrid.core.constants import OBJECT_TO_IDX, TILE_PIXELS
  5. from minigrid.core.world_object import Wall, WorldObj
  6. from minigrid.utils.rendering import (
  7. downsample,
  8. fill_coords,
  9. highlight_img,
  10. point_in_rect,
  11. point_in_triangle,
  12. rotate_fn,
  13. )
  14. class Grid:
  15. """
  16. Represent a grid and operations on it
  17. """
  18. # Static cache of pre-renderer tiles
  19. tile_cache = {}
  20. def __init__(self, width: int, height: int):
  21. assert width >= 3
  22. assert height >= 3
  23. self.width: int = width
  24. self.height: int = height
  25. self.grid: List[Optional[WorldObj]] = [None] * width * height
  26. def __contains__(self, key: Any) -> bool:
  27. if isinstance(key, WorldObj):
  28. for e in self.grid:
  29. if e is key:
  30. return True
  31. elif isinstance(key, tuple):
  32. for e in self.grid:
  33. if e is None:
  34. continue
  35. if (e.color, e.type) == key:
  36. return True
  37. if key[0] is None and key[1] == e.type:
  38. return True
  39. return False
  40. def __eq__(self, other: "Grid") -> bool:
  41. grid1 = self.encode()
  42. grid2 = other.encode()
  43. return np.array_equal(grid2, grid1)
  44. def __ne__(self, other: "Grid") -> bool:
  45. return not self == other
  46. def copy(self) -> "Grid":
  47. from copy import deepcopy
  48. return deepcopy(self)
  49. def set(self, i: int, j: int, v: Optional[WorldObj]):
  50. assert 0 <= i < self.width
  51. assert 0 <= j < self.height
  52. self.grid[j * self.width + i] = v
  53. def get(self, i: int, j: int) -> Optional[WorldObj]:
  54. assert 0 <= i < self.width
  55. assert 0 <= j < self.height
  56. assert self.grid is not None
  57. return self.grid[j * self.width + i]
  58. def horz_wall(
  59. self,
  60. x: int,
  61. y: int,
  62. length: Optional[int] = None,
  63. obj_type: Type[WorldObj] = Wall,
  64. ):
  65. if length is None:
  66. length = self.width - x
  67. for i in range(0, length):
  68. self.set(x + i, y, obj_type())
  69. def vert_wall(
  70. self,
  71. x: int,
  72. y: int,
  73. length: Optional[int] = None,
  74. obj_type: Type[WorldObj] = Wall,
  75. ):
  76. if length is None:
  77. length = self.height - y
  78. for j in range(0, length):
  79. self.set(x, y + j, obj_type())
  80. def wall_rect(self, x: int, y: int, w: int, h: int):
  81. self.horz_wall(x, y, w)
  82. self.horz_wall(x, y + h - 1, w)
  83. self.vert_wall(x, y, h)
  84. self.vert_wall(x + w - 1, y, h)
  85. def rotate_left(self) -> "Grid":
  86. """
  87. Rotate the grid to the left (counter-clockwise)
  88. """
  89. grid = Grid(self.height, self.width)
  90. for i in range(self.width):
  91. for j in range(self.height):
  92. v = self.get(i, j)
  93. grid.set(j, grid.height - 1 - i, v)
  94. return grid
  95. def slice(self, topX: int, topY: int, width: int, height: int) -> "Grid":
  96. """
  97. Get a subset of the grid
  98. """
  99. grid = Grid(width, height)
  100. for j in range(0, height):
  101. for i in range(0, width):
  102. x = topX + i
  103. y = topY + j
  104. if 0 <= x < self.width and 0 <= y < self.height:
  105. v = self.get(x, y)
  106. else:
  107. v = Wall()
  108. grid.set(i, j, v)
  109. return grid
  110. @classmethod
  111. def render_tile(
  112. cls,
  113. obj: WorldObj,
  114. agent_dir: Optional[int] = None,
  115. highlight: bool = False,
  116. tile_size: int = TILE_PIXELS,
  117. subdivs: int = 3,
  118. ) -> np.ndarray:
  119. """
  120. Render a tile and cache the result
  121. """
  122. # Hash map lookup key for the cache
  123. key = (agent_dir, highlight, tile_size)
  124. key = obj.encode() + key if obj else key
  125. if key in cls.tile_cache:
  126. return cls.tile_cache[key]
  127. img = np.zeros(
  128. shape=(tile_size * subdivs, tile_size * subdivs, 3), dtype=np.uint8
  129. )
  130. # Draw the grid lines (top and left edges)
  131. fill_coords(img, point_in_rect(0, 0.031, 0, 1), (100, 100, 100))
  132. fill_coords(img, point_in_rect(0, 1, 0, 0.031), (100, 100, 100))
  133. if obj is not None:
  134. obj.render(img)
  135. # Overlay the agent on top
  136. if agent_dir is not None:
  137. tri_fn = point_in_triangle(
  138. (0.12, 0.19),
  139. (0.87, 0.50),
  140. (0.12, 0.81),
  141. )
  142. # Rotate the agent based on its direction
  143. tri_fn = rotate_fn(tri_fn, cx=0.5, cy=0.5, theta=0.5 * math.pi * agent_dir)
  144. fill_coords(img, tri_fn, (255, 0, 0))
  145. # Highlight the cell if needed
  146. if highlight:
  147. highlight_img(img)
  148. # Downsample the image to perform supersampling/anti-aliasing
  149. img = downsample(img, subdivs)
  150. # Cache the rendered tile
  151. cls.tile_cache[key] = img
  152. return img
  153. def render(
  154. self,
  155. tile_size: int,
  156. agent_pos: Tuple[int, int],
  157. agent_dir: Optional[int] = None,
  158. highlight_mask: Optional[np.ndarray] = None,
  159. ) -> np.ndarray:
  160. """
  161. Render this grid at a given scale
  162. :param r: target renderer object
  163. :param tile_size: tile size in pixels
  164. """
  165. if highlight_mask is None:
  166. highlight_mask = np.zeros(shape=(self.width, self.height), dtype=bool)
  167. # Compute the total grid size
  168. width_px = self.width * tile_size
  169. height_px = self.height * tile_size
  170. img = np.zeros(shape=(height_px, width_px, 3), dtype=np.uint8)
  171. # Render the grid
  172. for j in range(0, self.height):
  173. for i in range(0, self.width):
  174. cell = self.get(i, j)
  175. agent_here = np.array_equal(agent_pos, (i, j))
  176. assert highlight_mask is not None
  177. tile_img = Grid.render_tile(
  178. cell,
  179. agent_dir=agent_dir if agent_here else None,
  180. highlight=highlight_mask[i, j],
  181. tile_size=tile_size,
  182. )
  183. ymin = j * tile_size
  184. ymax = (j + 1) * tile_size
  185. xmin = i * tile_size
  186. xmax = (i + 1) * tile_size
  187. img[ymin:ymax, xmin:xmax, :] = tile_img
  188. return img
  189. def encode(self, vis_mask: Optional[np.ndarray] = None) -> np.ndarray:
  190. """
  191. Produce a compact numpy encoding of the grid
  192. """
  193. if vis_mask is None:
  194. vis_mask = np.ones((self.width, self.height), dtype=bool)
  195. array = np.zeros((self.width, self.height, 3), dtype="uint8")
  196. for i in range(self.width):
  197. for j in range(self.height):
  198. assert vis_mask is not None
  199. if vis_mask[i, j]:
  200. v = self.get(i, j)
  201. if v is None:
  202. array[i, j, 0] = OBJECT_TO_IDX["empty"]
  203. array[i, j, 1] = 0
  204. array[i, j, 2] = 0
  205. else:
  206. array[i, j, :] = v.encode()
  207. return array
  208. @staticmethod
  209. def decode(array: np.ndarray) -> Tuple["Grid", np.ndarray]:
  210. """
  211. Decode an array grid encoding back into a grid
  212. """
  213. width, height, channels = array.shape
  214. assert channels == 3
  215. vis_mask = np.ones(shape=(width, height), dtype=bool)
  216. grid = Grid(width, height)
  217. for i in range(width):
  218. for j in range(height):
  219. type_idx, color_idx, state = array[i, j]
  220. v = WorldObj.decode(type_idx, color_idx, state)
  221. grid.set(i, j, v)
  222. vis_mask[i, j] = type_idx != OBJECT_TO_IDX["unseen"]
  223. return grid, vis_mask
  224. def process_vis(self, agent_pos: Tuple[int, int]) -> np.ndarray:
  225. mask = np.zeros(shape=(self.width, self.height), dtype=bool)
  226. mask[agent_pos[0], agent_pos[1]] = True
  227. for j in reversed(range(0, self.height)):
  228. for i in range(0, self.width - 1):
  229. if not mask[i, j]:
  230. continue
  231. cell = self.get(i, j)
  232. if cell and not cell.see_behind():
  233. continue
  234. mask[i + 1, j] = True
  235. if j > 0:
  236. mask[i + 1, j - 1] = True
  237. mask[i, j - 1] = True
  238. for i in reversed(range(1, self.width)):
  239. if not mask[i, j]:
  240. continue
  241. cell = self.get(i, j)
  242. if cell and not cell.see_behind():
  243. continue
  244. mask[i - 1, j] = True
  245. if j > 0:
  246. mask[i - 1, j - 1] = True
  247. mask[i, j - 1] = True
  248. for j in range(0, self.height):
  249. for i in range(0, self.width):
  250. if not mask[i, j]:
  251. self.set(i, j, None)
  252. return mask