patterns.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. "Extract patterns from grids of tiles. Implementation based on https://github.com/ikarth/wfc_2019f"
  2. from __future__ import annotations
  3. import logging
  4. from collections import Counter
  5. from typing import Any, Mapping
  6. import numpy as np
  7. from numpy.typing import NDArray
  8. from minigrid.envs.wfc.wfclogic.utilities import hash_downto
  9. logger = logging.getLogger(__name__)
  10. def unique_patterns_2d(
  11. agrid: NDArray[np.int64], ksize: int, periodic_input: bool
  12. ) -> tuple[NDArray[np.int64], NDArray[np.int64], NDArray[np.int64]]:
  13. assert ksize >= 1
  14. if periodic_input:
  15. agrid = np.pad(
  16. agrid,
  17. ((0, ksize - 1), (0, ksize - 1), *(((0, 0),) * (len(agrid.shape) - 2))),
  18. mode="wrap",
  19. )
  20. else:
  21. # TODO: implement non-wrapped image handling
  22. # a = np.pad(a, ((0,k-1),(0,k-1),*(((0,0),)*(len(a.shape)-2))), mode='constant', constant_values=None)
  23. agrid = np.pad(
  24. agrid,
  25. ((0, ksize - 1), (0, ksize - 1), *(((0, 0),) * (len(agrid.shape) - 2))),
  26. mode="wrap",
  27. )
  28. patches: NDArray[np.int64] = np.lib.stride_tricks.as_strided(
  29. agrid,
  30. (
  31. agrid.shape[0] - ksize + 1,
  32. agrid.shape[1] - ksize + 1,
  33. ksize,
  34. ksize,
  35. *agrid.shape[2:],
  36. ),
  37. agrid.strides[:2] + agrid.strides[:2] + agrid.strides[2:],
  38. writeable=False,
  39. )
  40. patch_codes = hash_downto(patches, 2)
  41. uc, ui = np.unique(patch_codes, return_index=True)
  42. locs = np.unravel_index(ui, patch_codes.shape)
  43. up: NDArray[np.int64] = patches[locs[0], locs[1]]
  44. ids: NDArray[np.int64] = np.vectorize(
  45. {code: ind for ind, code in enumerate(uc)}.get
  46. )(patch_codes)
  47. return ids, up, patch_codes
  48. def unique_patterns_brute_force(grid, size, periodic_input):
  49. padded_grid = np.pad(
  50. grid,
  51. ((0, size - 1), (0, size - 1), *(((0, 0),) * (len(grid.shape) - 2))),
  52. mode="wrap",
  53. )
  54. patches = []
  55. for x in range(grid.shape[0]):
  56. row_patches = []
  57. for y in range(grid.shape[1]):
  58. row_patches.append(
  59. np.ndarray.tolist(padded_grid[x : x + size, y : y + size])
  60. )
  61. patches.append(row_patches)
  62. patches = np.array(patches)
  63. patch_codes = hash_downto(patches, 2)
  64. uc, ui = np.unique(patch_codes, return_index=True)
  65. locs = np.unravel_index(ui, patch_codes.shape)
  66. up = patches[locs[0], locs[1]]
  67. ids = np.vectorize({c: i for i, c in enumerate(uc)}.get)(patch_codes)
  68. return ids, up
  69. def make_pattern_catalog(
  70. tile_grid: NDArray[np.int64], pattern_width: int, input_is_periodic: bool = True
  71. ) -> tuple[dict[int, NDArray[np.int64]], Counter, NDArray[np.int64], NDArray[np.int64]]:
  72. """Returns a pattern catalog (dictionary of pattern hashes to constituent tiles),
  73. an ordered list of pattern weights, and an ordered list of pattern contents."""
  74. _patterns_in_grid, pattern_contents_list, patch_codes = unique_patterns_2d(
  75. tile_grid, pattern_width, input_is_periodic
  76. )
  77. dict_of_pattern_contents: dict[int, NDArray[np.int64]] = {}
  78. for pat_idx in range(pattern_contents_list.shape[0]):
  79. p_hash = hash_downto(pattern_contents_list[pat_idx], 0)
  80. dict_of_pattern_contents.update({p_hash.item(): pattern_contents_list[pat_idx]})
  81. pattern_frequency = Counter(hash_downto(pattern_contents_list, 1))
  82. return (
  83. dict_of_pattern_contents,
  84. pattern_frequency,
  85. hash_downto(pattern_contents_list, 1),
  86. patch_codes,
  87. )
  88. def identity_grid(grid):
  89. """Do nothing to the grid"""
  90. # return np.array([[7,5,5,5],[5,0,0,0],[5,0,1,0],[5,0,0,0]])
  91. return grid
  92. def reflect_grid(grid):
  93. """Reflect the grid left/right"""
  94. return np.fliplr(grid)
  95. def rotate_grid(grid):
  96. """Rotate the grid"""
  97. return np.rot90(grid, axes=(1, 0))
  98. def make_pattern_catalog_with_rotations(
  99. tile_grid: NDArray[np.int64],
  100. pattern_width: int,
  101. rotations: int = 7,
  102. input_is_periodic: bool = True,
  103. ) -> tuple[dict[int, NDArray[np.int64]], Counter, NDArray[np.int64], NDArray[np.int64]]:
  104. rotated_tile_grid = tile_grid.copy()
  105. merged_dict_of_pattern_contents: dict[int, NDArray[np.int64]] = {}
  106. merged_pattern_frequency: Counter = Counter()
  107. merged_pattern_contents_list: NDArray[np.int64] | None = None
  108. merged_patch_codes: NDArray[np.int64] | None = None
  109. def _make_catalog() -> None:
  110. nonlocal rotated_tile_grid, merged_dict_of_pattern_contents, merged_pattern_contents_list, merged_pattern_frequency, merged_patch_codes
  111. (
  112. dict_of_pattern_contents,
  113. pattern_frequency,
  114. pattern_contents_list,
  115. patch_codes,
  116. ) = make_pattern_catalog(rotated_tile_grid, pattern_width, input_is_periodic)
  117. merged_dict_of_pattern_contents.update(dict_of_pattern_contents)
  118. merged_pattern_frequency.update(pattern_frequency)
  119. if merged_pattern_contents_list is None:
  120. merged_pattern_contents_list = pattern_contents_list.copy()
  121. else:
  122. merged_pattern_contents_list = np.unique(
  123. np.concatenate((merged_pattern_contents_list, pattern_contents_list))
  124. )
  125. if merged_patch_codes is None:
  126. merged_patch_codes = patch_codes.copy()
  127. counter = 0
  128. grid_ops = [
  129. identity_grid,
  130. reflect_grid,
  131. rotate_grid,
  132. reflect_grid,
  133. rotate_grid,
  134. reflect_grid,
  135. rotate_grid,
  136. reflect_grid,
  137. ]
  138. while counter <= (rotations):
  139. # logger.debug(rotated_tile_grid.shape)
  140. # logger.debug(np.array_equiv(reflect_grid(rotated_tile_grid.copy()), rotate_grid(rotated_tile_grid.copy())))
  141. # logger.debug(counter)
  142. # logger.debug(grid_ops[counter].__name__)
  143. rotated_tile_grid = grid_ops[counter](rotated_tile_grid.copy())
  144. # logger.debug(rotated_tile_grid)
  145. # logger.debug("---")
  146. _make_catalog()
  147. counter += 1
  148. # assert False
  149. assert merged_pattern_contents_list is not None
  150. assert merged_patch_codes is not None
  151. return (
  152. merged_dict_of_pattern_contents,
  153. merged_pattern_frequency,
  154. merged_pattern_contents_list,
  155. merged_patch_codes,
  156. )
  157. def pattern_grid_to_tiles(
  158. pattern_grid: NDArray[np.int64], pattern_catalog: Mapping[int, NDArray[np.int64]]
  159. ) -> NDArray[np.int64]:
  160. anchor_x = 0
  161. anchor_y = 0
  162. def pattern_to_tile(pattern: int) -> Any:
  163. # if isinstance(pattern, list):
  164. # ptrns = []
  165. # for p in pattern:
  166. # logger.debug(p)
  167. # ptrns.push(pattern_to_tile(p))
  168. # logger.debug(ptrns)
  169. # assert False
  170. # return ptrns
  171. return pattern_catalog[pattern][anchor_x][anchor_y]
  172. return np.vectorize(pattern_to_tile)(pattern_grid)