grid.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. import math
  2. import numpy as np
  3. from minigrid.core.constants import OBJECT_TO_IDX, TILE_PIXELS
  4. from minigrid.core.world_object import Wall, WorldObj
  5. from minigrid.utils.rendering import (
  6. downsample,
  7. fill_coords,
  8. highlight_img,
  9. point_in_rect,
  10. point_in_triangle,
  11. rotate_fn,
  12. )
  13. class Grid:
  14. """
  15. Represent a grid and operations on it
  16. """
  17. # Static cache of pre-renderer tiles
  18. tile_cache = {}
  19. def __init__(self, width, height):
  20. assert width >= 3
  21. assert height >= 3
  22. self.width = width
  23. self.height = height
  24. self.grid = [None] * width * height
  25. def __contains__(self, key):
  26. if isinstance(key, WorldObj):
  27. for e in self.grid:
  28. if e is key:
  29. return True
  30. elif isinstance(key, tuple):
  31. for e in self.grid:
  32. if e is None:
  33. continue
  34. if (e.color, e.type) == key:
  35. return True
  36. if key[0] is None and key[1] == e.type:
  37. return True
  38. return False
  39. def __eq__(self, other):
  40. grid1 = self.encode()
  41. grid2 = other.encode()
  42. return np.array_equal(grid2, grid1)
  43. def __ne__(self, other):
  44. return not self == other
  45. def copy(self):
  46. from copy import deepcopy
  47. return deepcopy(self)
  48. def set(self, i, j, v):
  49. assert i >= 0 and i < self.width
  50. assert j >= 0 and j < self.height
  51. self.grid[j * self.width + i] = v
  52. def get(self, i, j):
  53. assert i >= 0 and i < self.width
  54. assert j >= 0 and j < self.height
  55. return self.grid[j * self.width + i]
  56. def horz_wall(self, x, y, length=None, obj_type=Wall):
  57. if length is None:
  58. length = self.width - x
  59. for i in range(0, length):
  60. self.set(x + i, y, obj_type())
  61. def vert_wall(self, x, y, length=None, obj_type=Wall):
  62. if length is None:
  63. length = self.height - y
  64. for j in range(0, length):
  65. self.set(x, y + j, obj_type())
  66. def wall_rect(self, x, y, w, h):
  67. self.horz_wall(x, y, w)
  68. self.horz_wall(x, y + h - 1, w)
  69. self.vert_wall(x, y, h)
  70. self.vert_wall(x + w - 1, y, h)
  71. def rotate_left(self):
  72. """
  73. Rotate the grid to the left (counter-clockwise)
  74. """
  75. grid = Grid(self.height, self.width)
  76. for i in range(self.width):
  77. for j in range(self.height):
  78. v = self.get(i, j)
  79. grid.set(j, grid.height - 1 - i, v)
  80. return grid
  81. def slice(self, topX, topY, width, height):
  82. """
  83. Get a subset of the grid
  84. """
  85. grid = Grid(width, height)
  86. for j in range(0, height):
  87. for i in range(0, width):
  88. x = topX + i
  89. y = topY + j
  90. if x >= 0 and x < self.width and y >= 0 and y < self.height:
  91. v = self.get(x, y)
  92. else:
  93. v = Wall()
  94. grid.set(i, j, v)
  95. return grid
  96. @classmethod
  97. def render_tile(
  98. cls, obj, agent_dir=None, highlight=False, tile_size=TILE_PIXELS, subdivs=3
  99. ):
  100. """
  101. Render a tile and cache the result
  102. """
  103. # Hash map lookup key for the cache
  104. key = (agent_dir, highlight, tile_size)
  105. key = obj.encode() + key if obj else key
  106. if key in cls.tile_cache:
  107. return cls.tile_cache[key]
  108. img = np.zeros(
  109. shape=(tile_size * subdivs, tile_size * subdivs, 3), dtype=np.uint8
  110. )
  111. # Draw the grid lines (top and left edges)
  112. fill_coords(img, point_in_rect(0, 0.031, 0, 1), (100, 100, 100))
  113. fill_coords(img, point_in_rect(0, 1, 0, 0.031), (100, 100, 100))
  114. if obj is not None:
  115. obj.render(img)
  116. # Overlay the agent on top
  117. if agent_dir is not None:
  118. tri_fn = point_in_triangle(
  119. (0.12, 0.19),
  120. (0.87, 0.50),
  121. (0.12, 0.81),
  122. )
  123. # Rotate the agent based on its direction
  124. tri_fn = rotate_fn(tri_fn, cx=0.5, cy=0.5, theta=0.5 * math.pi * agent_dir)
  125. fill_coords(img, tri_fn, (255, 0, 0))
  126. # Highlight the cell if needed
  127. if highlight:
  128. highlight_img(img)
  129. # Downsample the image to perform supersampling/anti-aliasing
  130. img = downsample(img, subdivs)
  131. # Cache the rendered tile
  132. cls.tile_cache[key] = img
  133. return img
  134. def render(self, tile_size, agent_pos, agent_dir=None, highlight_mask=None):
  135. """
  136. Render this grid at a given scale
  137. :param r: target renderer object
  138. :param tile_size: tile size in pixels
  139. """
  140. if highlight_mask is None:
  141. highlight_mask = np.zeros(shape=(self.width, self.height), dtype=bool)
  142. # Compute the total grid size
  143. width_px = self.width * tile_size
  144. height_px = self.height * tile_size
  145. img = np.zeros(shape=(height_px, width_px, 3), dtype=np.uint8)
  146. # Render the grid
  147. for j in range(0, self.height):
  148. for i in range(0, self.width):
  149. cell = self.get(i, j)
  150. agent_here = np.array_equal(agent_pos, (i, j))
  151. tile_img = Grid.render_tile(
  152. cell,
  153. agent_dir=agent_dir if agent_here else None,
  154. highlight=highlight_mask[i, j],
  155. tile_size=tile_size,
  156. )
  157. ymin = j * tile_size
  158. ymax = (j + 1) * tile_size
  159. xmin = i * tile_size
  160. xmax = (i + 1) * tile_size
  161. img[ymin:ymax, xmin:xmax, :] = tile_img
  162. return img
  163. def encode(self, vis_mask=None):
  164. """
  165. Produce a compact numpy encoding of the grid
  166. """
  167. if vis_mask is None:
  168. vis_mask = np.ones((self.width, self.height), dtype=bool)
  169. array = np.zeros((self.width, self.height, 3), dtype="uint8")
  170. for i in range(self.width):
  171. for j in range(self.height):
  172. if vis_mask[i, j]:
  173. v = self.get(i, j)
  174. if v is None:
  175. array[i, j, 0] = OBJECT_TO_IDX["empty"]
  176. array[i, j, 1] = 0
  177. array[i, j, 2] = 0
  178. else:
  179. array[i, j, :] = v.encode()
  180. return array
  181. @staticmethod
  182. def decode(array):
  183. """
  184. Decode an array grid encoding back into a grid
  185. """
  186. width, height, channels = array.shape
  187. assert channels == 3
  188. vis_mask = np.ones(shape=(width, height), dtype=bool)
  189. grid = Grid(width, height)
  190. for i in range(width):
  191. for j in range(height):
  192. type_idx, color_idx, state = array[i, j]
  193. v = WorldObj.decode(type_idx, color_idx, state)
  194. grid.set(i, j, v)
  195. vis_mask[i, j] = type_idx != OBJECT_TO_IDX["unseen"]
  196. return grid, vis_mask
  197. def process_vis(self, agent_pos):
  198. mask = np.zeros(shape=(self.width, self.height), dtype=bool)
  199. mask[agent_pos[0], agent_pos[1]] = True
  200. for j in reversed(range(0, self.height)):
  201. for i in range(0, self.width - 1):
  202. if not mask[i, j]:
  203. continue
  204. cell = self.get(i, j)
  205. if cell and not cell.see_behind():
  206. continue
  207. mask[i + 1, j] = True
  208. if j > 0:
  209. mask[i + 1, j - 1] = True
  210. mask[i, j - 1] = True
  211. for i in reversed(range(1, self.width)):
  212. if not mask[i, j]:
  213. continue
  214. cell = self.get(i, j)
  215. if cell and not cell.see_behind():
  216. continue
  217. mask[i - 1, j] = True
  218. if j > 0:
  219. mask[i - 1, j - 1] = True
  220. mask[i, j - 1] = True
  221. for j in range(0, self.height):
  222. for i in range(0, self.width):
  223. if not mask[i, j]:
  224. self.set(i, j, None)
  225. return mask