solver.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532
  1. """Wave Function Collapse solver. Implementation based on https://github.com/ikarth/wfc_2019f"""
  2. from __future__ import annotations
  3. import itertools
  4. import logging
  5. import math
  6. from typing import Any, Callable, Collection, Iterable, Iterator, Mapping, TypeVar
  7. # from scipy import sparse # type: ignore
  8. import numpy
  9. import numpy as np
  10. from numpy.typing import NBitBase, NDArray
  11. logger = logging.getLogger(__name__)
  12. T = TypeVar("T", bound=NBitBase)
  13. class Contradiction(Exception):
  14. """Solving could not proceed without backtracking/restarting."""
  15. pass
  16. class TimedOut(Exception):
  17. """Solve timed out."""
  18. pass
  19. class StopEarly(Exception):
  20. """Aborting solve early."""
  21. pass
  22. class Solver:
  23. """WFC Solver which can hold wave and backtracking state."""
  24. def __init__(
  25. self,
  26. *,
  27. wave: NDArray[np.bool_],
  28. adj: Mapping[tuple[int, int], NDArray[numpy.bool_]],
  29. periodic: bool = False,
  30. backtracking: bool = False,
  31. on_backtrack: Callable[[], None] | None = None,
  32. on_choice: Callable[[int, int, int], None] | None = None,
  33. on_observe: Callable[[NDArray[numpy.bool_]], None] | None = None,
  34. on_propagate: Callable[[NDArray[numpy.bool_]], None] | None = None,
  35. check_feasible: Callable[[NDArray[numpy.bool_]], bool] | None = None,
  36. ) -> None:
  37. self.wave = wave
  38. self.adj = adj
  39. self.periodic = periodic
  40. self.backtracking = backtracking
  41. self.history: list[NDArray[np.bool_]] = [] # An undo history for backtracking.
  42. self.on_backtrack = on_backtrack
  43. self.on_choice = on_choice
  44. self.on_observe = on_observe
  45. self.on_propagate = on_propagate
  46. self.check_feasible = check_feasible
  47. @property
  48. def is_solved(self) -> bool:
  49. """Is True if the wave has been fully resolved."""
  50. return (
  51. self.wave.sum() == self.wave.shape[1] * self.wave.shape[2]
  52. and (self.wave.sum(axis=0) == 1).all()
  53. )
  54. def solve_next(
  55. self,
  56. location_heuristic: Callable[[NDArray[numpy.bool_]], tuple[int, int]],
  57. pattern_heuristic: Callable[[NDArray[np.bool_], NDArray[np.bool_]], int],
  58. ) -> bool:
  59. """Attempt to collapse one wave. Returns True if no more steps remain."""
  60. if self.is_solved:
  61. return True
  62. if self.check_feasible and not self.check_feasible(self.wave):
  63. raise Contradiction("Not feasible.")
  64. if self.backtracking:
  65. self.history.append(self.wave.copy())
  66. propagate(
  67. self.wave, self.adj, periodic=self.periodic, onPropagate=self.on_propagate
  68. )
  69. pattern, i, j = None, None, None
  70. try:
  71. pattern, i, j = observe(self.wave, location_heuristic, pattern_heuristic)
  72. if self.on_choice:
  73. self.on_choice(pattern, i, j)
  74. self.wave[:, i, j] = False
  75. self.wave[pattern, i, j] = True
  76. if self.on_observe:
  77. self.on_observe(self.wave)
  78. propagate(
  79. self.wave,
  80. self.adj,
  81. periodic=self.periodic,
  82. onPropagate=self.on_propagate,
  83. )
  84. return False # Assume there is remaining steps, if not then the next call will return True.
  85. except Contradiction:
  86. if not self.backtracking:
  87. raise
  88. if not self.history:
  89. raise Contradiction("Every permutation has been attempted.")
  90. if self.on_backtrack:
  91. self.on_backtrack()
  92. self.wave = self.history.pop()
  93. self.wave[pattern, i, j] = False
  94. return False
  95. def solve(
  96. self,
  97. location_heuristic: Callable[[NDArray[numpy.bool_]], tuple[int, int]],
  98. pattern_heuristic: Callable[[NDArray[np.bool_], NDArray[np.bool_]], int],
  99. ) -> NDArray[np.int64]:
  100. """Attempts to solve all waves and returns the solution."""
  101. while not self.solve_next(
  102. location_heuristic=location_heuristic, pattern_heuristic=pattern_heuristic
  103. ):
  104. pass
  105. return numpy.argmax(self.wave, axis=0)
  106. def makeWave(
  107. n: int, w: int, h: int, ground: Iterable[int] | None = None
  108. ) -> NDArray[numpy.bool_]:
  109. wave: NDArray[numpy.bool_] = numpy.ones((n, w, h), dtype=numpy.bool_)
  110. if ground is not None:
  111. wave[:, :, h - 1] = False
  112. for g in ground:
  113. wave[
  114. g,
  115. :,
  116. ] = False
  117. wave[g, :, h - 1] = True
  118. # logger.debug(wave)
  119. # for i in range(wave.shape[0]):
  120. # logger.debug(wave[i])
  121. return wave
  122. def makeAdj(
  123. adjLists: Mapping[tuple[int, int], Collection[Iterable[int]]]
  124. ) -> dict[tuple[int, int], NDArray[numpy.bool_]]:
  125. adjMatrices = {}
  126. # logger.debug(adjLists)
  127. num_patterns = len(list(adjLists.values())[0])
  128. for d in adjLists:
  129. m = numpy.zeros((num_patterns, num_patterns), dtype=bool)
  130. for i, js in enumerate(adjLists[d]):
  131. # logger.debug(js)
  132. for j in js:
  133. m[i, j] = 1
  134. # If scipy is available, use sparse matrices.
  135. # adjMatrices[d] = sparse.csr_matrix(m)
  136. adjMatrices[d] = m
  137. return adjMatrices
  138. ######################################
  139. # Location Heuristics
  140. def makeRandomLocationHeuristic(
  141. preferences: NDArray[np.floating[Any]],
  142. ) -> Callable[[NDArray[np.bool_]], tuple[int, int]]:
  143. def randomLocationHeuristic(wave: NDArray[np.bool_]) -> tuple[int, int]:
  144. unresolved_cell_mask = numpy.count_nonzero(wave, axis=0) > 1
  145. cell_weights = numpy.where(unresolved_cell_mask, preferences, numpy.inf)
  146. row, col = numpy.unravel_index(numpy.argmin(cell_weights), cell_weights.shape)
  147. return row.item(), col.item()
  148. return randomLocationHeuristic
  149. def makeEntropyLocationHeuristic(
  150. preferences: NDArray[np.floating[Any]],
  151. ) -> Callable[[NDArray[np.bool_]], tuple[int, int]]:
  152. def entropyLocationHeuristic(wave: NDArray[np.bool_]) -> tuple[int, int]:
  153. unresolved_cell_mask = numpy.count_nonzero(wave, axis=0) > 1
  154. cell_weights = numpy.where(
  155. unresolved_cell_mask,
  156. preferences + numpy.count_nonzero(wave, axis=0),
  157. numpy.inf,
  158. )
  159. row, col = numpy.unravel_index(numpy.argmin(cell_weights), cell_weights.shape)
  160. return row.item(), col.item()
  161. return entropyLocationHeuristic
  162. def makeAntiEntropyLocationHeuristic(
  163. preferences: NDArray[np.floating[Any]],
  164. ) -> Callable[[NDArray[np.bool_]], tuple[int, int]]:
  165. def antiEntropyLocationHeuristic(wave: NDArray[np.bool_]) -> tuple[int, int]:
  166. unresolved_cell_mask = numpy.count_nonzero(wave, axis=0) > 1
  167. cell_weights = numpy.where(
  168. unresolved_cell_mask,
  169. preferences + numpy.count_nonzero(wave, axis=0),
  170. -numpy.inf,
  171. )
  172. row, col = numpy.unravel_index(numpy.argmax(cell_weights), cell_weights.shape)
  173. return row.item(), col.item()
  174. return antiEntropyLocationHeuristic
  175. def spiral_transforms() -> Iterator[tuple[int, int]]:
  176. for N in itertools.count(start=1):
  177. if N % 2 == 0:
  178. yield (0, 1) # right
  179. for _ in range(N):
  180. yield (1, 0) # down
  181. for _ in range(N):
  182. yield (0, -1) # left
  183. else:
  184. yield (0, -1) # left
  185. for _ in range(N):
  186. yield (-1, 0) # up
  187. for _ in range(N):
  188. yield (0, 1) # right
  189. def spiral_coords(x: int, y: int) -> Iterator[tuple[int, int]]:
  190. yield x, y
  191. for transform in spiral_transforms():
  192. x += transform[0]
  193. y += transform[1]
  194. yield x, y
  195. def fill_with_curve(
  196. arr: NDArray[np.floating[T]], curve_gen: Iterable[Iterable[int]]
  197. ) -> NDArray[np.floating[T]]:
  198. arr_len = numpy.prod(arr.shape)
  199. fill = 0
  200. for coord in curve_gen:
  201. # logger.debug(fill, idx, coord)
  202. if fill < arr_len:
  203. try:
  204. arr[tuple(coord)] = fill / arr_len
  205. fill += 1
  206. except IndexError:
  207. pass
  208. else:
  209. break
  210. # logger.debug(arr)
  211. return arr
  212. def makeSpiralLocationHeuristic(
  213. preferences: NDArray[np.floating[Any]],
  214. ) -> Callable[[NDArray[np.bool_]], tuple[int, int]]:
  215. # https://stackoverflow.com/a/23707273/5562922
  216. spiral_gen = (
  217. sc for sc in spiral_coords(preferences.shape[0] // 2, preferences.shape[1] // 2)
  218. )
  219. cell_order = fill_with_curve(preferences, spiral_gen)
  220. def spiralLocationHeuristic(wave: NDArray[np.bool_]) -> tuple[int, int]:
  221. unresolved_cell_mask = numpy.count_nonzero(wave, axis=0) > 1
  222. cell_weights = numpy.where(unresolved_cell_mask, cell_order, numpy.inf)
  223. row, col = numpy.unravel_index(numpy.argmin(cell_weights), cell_weights.shape)
  224. return row.item(), col.item()
  225. return spiralLocationHeuristic
  226. def makeHilbertLocationHeuristic(
  227. preferences: NDArray[np.floating[Any]],
  228. ) -> Callable[[NDArray[np.bool_]], tuple[int, int]]:
  229. from hilbertcurve.hilbertcurve import HilbertCurve # type: ignore
  230. curve_size = math.ceil(math.sqrt(max(preferences.shape[0], preferences.shape[1])))
  231. logger.debug(curve_size)
  232. curve_size = 4
  233. h_curve = HilbertCurve(curve_size, 2)
  234. h_coords = (h_curve.point_from_distance(i) for i in itertools.count())
  235. cell_order = fill_with_curve(preferences, h_coords)
  236. # logger.debug(cell_order)
  237. def hilbertLocationHeuristic(wave: NDArray[np.bool_]) -> tuple[int, int]:
  238. unresolved_cell_mask = numpy.count_nonzero(wave, axis=0) > 1
  239. cell_weights = numpy.where(unresolved_cell_mask, cell_order, numpy.inf)
  240. row, col = numpy.unravel_index(numpy.argmin(cell_weights), cell_weights.shape)
  241. return row.item(), col.item()
  242. return hilbertLocationHeuristic
  243. def simpleLocationHeuristic(wave: NDArray[np.bool_]) -> tuple[int, int]:
  244. unresolved_cell_mask = numpy.count_nonzero(wave, axis=0) > 1
  245. cell_weights = numpy.where(
  246. unresolved_cell_mask, numpy.count_nonzero(wave, axis=0), numpy.inf
  247. )
  248. row, col = numpy.unravel_index(numpy.argmin(cell_weights), cell_weights.shape)
  249. return row.item(), col.item()
  250. def lexicalLocationHeuristic(wave: NDArray[np.bool_]) -> tuple[int, int]:
  251. unresolved_cell_mask = numpy.count_nonzero(wave, axis=0) > 1
  252. cell_weights = numpy.where(unresolved_cell_mask, 1.0, numpy.inf)
  253. row, col = numpy.unravel_index(numpy.argmin(cell_weights), cell_weights.shape)
  254. return row.item(), col.item()
  255. #####################################
  256. # Pattern Heuristics
  257. def lexicalPatternHeuristic(weights: NDArray[np.bool_], wave: NDArray[np.bool_]) -> int:
  258. return numpy.nonzero(weights)[0][0].item()
  259. def makeWeightedPatternHeuristic(
  260. weights: NDArray[np.floating[Any]],
  261. np_random: numpy.random.Generator | None = None,
  262. ):
  263. num_of_patterns = len(weights)
  264. np_random: numpy.random.Generator = (
  265. numpy.random.default_rng() if np_random is None else np_random
  266. )
  267. def weightedPatternHeuristic(wave: NDArray[np.bool_], _: NDArray[np.bool_]) -> int:
  268. # TODO: there's maybe a faster, more controlled way to do this sampling...
  269. weighted_wave: NDArray[np.floating[Any]] = weights * wave
  270. weighted_wave /= weighted_wave.sum()
  271. result = np_random.choice(num_of_patterns, p=weighted_wave)
  272. return result
  273. return weightedPatternHeuristic
  274. def makeRarestPatternHeuristic(
  275. weights: NDArray[np.floating[Any]],
  276. np_random: numpy.random.Generator | None = None,
  277. ) -> Callable[[NDArray[np.bool_], NDArray[np.bool_]], int]:
  278. """Return a function that chooses the rarest (currently least-used) pattern."""
  279. np_random: numpy.random.Generator = (
  280. numpy.random.default_rng() if np_random is None else np_random
  281. )
  282. def weightedPatternHeuristic(
  283. wave: NDArray[np.bool_], total_wave: NDArray[np.bool_]
  284. ) -> int:
  285. logger.debug(total_wave.shape)
  286. # [logger.debug(e) for e in wave]
  287. wave_sums = numpy.sum(total_wave, (1, 2))
  288. # logger.debug(wave_sums)
  289. selected_pattern = np_random.choice(
  290. numpy.where(wave_sums == wave_sums.max())[0]
  291. )
  292. return selected_pattern
  293. return weightedPatternHeuristic
  294. def makeMostCommonPatternHeuristic(
  295. weights: NDArray[np.floating[Any]],
  296. np_random: numpy.random.Generator | None = None,
  297. ) -> Callable[[NDArray[np.bool_], NDArray[np.bool_]], int]:
  298. """Return a function that chooses the most common (currently most-used) pattern."""
  299. np_random: numpy.random.Generator = (
  300. numpy.random.default_rng() if np_random is None else np_random
  301. )
  302. def weightedPatternHeuristic(
  303. wave: NDArray[np.bool_], total_wave: NDArray[np.bool_]
  304. ) -> int:
  305. logger.debug(total_wave.shape)
  306. # [logger.debug(e) for e in wave]
  307. wave_sums = numpy.sum(total_wave, (1, 2))
  308. selected_pattern = np_random.choice(
  309. numpy.where(wave_sums == wave_sums.min())[0]
  310. )
  311. return selected_pattern
  312. return weightedPatternHeuristic
  313. def makeRandomPatternHeuristic(
  314. weights: NDArray[np.floating[Any]],
  315. np_random: numpy.random.Generator | None = None,
  316. ) -> Callable[[NDArray[np.bool_], NDArray[np.bool_]], int]:
  317. num_of_patterns = len(weights)
  318. np_random: numpy.random.Generator = (
  319. numpy.random.default_rng() if np_random is None else np_random
  320. )
  321. def randomPatternHeuristic(wave: NDArray[np.bool_], _: NDArray[np.bool_]) -> int:
  322. # TODO: there's maybe a faster, more controlled way to do this sampling...
  323. weighted_wave = 1.0 * wave
  324. weighted_wave /= weighted_wave.sum()
  325. result = np_random.choice(num_of_patterns, p=weighted_wave)
  326. return result
  327. return randomPatternHeuristic
  328. ######################################
  329. # Global Constraints
  330. def make_global_use_all_patterns() -> Callable[[NDArray[np.bool_]], bool]:
  331. def global_use_all_patterns(wave: NDArray[np.bool_]) -> bool:
  332. """Returns true if at least one instance of each pattern is still possible."""
  333. return numpy.all(numpy.any(wave, axis=(1, 2))).item()
  334. return global_use_all_patterns
  335. #####################################
  336. # Solver
  337. def propagate(
  338. wave: NDArray[np.bool_],
  339. adj: Mapping[tuple[int, int], NDArray[numpy.bool_]],
  340. periodic: bool = False,
  341. onPropagate: Callable[[NDArray[numpy.bool_]], None] | None = None,
  342. ) -> None:
  343. """Completely probagate any newly collapsed waves to all areas."""
  344. last_count = wave.sum()
  345. while True:
  346. supports = {}
  347. if periodic:
  348. padded = numpy.pad(wave, ((0, 0), (1, 1), (1, 1)), mode="wrap")
  349. else:
  350. padded = numpy.pad(
  351. wave, ((0, 0), (1, 1), (1, 1)), mode="constant", constant_values=True
  352. )
  353. # adj is the list of adjacencies. For each direction d in adjacency,
  354. # check which patterns are still valid...
  355. for d in adj:
  356. dx, dy = d
  357. # padded[] is a version of the adjacency matrix with the values wrapped around
  358. # shifted[] is the padded version with the values shifted over in one direction
  359. # because my code stores the directions as relative (x,y) coordinates, we can find
  360. # the adjacent cell for each direction by simply shifting the matrix in that direction,
  361. # which allows for arbitrary adjacency directions. This is somewhat excessive, but elegant.
  362. shifted = padded[
  363. :, 1 + dx : 1 + wave.shape[1] + dx, 1 + dy : 1 + wave.shape[2] + dy
  364. ]
  365. # logger.debug(f"shifted: {shifted.shape} | adj[d]: {adj[d].shape} | d: {d}")
  366. # raise StopEarly
  367. # supports[d] = numpy.einsum('pwh,pq->qwh', shifted, adj[d]) > 0
  368. # The adjacency matrix is a boolean matrix, indexed by the direction and the two patterns.
  369. # If the value for (direction, pattern1, pattern2) is True, then this is a valid adjacency.
  370. # This gives us a rapid way to compare: True is 1, False is 0, so multiplying the matrices
  371. # gives us the adjacency compatibility.
  372. supports[d] = (adj[d] @ shifted.reshape(shifted.shape[0], -1)).reshape(
  373. shifted.shape
  374. ) > 0
  375. # supports[d] = ( <- for each cell in the matrix
  376. # adj[d] <- the adjacency matrix [sliced by the direction d]
  377. # @ <- Matrix multiplication
  378. # shifted.reshape(shifted.shape[0], -1)) <- change the shape of the shifted matrix to 2-dimensions, to make the matrix multiplication easier
  379. # .reshape( <- reshape our matrix-multiplied result...
  380. # shifted.shape) <- ...to match the original shape of the shifted matrix
  381. # > 0 <- is not false
  382. # multiply the wave matrix by the support matrix to find which patterns are still in the domain
  383. for d in adj:
  384. wave *= supports[d]
  385. if wave.sum() == last_count:
  386. break # No changes since the last loop, changed waves have been fully propagated.
  387. last_count = wave.sum()
  388. if onPropagate:
  389. onPropagate(wave)
  390. if (wave.sum(axis=0) == 0).any():
  391. raise Contradiction("Wave is in a contradictory state and can not be solved.")
  392. def observe(
  393. wave: NDArray[np.bool_],
  394. locationHeuristic: Callable[[NDArray[np.bool_]], tuple[int, int]],
  395. patternHeuristic: Callable[[NDArray[np.bool_], NDArray[np.bool_]], int],
  396. ) -> tuple[int, int, int]:
  397. """Return the next best wave to collapse based on the provided heuristics."""
  398. i, j = locationHeuristic(wave)
  399. pattern = patternHeuristic(wave[:, i, j], wave)
  400. return pattern, i, j
  401. def run(
  402. wave: NDArray[np.bool_],
  403. adj: Mapping[tuple[int, int], NDArray[numpy.bool_]],
  404. locationHeuristic: Callable[[NDArray[numpy.bool_]], tuple[int, int]],
  405. patternHeuristic: Callable[[NDArray[np.bool_], NDArray[np.bool_]], int],
  406. periodic: bool = False,
  407. backtracking: bool = False,
  408. onBacktrack: Callable[[], None] | None = None,
  409. onChoice: Callable[[int, int, int], None] | None = None,
  410. onObserve: Callable[[NDArray[numpy.bool_]], None] | None = None,
  411. onPropagate: Callable[[NDArray[numpy.bool_]], None] | None = None,
  412. checkFeasible: Callable[[NDArray[numpy.bool_]], bool] | None = None,
  413. onFinal: Callable[[NDArray[numpy.bool_]], None] | None = None,
  414. depth: int = 0,
  415. depth_limit: int | None = None,
  416. ) -> NDArray[numpy.int64]:
  417. solver = Solver(
  418. wave=wave,
  419. adj=adj,
  420. periodic=periodic,
  421. backtracking=backtracking,
  422. on_backtrack=onBacktrack,
  423. on_choice=onChoice,
  424. on_observe=onObserve,
  425. on_propagate=onPropagate,
  426. check_feasible=checkFeasible,
  427. )
  428. while not solver.solve_next(
  429. location_heuristic=locationHeuristic, pattern_heuristic=patternHeuristic
  430. ):
  431. pass
  432. if onFinal:
  433. onFinal(solver.wave)
  434. return numpy.argmax(solver.wave, axis=0)