|
@@ -0,0 +1,530 @@
|
|
|
+"""Wave Function Collapse solver. Implementation based on https://github.com/ikarth/wfc_2019f"""
|
|
|
+from __future__ import annotations
|
|
|
+
|
|
|
+import itertools
|
|
|
+import logging
|
|
|
+import math
|
|
|
+from typing import Any, Callable, Collection, Iterable, Iterator, Mapping, TypeVar
|
|
|
+
|
|
|
+# from scipy import sparse # type: ignore
|
|
|
+import numpy
|
|
|
+import numpy as np
|
|
|
+from numpy.typing import NBitBase, NDArray
|
|
|
+
|
|
|
+logger = logging.getLogger(__name__)
|
|
|
+
|
|
|
+T = TypeVar("T", bound=NBitBase)
|
|
|
+
|
|
|
+
|
|
|
+class Contradiction(Exception):
|
|
|
+ """Solving could not proceed without backtracking/restarting."""
|
|
|
+
|
|
|
+ pass
|
|
|
+
|
|
|
+
|
|
|
+class TimedOut(Exception):
|
|
|
+ """Solve timed out."""
|
|
|
+
|
|
|
+ pass
|
|
|
+
|
|
|
+
|
|
|
+class StopEarly(Exception):
|
|
|
+ """Aborting solve early."""
|
|
|
+
|
|
|
+ pass
|
|
|
+
|
|
|
+
|
|
|
+class Solver:
|
|
|
+ """WFC Solver which can hold wave and backtracking state."""
|
|
|
+
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ *,
|
|
|
+ wave: NDArray[np.bool_],
|
|
|
+ adj: Mapping[tuple[int, int], NDArray[numpy.bool_]],
|
|
|
+ periodic: bool = False,
|
|
|
+ backtracking: bool = False,
|
|
|
+ on_backtrack: Callable[[], None] | None = None,
|
|
|
+ on_choice: Callable[[int, int, int], None] | None = None,
|
|
|
+ on_observe: Callable[[NDArray[numpy.bool_]], None] | None = None,
|
|
|
+ on_propagate: Callable[[NDArray[numpy.bool_]], None] | None = None,
|
|
|
+ check_feasible: Callable[[NDArray[numpy.bool_]], bool] | None = None,
|
|
|
+ ) -> None:
|
|
|
+ self.wave = wave
|
|
|
+ self.adj = adj
|
|
|
+ self.periodic = periodic
|
|
|
+ self.backtracking = backtracking
|
|
|
+ self.history: list[NDArray[np.bool_]] = [] # An undo history for backtracking.
|
|
|
+ self.on_backtrack = on_backtrack
|
|
|
+ self.on_choice = on_choice
|
|
|
+ self.on_observe = on_observe
|
|
|
+ self.on_propagate = on_propagate
|
|
|
+ self.check_feasible = check_feasible
|
|
|
+
|
|
|
+ @property
|
|
|
+ def is_solved(self) -> bool:
|
|
|
+ """Is True if the wave has been fully resolved."""
|
|
|
+ return (
|
|
|
+ self.wave.sum() == self.wave.shape[1] * self.wave.shape[2]
|
|
|
+ and (self.wave.sum(axis=0) == 1).all()
|
|
|
+ )
|
|
|
+
|
|
|
+ def solve_next(
|
|
|
+ self,
|
|
|
+ location_heuristic: Callable[[NDArray[numpy.bool_]], tuple[int, int]],
|
|
|
+ pattern_heuristic: Callable[[NDArray[np.bool_], NDArray[np.bool_]], int],
|
|
|
+ ) -> bool:
|
|
|
+ """Attempt to collapse one wave. Returns True if no more steps remain."""
|
|
|
+ if self.is_solved:
|
|
|
+ return True
|
|
|
+ if self.check_feasible and not self.check_feasible(self.wave):
|
|
|
+ raise Contradiction("Not feasible.")
|
|
|
+ if self.backtracking:
|
|
|
+ self.history.append(self.wave.copy())
|
|
|
+ propagate(
|
|
|
+ self.wave, self.adj, periodic=self.periodic, onPropagate=self.on_propagate
|
|
|
+ )
|
|
|
+ pattern, i, j = None, None, None
|
|
|
+ try:
|
|
|
+ pattern, i, j = observe(self.wave, location_heuristic, pattern_heuristic)
|
|
|
+ if self.on_choice:
|
|
|
+ self.on_choice(pattern, i, j)
|
|
|
+ self.wave[:, i, j] = False
|
|
|
+ self.wave[pattern, i, j] = True
|
|
|
+ if self.on_observe:
|
|
|
+ self.on_observe(self.wave)
|
|
|
+ propagate(
|
|
|
+ self.wave,
|
|
|
+ self.adj,
|
|
|
+ periodic=self.periodic,
|
|
|
+ onPropagate=self.on_propagate,
|
|
|
+ )
|
|
|
+ return False # Assume there is remaining steps, if not then the next call will return True.
|
|
|
+ except Contradiction:
|
|
|
+ if not self.backtracking:
|
|
|
+ raise
|
|
|
+ if not self.history:
|
|
|
+ raise Contradiction("Every permutation has been attempted.")
|
|
|
+ if self.on_backtrack:
|
|
|
+ self.on_backtrack()
|
|
|
+ self.wave = self.history.pop()
|
|
|
+ self.wave[pattern, i, j] = False
|
|
|
+ return False
|
|
|
+
|
|
|
+ def solve(
|
|
|
+ self,
|
|
|
+ location_heuristic: Callable[[NDArray[numpy.bool_]], tuple[int, int]],
|
|
|
+ pattern_heuristic: Callable[[NDArray[np.bool_], NDArray[np.bool_]], int],
|
|
|
+ ) -> NDArray[np.int64]:
|
|
|
+ """Attempts to solve all waves and returns the solution."""
|
|
|
+ while not self.solve_next(
|
|
|
+ location_heuristic=location_heuristic, pattern_heuristic=pattern_heuristic
|
|
|
+ ):
|
|
|
+ pass
|
|
|
+ return numpy.argmax(self.wave, axis=0)
|
|
|
+
|
|
|
+
|
|
|
+def makeWave(
|
|
|
+ n: int, w: int, h: int, ground: Iterable[int] | None = None
|
|
|
+) -> NDArray[numpy.bool_]:
|
|
|
+ wave: NDArray[numpy.bool_] = numpy.ones((n, w, h), dtype=numpy.bool_)
|
|
|
+ if ground is not None:
|
|
|
+ wave[:, :, h - 1] = False
|
|
|
+ for g in ground:
|
|
|
+ wave[
|
|
|
+ g,
|
|
|
+ :,
|
|
|
+ ] = False
|
|
|
+ wave[g, :, h - 1] = True
|
|
|
+ # logger.debug(wave)
|
|
|
+ # for i in range(wave.shape[0]):
|
|
|
+ # logger.debug(wave[i])
|
|
|
+ return wave
|
|
|
+
|
|
|
+
|
|
|
+def makeAdj(
|
|
|
+ adjLists: Mapping[tuple[int, int], Collection[Iterable[int]]]
|
|
|
+) -> dict[tuple[int, int], NDArray[numpy.bool_]]:
|
|
|
+ adjMatrices = {}
|
|
|
+ # logger.debug(adjLists)
|
|
|
+ num_patterns = len(list(adjLists.values())[0])
|
|
|
+ for d in adjLists:
|
|
|
+ m = numpy.zeros((num_patterns, num_patterns), dtype=bool)
|
|
|
+ for i, js in enumerate(adjLists[d]):
|
|
|
+ # logger.debug(js)
|
|
|
+ for j in js:
|
|
|
+ m[i, j] = 1
|
|
|
+ # If scipy is available, use sparse matrices.
|
|
|
+ # adjMatrices[d] = sparse.csr_matrix(m)
|
|
|
+ adjMatrices[d] = m
|
|
|
+ return adjMatrices
|
|
|
+
|
|
|
+
|
|
|
+######################################
|
|
|
+# Location Heuristics
|
|
|
+
|
|
|
+
|
|
|
+def makeRandomLocationHeuristic(
|
|
|
+ preferences: NDArray[np.floating[Any]],
|
|
|
+) -> Callable[[NDArray[np.bool_]], tuple[int, int]]:
|
|
|
+ def randomLocationHeuristic(wave: NDArray[np.bool_]) -> tuple[int, int]:
|
|
|
+ unresolved_cell_mask = numpy.count_nonzero(wave, axis=0) > 1
|
|
|
+ cell_weights = numpy.where(unresolved_cell_mask, preferences, numpy.inf)
|
|
|
+ row, col = numpy.unravel_index(numpy.argmin(cell_weights), cell_weights.shape)
|
|
|
+ return row.item(), col.item()
|
|
|
+
|
|
|
+ return randomLocationHeuristic
|
|
|
+
|
|
|
+
|
|
|
+def makeEntropyLocationHeuristic(
|
|
|
+ preferences: NDArray[np.floating[Any]],
|
|
|
+) -> Callable[[NDArray[np.bool_]], tuple[int, int]]:
|
|
|
+ def entropyLocationHeuristic(wave: NDArray[np.bool_]) -> tuple[int, int]:
|
|
|
+ unresolved_cell_mask = numpy.count_nonzero(wave, axis=0) > 1
|
|
|
+ cell_weights = numpy.where(
|
|
|
+ unresolved_cell_mask,
|
|
|
+ preferences + numpy.count_nonzero(wave, axis=0),
|
|
|
+ numpy.inf,
|
|
|
+ )
|
|
|
+ row, col = numpy.unravel_index(numpy.argmin(cell_weights), cell_weights.shape)
|
|
|
+ return row.item(), col.item()
|
|
|
+
|
|
|
+ return entropyLocationHeuristic
|
|
|
+
|
|
|
+
|
|
|
+def makeAntiEntropyLocationHeuristic(
|
|
|
+ preferences: NDArray[np.floating[Any]],
|
|
|
+) -> Callable[[NDArray[np.bool_]], tuple[int, int]]:
|
|
|
+ def antiEntropyLocationHeuristic(wave: NDArray[np.bool_]) -> tuple[int, int]:
|
|
|
+ unresolved_cell_mask = numpy.count_nonzero(wave, axis=0) > 1
|
|
|
+ cell_weights = numpy.where(
|
|
|
+ unresolved_cell_mask,
|
|
|
+ preferences + numpy.count_nonzero(wave, axis=0),
|
|
|
+ -numpy.inf,
|
|
|
+ )
|
|
|
+ row, col = numpy.unravel_index(numpy.argmax(cell_weights), cell_weights.shape)
|
|
|
+ return row.item(), col.item()
|
|
|
+
|
|
|
+ return antiEntropyLocationHeuristic
|
|
|
+
|
|
|
+
|
|
|
+def spiral_transforms() -> Iterator[tuple[int, int]]:
|
|
|
+ for N in itertools.count(start=1):
|
|
|
+ if N % 2 == 0:
|
|
|
+ yield (0, 1) # right
|
|
|
+ for _ in range(N):
|
|
|
+ yield (1, 0) # down
|
|
|
+ for _ in range(N):
|
|
|
+ yield (0, -1) # left
|
|
|
+ else:
|
|
|
+ yield (0, -1) # left
|
|
|
+ for _ in range(N):
|
|
|
+ yield (-1, 0) # up
|
|
|
+ for _ in range(N):
|
|
|
+ yield (0, 1) # right
|
|
|
+
|
|
|
+
|
|
|
+def spiral_coords(x: int, y: int) -> Iterator[tuple[int, int]]:
|
|
|
+ yield x, y
|
|
|
+ for transform in spiral_transforms():
|
|
|
+ x += transform[0]
|
|
|
+ y += transform[1]
|
|
|
+ yield x, y
|
|
|
+
|
|
|
+
|
|
|
+def fill_with_curve(
|
|
|
+ arr: NDArray[np.floating[T]], curve_gen: Iterable[Iterable[int]]
|
|
|
+) -> NDArray[np.floating[T]]:
|
|
|
+ arr_len = numpy.prod(arr.shape)
|
|
|
+ fill = 0
|
|
|
+ for coord in curve_gen:
|
|
|
+ # logger.debug(fill, idx, coord)
|
|
|
+ if fill < arr_len:
|
|
|
+ try:
|
|
|
+ arr[tuple(coord)] = fill / arr_len
|
|
|
+ fill += 1
|
|
|
+ except IndexError:
|
|
|
+ pass
|
|
|
+ else:
|
|
|
+ break
|
|
|
+ # logger.debug(arr)
|
|
|
+ return arr
|
|
|
+
|
|
|
+
|
|
|
+def makeSpiralLocationHeuristic(
|
|
|
+ preferences: NDArray[np.floating[Any]],
|
|
|
+) -> Callable[[NDArray[np.bool_]], tuple[int, int]]:
|
|
|
+ # https://stackoverflow.com/a/23707273/5562922
|
|
|
+
|
|
|
+ spiral_gen = (
|
|
|
+ sc for sc in spiral_coords(preferences.shape[0] // 2, preferences.shape[1] // 2)
|
|
|
+ )
|
|
|
+
|
|
|
+ cell_order = fill_with_curve(preferences, spiral_gen)
|
|
|
+
|
|
|
+ def spiralLocationHeuristic(wave: NDArray[np.bool_]) -> tuple[int, int]:
|
|
|
+ unresolved_cell_mask = numpy.count_nonzero(wave, axis=0) > 1
|
|
|
+ cell_weights = numpy.where(unresolved_cell_mask, cell_order, numpy.inf)
|
|
|
+ row, col = numpy.unravel_index(numpy.argmin(cell_weights), cell_weights.shape)
|
|
|
+ return row.item(), col.item()
|
|
|
+
|
|
|
+ return spiralLocationHeuristic
|
|
|
+
|
|
|
+
|
|
|
+def makeHilbertLocationHeuristic(
|
|
|
+ preferences: NDArray[np.floating[Any]],
|
|
|
+) -> Callable[[NDArray[np.bool_]], tuple[int, int]]:
|
|
|
+ from hilbertcurve.hilbertcurve import HilbertCurve # type: ignore
|
|
|
+
|
|
|
+ curve_size = math.ceil(math.sqrt(max(preferences.shape[0], preferences.shape[1])))
|
|
|
+ logger.debug(curve_size)
|
|
|
+ curve_size = 4
|
|
|
+ h_curve = HilbertCurve(curve_size, 2)
|
|
|
+ h_coords = (h_curve.point_from_distance(i) for i in itertools.count())
|
|
|
+ cell_order = fill_with_curve(preferences, h_coords)
|
|
|
+ # logger.debug(cell_order)
|
|
|
+
|
|
|
+ def hilbertLocationHeuristic(wave: NDArray[np.bool_]) -> tuple[int, int]:
|
|
|
+ unresolved_cell_mask = numpy.count_nonzero(wave, axis=0) > 1
|
|
|
+ cell_weights = numpy.where(unresolved_cell_mask, cell_order, numpy.inf)
|
|
|
+ row, col = numpy.unravel_index(numpy.argmin(cell_weights), cell_weights.shape)
|
|
|
+ return row.item(), col.item()
|
|
|
+
|
|
|
+ return hilbertLocationHeuristic
|
|
|
+
|
|
|
+
|
|
|
+def simpleLocationHeuristic(wave: NDArray[np.bool_]) -> tuple[int, int]:
|
|
|
+ unresolved_cell_mask = numpy.count_nonzero(wave, axis=0) > 1
|
|
|
+ cell_weights = numpy.where(
|
|
|
+ unresolved_cell_mask, numpy.count_nonzero(wave, axis=0), numpy.inf
|
|
|
+ )
|
|
|
+ row, col = numpy.unravel_index(numpy.argmin(cell_weights), cell_weights.shape)
|
|
|
+ return row.item(), col.item()
|
|
|
+
|
|
|
+
|
|
|
+def lexicalLocationHeuristic(wave: NDArray[np.bool_]) -> tuple[int, int]:
|
|
|
+ unresolved_cell_mask = numpy.count_nonzero(wave, axis=0) > 1
|
|
|
+ cell_weights = numpy.where(unresolved_cell_mask, 1.0, numpy.inf)
|
|
|
+ row, col = numpy.unravel_index(numpy.argmin(cell_weights), cell_weights.shape)
|
|
|
+ return row.item(), col.item()
|
|
|
+
|
|
|
+
|
|
|
+#####################################
|
|
|
+# Pattern Heuristics
|
|
|
+
|
|
|
+
|
|
|
+def lexicalPatternHeuristic(weights: NDArray[np.bool_], wave: NDArray[np.bool_]) -> int:
|
|
|
+ return numpy.nonzero(weights)[0][0].item()
|
|
|
+
|
|
|
+
|
|
|
+def makeWeightedPatternHeuristic(
|
|
|
+ weights: NDArray[np.floating[Any]],
|
|
|
+ np_random: numpy.random.Generator | None = None,
|
|
|
+):
|
|
|
+ num_of_patterns = len(weights)
|
|
|
+ np_random: numpy.random.Generator = (
|
|
|
+ numpy.random.default_rng() if np_random is None else np_random
|
|
|
+ )
|
|
|
+
|
|
|
+ def weightedPatternHeuristic(wave: NDArray[np.bool_], _: NDArray[np.bool_]) -> int:
|
|
|
+ # TODO: there's maybe a faster, more controlled way to do this sampling...
|
|
|
+ weighted_wave: NDArray[np.floating[Any]] = weights * wave
|
|
|
+ weighted_wave /= weighted_wave.sum()
|
|
|
+ result = np_random.choice(num_of_patterns, p=weighted_wave)
|
|
|
+ return result
|
|
|
+
|
|
|
+ return weightedPatternHeuristic
|
|
|
+
|
|
|
+
|
|
|
+def makeRarestPatternHeuristic(
|
|
|
+ weights: NDArray[np.floating[Any]],
|
|
|
+ np_random: numpy.random.Generator | None = None,
|
|
|
+) -> Callable[[NDArray[np.bool_], NDArray[np.bool_]], int]:
|
|
|
+ """Return a function that chooses the rarest (currently least-used) pattern."""
|
|
|
+ np_random: numpy.random.Generator = (
|
|
|
+ numpy.random.default_rng() if np_random is None else np_random
|
|
|
+ )
|
|
|
+
|
|
|
+ def weightedPatternHeuristic(
|
|
|
+ wave: NDArray[np.bool_], total_wave: NDArray[np.bool_]
|
|
|
+ ) -> int:
|
|
|
+ logger.debug(total_wave.shape)
|
|
|
+ # [logger.debug(e) for e in wave]
|
|
|
+ wave_sums = numpy.sum(total_wave, (1, 2))
|
|
|
+ # logger.debug(wave_sums)
|
|
|
+ selected_pattern = np_random.choice(
|
|
|
+ numpy.where(wave_sums == wave_sums.max())[0]
|
|
|
+ )
|
|
|
+ return selected_pattern
|
|
|
+
|
|
|
+ return weightedPatternHeuristic
|
|
|
+
|
|
|
+
|
|
|
+def makeMostCommonPatternHeuristic(
|
|
|
+ weights: NDArray[np.floating[Any]],
|
|
|
+ np_random: numpy.random.Generator | None = None,
|
|
|
+) -> Callable[[NDArray[np.bool_], NDArray[np.bool_]], int]:
|
|
|
+ """Return a function that chooses the most common (currently most-used) pattern."""
|
|
|
+ np_random: numpy.random.Generator = (
|
|
|
+ numpy.random.default_rng() if np_random is None else np_random
|
|
|
+ )
|
|
|
+
|
|
|
+ def weightedPatternHeuristic(
|
|
|
+ wave: NDArray[np.bool_], total_wave: NDArray[np.bool_]
|
|
|
+ ) -> int:
|
|
|
+ logger.debug(total_wave.shape)
|
|
|
+ # [logger.debug(e) for e in wave]
|
|
|
+ wave_sums = numpy.sum(total_wave, (1, 2))
|
|
|
+ selected_pattern = np_random.choice(
|
|
|
+ numpy.where(wave_sums == wave_sums.min())[0]
|
|
|
+ )
|
|
|
+ return selected_pattern
|
|
|
+
|
|
|
+ return weightedPatternHeuristic
|
|
|
+
|
|
|
+
|
|
|
+def makeRandomPatternHeuristic(
|
|
|
+ weights: NDArray[np.floating[Any]],
|
|
|
+ np_random: numpy.random.Generator | None = None,
|
|
|
+) -> Callable[[NDArray[np.bool_], NDArray[np.bool_]], int]:
|
|
|
+ num_of_patterns = len(weights)
|
|
|
+ np_random: numpy.random.Generator = (
|
|
|
+ numpy.random.default_rng() if np_random is None else np_random
|
|
|
+ )
|
|
|
+
|
|
|
+ def randomPatternHeuristic(wave: NDArray[np.bool_], _: NDArray[np.bool_]) -> int:
|
|
|
+ # TODO: there's maybe a faster, more controlled way to do this sampling...
|
|
|
+ weighted_wave = 1.0 * wave
|
|
|
+ weighted_wave /= weighted_wave.sum()
|
|
|
+ result = np_random.choice(num_of_patterns, p=weighted_wave)
|
|
|
+ return result
|
|
|
+
|
|
|
+ return randomPatternHeuristic
|
|
|
+
|
|
|
+
|
|
|
+######################################
|
|
|
+# Global Constraints
|
|
|
+
|
|
|
+
|
|
|
+def make_global_use_all_patterns() -> Callable[[NDArray[np.bool_]], bool]:
|
|
|
+ def global_use_all_patterns(wave: NDArray[np.bool_]) -> bool:
|
|
|
+ """Returns true if at least one instance of each pattern is still possible."""
|
|
|
+ return numpy.all(numpy.any(wave, axis=(1, 2))).item()
|
|
|
+
|
|
|
+ return global_use_all_patterns
|
|
|
+
|
|
|
+
|
|
|
+#####################################
|
|
|
+# Solver
|
|
|
+
|
|
|
+
|
|
|
+def propagate(
|
|
|
+ wave: NDArray[np.bool_],
|
|
|
+ adj: Mapping[tuple[int, int], NDArray[numpy.bool_]],
|
|
|
+ periodic: bool = False,
|
|
|
+ onPropagate: Callable[[NDArray[numpy.bool_]], None] | None = None,
|
|
|
+) -> None:
|
|
|
+ """Completely probagate any newly collapsed waves to all areas."""
|
|
|
+ last_count = wave.sum()
|
|
|
+
|
|
|
+ while True:
|
|
|
+ supports = {}
|
|
|
+ if periodic:
|
|
|
+ padded = numpy.pad(wave, ((0, 0), (1, 1), (1, 1)), mode="wrap")
|
|
|
+ else:
|
|
|
+ padded = numpy.pad(
|
|
|
+ wave, ((0, 0), (1, 1), (1, 1)), mode="constant", constant_values=True
|
|
|
+ )
|
|
|
+
|
|
|
+ # adj is the list of adjacencies. For each direction d in adjacency,
|
|
|
+ # check which patterns are still valid...
|
|
|
+ for d in adj:
|
|
|
+ dx, dy = d
|
|
|
+ # padded[] is a version of the adjacency matrix with the values wrapped around
|
|
|
+ # shifted[] is the padded version with the values shifted over in one direction
|
|
|
+ # because my code stores the directions as relative (x,y) coordinates, we can find
|
|
|
+ # the adjacent cell for each direction by simply shifting the matrix in that direction,
|
|
|
+ # which allows for arbitrary adjacency directions. This is somewhat excessive, but elegant.
|
|
|
+
|
|
|
+ shifted = padded[
|
|
|
+ :, 1 + dx : 1 + wave.shape[1] + dx, 1 + dy : 1 + wave.shape[2] + dy
|
|
|
+ ]
|
|
|
+ # logger.debug(f"shifted: {shifted.shape} | adj[d]: {adj[d].shape} | d: {d}")
|
|
|
+ # raise StopEarly
|
|
|
+ # supports[d] = numpy.einsum('pwh,pq->qwh', shifted, adj[d]) > 0
|
|
|
+
|
|
|
+ # The adjacency matrix is a boolean matrix, indexed by the direction and the two patterns.
|
|
|
+ # If the value for (direction, pattern1, pattern2) is True, then this is a valid adjacency.
|
|
|
+ # This gives us a rapid way to compare: True is 1, False is 0, so multiplying the matrices
|
|
|
+ # gives us the adjacency compatibility.
|
|
|
+ supports[d] = (adj[d] @ shifted.reshape(shifted.shape[0], -1)).reshape(
|
|
|
+ shifted.shape
|
|
|
+ ) > 0
|
|
|
+ # supports[d] = ( <- for each cell in the matrix
|
|
|
+ # adj[d] <- the adjacency matrix [sliced by the direction d]
|
|
|
+ # @ <- Matrix multiplication
|
|
|
+ # shifted.reshape(shifted.shape[0], -1)) <- change the shape of the shifted matrix to 2-dimensions, to make the matrix multiplication easier
|
|
|
+ # .reshape( <- reshape our matrix-multiplied result...
|
|
|
+ # shifted.shape) <- ...to match the original shape of the shifted matrix
|
|
|
+ # > 0 <- is not false
|
|
|
+
|
|
|
+ # multiply the wave matrix by the support matrix to find which patterns are still in the domain
|
|
|
+ for d in adj:
|
|
|
+ wave *= supports[d]
|
|
|
+
|
|
|
+ if wave.sum() == last_count:
|
|
|
+ break # No changes since the last loop, changed waves have been fully propagated.
|
|
|
+ last_count = wave.sum()
|
|
|
+
|
|
|
+ if onPropagate:
|
|
|
+ onPropagate(wave)
|
|
|
+
|
|
|
+ if (wave.sum(axis=0) == 0).any():
|
|
|
+ raise Contradiction("Wave is in a contradictory state and can not be solved.")
|
|
|
+
|
|
|
+
|
|
|
+def observe(
|
|
|
+ wave: NDArray[np.bool_],
|
|
|
+ locationHeuristic: Callable[[NDArray[np.bool_]], tuple[int, int]],
|
|
|
+ patternHeuristic: Callable[[NDArray[np.bool_], NDArray[np.bool_]], int],
|
|
|
+) -> tuple[int, int, int]:
|
|
|
+ """Return the next best wave to collapse based on the provided heuristics."""
|
|
|
+ i, j = locationHeuristic(wave)
|
|
|
+ pattern = patternHeuristic(wave[:, i, j], wave)
|
|
|
+ return pattern, i, j
|
|
|
+
|
|
|
+
|
|
|
+def run(
|
|
|
+ wave: NDArray[np.bool_],
|
|
|
+ adj: Mapping[tuple[int, int], NDArray[numpy.bool_]],
|
|
|
+ locationHeuristic: Callable[[NDArray[numpy.bool_]], tuple[int, int]],
|
|
|
+ patternHeuristic: Callable[[NDArray[np.bool_], NDArray[np.bool_]], int],
|
|
|
+ periodic: bool = False,
|
|
|
+ backtracking: bool = False,
|
|
|
+ onBacktrack: Callable[[], None] | None = None,
|
|
|
+ onChoice: Callable[[int, int, int], None] | None = None,
|
|
|
+ onObserve: Callable[[NDArray[numpy.bool_]], None] | None = None,
|
|
|
+ onPropagate: Callable[[NDArray[numpy.bool_]], None] | None = None,
|
|
|
+ checkFeasible: Callable[[NDArray[numpy.bool_]], bool] | None = None,
|
|
|
+ onFinal: Callable[[NDArray[numpy.bool_]], None] | None = None,
|
|
|
+ depth: int = 0,
|
|
|
+ depth_limit: int | None = None,
|
|
|
+) -> NDArray[numpy.int64]:
|
|
|
+ solver = Solver(
|
|
|
+ wave=wave,
|
|
|
+ adj=adj,
|
|
|
+ periodic=periodic,
|
|
|
+ backtracking=backtracking,
|
|
|
+ on_backtrack=onBacktrack,
|
|
|
+ on_choice=onChoice,
|
|
|
+ on_observe=onObserve,
|
|
|
+ on_propagate=onPropagate,
|
|
|
+ check_feasible=checkFeasible,
|
|
|
+ )
|
|
|
+ while not solver.solve_next(
|
|
|
+ location_heuristic=locationHeuristic, pattern_heuristic=patternHeuristic
|
|
|
+ ):
|
|
|
+ pass
|
|
|
+ if onFinal:
|
|
|
+ onFinal(solver.wave)
|
|
|
+ return numpy.argmax(solver.wave, axis=0)
|