123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532 |
- """Wave Function Collapse solver. Implementation based on https://github.com/ikarth/wfc_2019f"""
- from __future__ import annotations
- import itertools
- import logging
- import math
- from typing import Any, Callable, Collection, Iterable, Iterator, Mapping, TypeVar
- # from scipy import sparse # type: ignore
- import numpy
- import numpy as np
- from numpy.typing import NBitBase, NDArray
- logger = logging.getLogger(__name__)
- T = TypeVar("T", bound=NBitBase)
- class Contradiction(Exception):
- """Solving could not proceed without backtracking/restarting."""
- pass
- class TimedOut(Exception):
- """Solve timed out."""
- pass
- class StopEarly(Exception):
- """Aborting solve early."""
- pass
- class Solver:
- """WFC Solver which can hold wave and backtracking state."""
- def __init__(
- self,
- *,
- wave: NDArray[np.bool_],
- adj: Mapping[tuple[int, int], NDArray[numpy.bool_]],
- periodic: bool = False,
- backtracking: bool = False,
- on_backtrack: Callable[[], None] | None = None,
- on_choice: Callable[[int, int, int], None] | None = None,
- on_observe: Callable[[NDArray[numpy.bool_]], None] | None = None,
- on_propagate: Callable[[NDArray[numpy.bool_]], None] | None = None,
- check_feasible: Callable[[NDArray[numpy.bool_]], bool] | None = None,
- ) -> None:
- self.wave = wave
- self.adj = adj
- self.periodic = periodic
- self.backtracking = backtracking
- self.history: list[NDArray[np.bool_]] = [] # An undo history for backtracking.
- self.on_backtrack = on_backtrack
- self.on_choice = on_choice
- self.on_observe = on_observe
- self.on_propagate = on_propagate
- self.check_feasible = check_feasible
- @property
- def is_solved(self) -> bool:
- """Is True if the wave has been fully resolved."""
- return (
- self.wave.sum() == self.wave.shape[1] * self.wave.shape[2]
- and (self.wave.sum(axis=0) == 1).all()
- )
- def solve_next(
- self,
- location_heuristic: Callable[[NDArray[numpy.bool_]], tuple[int, int]],
- pattern_heuristic: Callable[[NDArray[np.bool_], NDArray[np.bool_]], int],
- ) -> bool:
- """Attempt to collapse one wave. Returns True if no more steps remain."""
- if self.is_solved:
- return True
- if self.check_feasible and not self.check_feasible(self.wave):
- raise Contradiction("Not feasible.")
- if self.backtracking:
- self.history.append(self.wave.copy())
- propagate(
- self.wave, self.adj, periodic=self.periodic, onPropagate=self.on_propagate
- )
- pattern, i, j = None, None, None
- try:
- pattern, i, j = observe(self.wave, location_heuristic, pattern_heuristic)
- if self.on_choice:
- self.on_choice(pattern, i, j)
- self.wave[:, i, j] = False
- self.wave[pattern, i, j] = True
- if self.on_observe:
- self.on_observe(self.wave)
- propagate(
- self.wave,
- self.adj,
- periodic=self.periodic,
- onPropagate=self.on_propagate,
- )
- return False # Assume there is remaining steps, if not then the next call will return True.
- except Contradiction:
- if not self.backtracking:
- raise
- if not self.history:
- raise Contradiction("Every permutation has been attempted.")
- if self.on_backtrack:
- self.on_backtrack()
- self.wave = self.history.pop()
- self.wave[pattern, i, j] = False
- return False
- def solve(
- self,
- location_heuristic: Callable[[NDArray[numpy.bool_]], tuple[int, int]],
- pattern_heuristic: Callable[[NDArray[np.bool_], NDArray[np.bool_]], int],
- ) -> NDArray[np.int64]:
- """Attempts to solve all waves and returns the solution."""
- while not self.solve_next(
- location_heuristic=location_heuristic, pattern_heuristic=pattern_heuristic
- ):
- pass
- return numpy.argmax(self.wave, axis=0)
- def makeWave(
- n: int, w: int, h: int, ground: Iterable[int] | None = None
- ) -> NDArray[numpy.bool_]:
- wave: NDArray[numpy.bool_] = numpy.ones((n, w, h), dtype=numpy.bool_)
- if ground is not None:
- wave[:, :, h - 1] = False
- for g in ground:
- wave[
- g,
- :,
- ] = False
- wave[g, :, h - 1] = True
- # logger.debug(wave)
- # for i in range(wave.shape[0]):
- # logger.debug(wave[i])
- return wave
- def makeAdj(
- adjLists: Mapping[tuple[int, int], Collection[Iterable[int]]]
- ) -> dict[tuple[int, int], NDArray[numpy.bool_]]:
- adjMatrices = {}
- # logger.debug(adjLists)
- num_patterns = len(list(adjLists.values())[0])
- for d in adjLists:
- m = numpy.zeros((num_patterns, num_patterns), dtype=bool)
- for i, js in enumerate(adjLists[d]):
- # logger.debug(js)
- for j in js:
- m[i, j] = 1
- # If scipy is available, use sparse matrices.
- # adjMatrices[d] = sparse.csr_matrix(m)
- adjMatrices[d] = m
- return adjMatrices
- ######################################
- # Location Heuristics
- def makeRandomLocationHeuristic(
- preferences: NDArray[np.floating[Any]],
- ) -> Callable[[NDArray[np.bool_]], tuple[int, int]]:
- def randomLocationHeuristic(wave: NDArray[np.bool_]) -> tuple[int, int]:
- unresolved_cell_mask = numpy.count_nonzero(wave, axis=0) > 1
- cell_weights = numpy.where(unresolved_cell_mask, preferences, numpy.inf)
- row, col = numpy.unravel_index(numpy.argmin(cell_weights), cell_weights.shape)
- return row.item(), col.item()
- return randomLocationHeuristic
- def makeEntropyLocationHeuristic(
- preferences: NDArray[np.floating[Any]],
- ) -> Callable[[NDArray[np.bool_]], tuple[int, int]]:
- def entropyLocationHeuristic(wave: NDArray[np.bool_]) -> tuple[int, int]:
- unresolved_cell_mask = numpy.count_nonzero(wave, axis=0) > 1
- cell_weights = numpy.where(
- unresolved_cell_mask,
- preferences + numpy.count_nonzero(wave, axis=0),
- numpy.inf,
- )
- row, col = numpy.unravel_index(numpy.argmin(cell_weights), cell_weights.shape)
- return row.item(), col.item()
- return entropyLocationHeuristic
- def makeAntiEntropyLocationHeuristic(
- preferences: NDArray[np.floating[Any]],
- ) -> Callable[[NDArray[np.bool_]], tuple[int, int]]:
- def antiEntropyLocationHeuristic(wave: NDArray[np.bool_]) -> tuple[int, int]:
- unresolved_cell_mask = numpy.count_nonzero(wave, axis=0) > 1
- cell_weights = numpy.where(
- unresolved_cell_mask,
- preferences + numpy.count_nonzero(wave, axis=0),
- -numpy.inf,
- )
- row, col = numpy.unravel_index(numpy.argmax(cell_weights), cell_weights.shape)
- return row.item(), col.item()
- return antiEntropyLocationHeuristic
- def spiral_transforms() -> Iterator[tuple[int, int]]:
- for N in itertools.count(start=1):
- if N % 2 == 0:
- yield (0, 1) # right
- for _ in range(N):
- yield (1, 0) # down
- for _ in range(N):
- yield (0, -1) # left
- else:
- yield (0, -1) # left
- for _ in range(N):
- yield (-1, 0) # up
- for _ in range(N):
- yield (0, 1) # right
- def spiral_coords(x: int, y: int) -> Iterator[tuple[int, int]]:
- yield x, y
- for transform in spiral_transforms():
- x += transform[0]
- y += transform[1]
- yield x, y
- def fill_with_curve(
- arr: NDArray[np.floating[T]], curve_gen: Iterable[Iterable[int]]
- ) -> NDArray[np.floating[T]]:
- arr_len = numpy.prod(arr.shape)
- fill = 0
- for coord in curve_gen:
- # logger.debug(fill, idx, coord)
- if fill < arr_len:
- try:
- arr[tuple(coord)] = fill / arr_len
- fill += 1
- except IndexError:
- pass
- else:
- break
- # logger.debug(arr)
- return arr
- def makeSpiralLocationHeuristic(
- preferences: NDArray[np.floating[Any]],
- ) -> Callable[[NDArray[np.bool_]], tuple[int, int]]:
- # https://stackoverflow.com/a/23707273/5562922
- spiral_gen = (
- sc for sc in spiral_coords(preferences.shape[0] // 2, preferences.shape[1] // 2)
- )
- cell_order = fill_with_curve(preferences, spiral_gen)
- def spiralLocationHeuristic(wave: NDArray[np.bool_]) -> tuple[int, int]:
- unresolved_cell_mask = numpy.count_nonzero(wave, axis=0) > 1
- cell_weights = numpy.where(unresolved_cell_mask, cell_order, numpy.inf)
- row, col = numpy.unravel_index(numpy.argmin(cell_weights), cell_weights.shape)
- return row.item(), col.item()
- return spiralLocationHeuristic
- def makeHilbertLocationHeuristic(
- preferences: NDArray[np.floating[Any]],
- ) -> Callable[[NDArray[np.bool_]], tuple[int, int]]:
- from hilbertcurve.hilbertcurve import HilbertCurve # type: ignore
- curve_size = math.ceil(math.sqrt(max(preferences.shape[0], preferences.shape[1])))
- logger.debug(curve_size)
- curve_size = 4
- h_curve = HilbertCurve(curve_size, 2)
- h_coords = (h_curve.point_from_distance(i) for i in itertools.count())
- cell_order = fill_with_curve(preferences, h_coords)
- # logger.debug(cell_order)
- def hilbertLocationHeuristic(wave: NDArray[np.bool_]) -> tuple[int, int]:
- unresolved_cell_mask = numpy.count_nonzero(wave, axis=0) > 1
- cell_weights = numpy.where(unresolved_cell_mask, cell_order, numpy.inf)
- row, col = numpy.unravel_index(numpy.argmin(cell_weights), cell_weights.shape)
- return row.item(), col.item()
- return hilbertLocationHeuristic
- def simpleLocationHeuristic(wave: NDArray[np.bool_]) -> tuple[int, int]:
- unresolved_cell_mask = numpy.count_nonzero(wave, axis=0) > 1
- cell_weights = numpy.where(
- unresolved_cell_mask, numpy.count_nonzero(wave, axis=0), numpy.inf
- )
- row, col = numpy.unravel_index(numpy.argmin(cell_weights), cell_weights.shape)
- return row.item(), col.item()
- def lexicalLocationHeuristic(wave: NDArray[np.bool_]) -> tuple[int, int]:
- unresolved_cell_mask = numpy.count_nonzero(wave, axis=0) > 1
- cell_weights = numpy.where(unresolved_cell_mask, 1.0, numpy.inf)
- row, col = numpy.unravel_index(numpy.argmin(cell_weights), cell_weights.shape)
- return row.item(), col.item()
- #####################################
- # Pattern Heuristics
- def lexicalPatternHeuristic(weights: NDArray[np.bool_], wave: NDArray[np.bool_]) -> int:
- return numpy.nonzero(weights)[0][0].item()
- def makeWeightedPatternHeuristic(
- weights: NDArray[np.floating[Any]],
- np_random: numpy.random.Generator | None = None,
- ):
- num_of_patterns = len(weights)
- np_random: numpy.random.Generator = (
- numpy.random.default_rng() if np_random is None else np_random
- )
- def weightedPatternHeuristic(wave: NDArray[np.bool_], _: NDArray[np.bool_]) -> int:
- # TODO: there's maybe a faster, more controlled way to do this sampling...
- weighted_wave: NDArray[np.floating[Any]] = weights * wave
- weighted_wave /= weighted_wave.sum()
- result = np_random.choice(num_of_patterns, p=weighted_wave)
- return result
- return weightedPatternHeuristic
- def makeRarestPatternHeuristic(
- weights: NDArray[np.floating[Any]],
- np_random: numpy.random.Generator | None = None,
- ) -> Callable[[NDArray[np.bool_], NDArray[np.bool_]], int]:
- """Return a function that chooses the rarest (currently least-used) pattern."""
- np_random: numpy.random.Generator = (
- numpy.random.default_rng() if np_random is None else np_random
- )
- def weightedPatternHeuristic(
- wave: NDArray[np.bool_], total_wave: NDArray[np.bool_]
- ) -> int:
- logger.debug(total_wave.shape)
- # [logger.debug(e) for e in wave]
- wave_sums = numpy.sum(total_wave, (1, 2))
- # logger.debug(wave_sums)
- selected_pattern = np_random.choice(
- numpy.where(wave_sums == wave_sums.max())[0]
- )
- return selected_pattern
- return weightedPatternHeuristic
- def makeMostCommonPatternHeuristic(
- weights: NDArray[np.floating[Any]],
- np_random: numpy.random.Generator | None = None,
- ) -> Callable[[NDArray[np.bool_], NDArray[np.bool_]], int]:
- """Return a function that chooses the most common (currently most-used) pattern."""
- np_random: numpy.random.Generator = (
- numpy.random.default_rng() if np_random is None else np_random
- )
- def weightedPatternHeuristic(
- wave: NDArray[np.bool_], total_wave: NDArray[np.bool_]
- ) -> int:
- logger.debug(total_wave.shape)
- # [logger.debug(e) for e in wave]
- wave_sums = numpy.sum(total_wave, (1, 2))
- selected_pattern = np_random.choice(
- numpy.where(wave_sums == wave_sums.min())[0]
- )
- return selected_pattern
- return weightedPatternHeuristic
- def makeRandomPatternHeuristic(
- weights: NDArray[np.floating[Any]],
- np_random: numpy.random.Generator | None = None,
- ) -> Callable[[NDArray[np.bool_], NDArray[np.bool_]], int]:
- num_of_patterns = len(weights)
- np_random: numpy.random.Generator = (
- numpy.random.default_rng() if np_random is None else np_random
- )
- def randomPatternHeuristic(wave: NDArray[np.bool_], _: NDArray[np.bool_]) -> int:
- # TODO: there's maybe a faster, more controlled way to do this sampling...
- weighted_wave = 1.0 * wave
- weighted_wave /= weighted_wave.sum()
- result = np_random.choice(num_of_patterns, p=weighted_wave)
- return result
- return randomPatternHeuristic
- ######################################
- # Global Constraints
- def make_global_use_all_patterns() -> Callable[[NDArray[np.bool_]], bool]:
- def global_use_all_patterns(wave: NDArray[np.bool_]) -> bool:
- """Returns true if at least one instance of each pattern is still possible."""
- return numpy.all(numpy.any(wave, axis=(1, 2))).item()
- return global_use_all_patterns
- #####################################
- # Solver
- def propagate(
- wave: NDArray[np.bool_],
- adj: Mapping[tuple[int, int], NDArray[numpy.bool_]],
- periodic: bool = False,
- onPropagate: Callable[[NDArray[numpy.bool_]], None] | None = None,
- ) -> None:
- """Completely probagate any newly collapsed waves to all areas."""
- last_count = wave.sum()
- while True:
- supports = {}
- if periodic:
- padded = numpy.pad(wave, ((0, 0), (1, 1), (1, 1)), mode="wrap")
- else:
- padded = numpy.pad(
- wave, ((0, 0), (1, 1), (1, 1)), mode="constant", constant_values=True
- )
- # adj is the list of adjacencies. For each direction d in adjacency,
- # check which patterns are still valid...
- for d in adj:
- dx, dy = d
- # padded[] is a version of the adjacency matrix with the values wrapped around
- # shifted[] is the padded version with the values shifted over in one direction
- # because my code stores the directions as relative (x,y) coordinates, we can find
- # the adjacent cell for each direction by simply shifting the matrix in that direction,
- # which allows for arbitrary adjacency directions. This is somewhat excessive, but elegant.
- shifted = padded[
- :, 1 + dx : 1 + wave.shape[1] + dx, 1 + dy : 1 + wave.shape[2] + dy
- ]
- # logger.debug(f"shifted: {shifted.shape} | adj[d]: {adj[d].shape} | d: {d}")
- # raise StopEarly
- # supports[d] = numpy.einsum('pwh,pq->qwh', shifted, adj[d]) > 0
- # The adjacency matrix is a boolean matrix, indexed by the direction and the two patterns.
- # If the value for (direction, pattern1, pattern2) is True, then this is a valid adjacency.
- # This gives us a rapid way to compare: True is 1, False is 0, so multiplying the matrices
- # gives us the adjacency compatibility.
- supports[d] = (adj[d] @ shifted.reshape(shifted.shape[0], -1)).reshape(
- shifted.shape
- ) > 0
- # supports[d] = ( <- for each cell in the matrix
- # adj[d] <- the adjacency matrix [sliced by the direction d]
- # @ <- Matrix multiplication
- # shifted.reshape(shifted.shape[0], -1)) <- change the shape of the shifted matrix to 2-dimensions, to make the matrix multiplication easier
- # .reshape( <- reshape our matrix-multiplied result...
- # shifted.shape) <- ...to match the original shape of the shifted matrix
- # > 0 <- is not false
- # multiply the wave matrix by the support matrix to find which patterns are still in the domain
- for d in adj:
- wave *= supports[d]
- if wave.sum() == last_count:
- break # No changes since the last loop, changed waves have been fully propagated.
- last_count = wave.sum()
- if onPropagate:
- onPropagate(wave)
- if (wave.sum(axis=0) == 0).any():
- raise Contradiction("Wave is in a contradictory state and can not be solved.")
- def observe(
- wave: NDArray[np.bool_],
- locationHeuristic: Callable[[NDArray[np.bool_]], tuple[int, int]],
- patternHeuristic: Callable[[NDArray[np.bool_], NDArray[np.bool_]], int],
- ) -> tuple[int, int, int]:
- """Return the next best wave to collapse based on the provided heuristics."""
- i, j = locationHeuristic(wave)
- pattern = patternHeuristic(wave[:, i, j], wave)
- return pattern, i, j
- def run(
- wave: NDArray[np.bool_],
- adj: Mapping[tuple[int, int], NDArray[numpy.bool_]],
- locationHeuristic: Callable[[NDArray[numpy.bool_]], tuple[int, int]],
- patternHeuristic: Callable[[NDArray[np.bool_], NDArray[np.bool_]], int],
- periodic: bool = False,
- backtracking: bool = False,
- onBacktrack: Callable[[], None] | None = None,
- onChoice: Callable[[int, int, int], None] | None = None,
- onObserve: Callable[[NDArray[numpy.bool_]], None] | None = None,
- onPropagate: Callable[[NDArray[numpy.bool_]], None] | None = None,
- checkFeasible: Callable[[NDArray[numpy.bool_]], bool] | None = None,
- onFinal: Callable[[NDArray[numpy.bool_]], None] | None = None,
- depth: int = 0,
- depth_limit: int | None = None,
- ) -> NDArray[numpy.int64]:
- solver = Solver(
- wave=wave,
- adj=adj,
- periodic=periodic,
- backtracking=backtracking,
- on_backtrack=onBacktrack,
- on_choice=onChoice,
- on_observe=onObserve,
- on_propagate=onPropagate,
- check_feasible=checkFeasible,
- )
- while not solver.solve_next(
- location_heuristic=locationHeuristic, pattern_heuristic=patternHeuristic
- ):
- pass
- if onFinal:
- onFinal(solver.wave)
- return numpy.argmax(solver.wave, axis=0)
|