123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296 |
- """Main WFC execution function. Implementation based on https://github.com/ikarth/wfc_2019f"""
- from __future__ import annotations
- import logging
- import time
- from typing import Any, Callable
- import numpy as np
- from numpy.typing import NDArray
- from typing_extensions import Literal
- from minigrid.envs.wfc.wfclogic.adjacency import adjacency_extraction
- from minigrid.envs.wfc.wfclogic.patterns import (
- make_pattern_catalog_with_rotations,
- pattern_grid_to_tiles,
- )
- from minigrid.envs.wfc.wfclogic.solver import (
- Contradiction,
- StopEarly,
- TimedOut,
- lexicalLocationHeuristic,
- lexicalPatternHeuristic,
- make_global_use_all_patterns,
- makeAdj,
- makeAntiEntropyLocationHeuristic,
- makeEntropyLocationHeuristic,
- makeHilbertLocationHeuristic,
- makeRandomLocationHeuristic,
- makeRandomPatternHeuristic,
- makeRarestPatternHeuristic,
- makeSpiralLocationHeuristic,
- makeWave,
- makeWeightedPatternHeuristic,
- run,
- simpleLocationHeuristic,
- )
- from .tiles import make_tile_catalog
- from .utilities import tile_grid_to_image
- logger = logging.getLogger(__name__)
- def make_log_stats() -> Callable[[dict[str, Any], str], None]:
- log_line = 0
- def log_stats(stats: dict[str, Any], filename: str) -> None:
- nonlocal log_line
- if stats:
- log_line += 1
- with open(filename, "a", encoding="utf_8") as logf:
- if log_line < 2:
- for s in stats.keys():
- print(str(s), end="\t", file=logf)
- print("", file=logf)
- for s in stats.keys():
- print(str(stats[s]), end="\t", file=logf)
- print("", file=logf)
- return log_stats
- def execute_wfc(
- image: NDArray[np.integer],
- tile_size: int = 1,
- pattern_width: int = 2,
- rotations: int = 8,
- output_size: tuple[int, int] = (48, 48),
- ground: int | None = None,
- attempt_limit: int = 10,
- output_periodic: bool = True,
- input_periodic: bool = True,
- loc_heuristic: Literal[
- "lexical", "hilbert", "spiral", "entropy", "anti-entropy", "simple", "random"
- ] = "entropy",
- choice_heuristic: Literal["lexical", "rarest", "weighted", "random"] = "weighted",
- global_constraint: Literal[False, "allpatterns"] = False,
- backtracking: bool = False,
- log_filename: str = "log",
- logging: bool = False,
- log_stats_to_output: Callable[[dict[str, Any], str], None] | None = None,
- np_random: np.random.Generator | None = None,
- ) -> NDArray[np.integer]:
- time_begin = time.perf_counter()
- output_destination = r"./output/"
- np_random: np.random.Generator = (
- np.random.default_rng() if np_random is None else np_random
- )
- rotations -= 1 # change to zero-based
- input_stats = {
- "tile_size": tile_size,
- "pattern_width": pattern_width,
- "rotations": rotations,
- "output_size": output_size,
- "ground": ground,
- "attempt_limit": attempt_limit,
- "output_periodic": output_periodic,
- "input_periodic": input_periodic,
- "location heuristic": loc_heuristic,
- "choice heuristic": choice_heuristic,
- "global constraint": global_constraint,
- "backtracking": backtracking,
- }
- # TODO: generalize this to more than the four cardinal directions
- direction_offsets = list(enumerate([(0, -1), (1, 0), (0, 1), (-1, 0)]))
- tile_catalog, tile_grid, _code_list, _unique_tiles = make_tile_catalog(
- image, tile_size
- )
- (
- pattern_catalog,
- pattern_weights,
- pattern_list,
- pattern_grid,
- ) = make_pattern_catalog_with_rotations(
- tile_grid, pattern_width, input_is_periodic=input_periodic, rotations=rotations
- )
- logger.debug("profiling adjacency relations")
- adjacency_relations = adjacency_extraction(
- pattern_grid,
- pattern_catalog,
- direction_offsets,
- (pattern_width, pattern_width),
- )
- logger.debug("adjacency_relations")
- logger.debug(f"output size: {output_size}\noutput periodic: {output_periodic}")
- number_of_patterns = len(pattern_weights)
- logger.debug(f"# patterns: {number_of_patterns}")
- decode_patterns = dict(enumerate(pattern_list))
- encode_patterns = {x: i for i, x in enumerate(pattern_list)}
- adjacency_list: dict[tuple[int, int], list[set[int]]] = {}
- for _, adjacency in direction_offsets:
- adjacency_list[adjacency] = [set() for _ in pattern_weights]
- # logger.debug(adjacency_list)
- for adjacency, pattern1, pattern2 in adjacency_relations:
- # logger.debug(adjacency)
- # logger.debug(decode_patterns[pattern1])
- adjacency_list[adjacency][encode_patterns[pattern1]].add(
- encode_patterns[pattern2]
- )
- logger.debug(f"adjacency: {len(adjacency_list)}")
- time_adjacency = time.perf_counter()
- # Ground #
- ground_list: NDArray[np.int64] | None = None
- if ground:
- ground_list = np.vectorize(lambda x: encode_patterns[x])(
- pattern_grid.flat[(ground - 1) :]
- )
- if ground_list is None or ground_list.size == 0:
- ground_list = None
- wave = makeWave(
- number_of_patterns, output_size[0], output_size[1], ground=ground_list
- )
- adjacency_matrix = makeAdj(adjacency_list)
- # Heuristics #
- encoded_weights: NDArray[np.float64] = np.zeros(
- (number_of_patterns), dtype=np.float64
- )
- for w_id, w_val in pattern_weights.items():
- encoded_weights[encode_patterns[w_id]] = w_val
- choice_random_weighting: NDArray[np.float64] = (
- np_random.random(wave.shape[1:]) * 0.1
- )
- pattern_heuristic: Callable[
- [NDArray[np.bool_], NDArray[np.bool_]], int
- ] = lexicalPatternHeuristic
- if choice_heuristic == "rarest":
- pattern_heuristic = makeRarestPatternHeuristic(encoded_weights, np_random)
- if choice_heuristic == "weighted":
- pattern_heuristic = makeWeightedPatternHeuristic(encoded_weights, np_random)
- if choice_heuristic == "random":
- pattern_heuristic = makeRandomPatternHeuristic(encoded_weights, np_random)
- logger.debug(loc_heuristic)
- location_heuristic: Callable[
- [NDArray[np.bool_]], tuple[int, int]
- ] = lexicalLocationHeuristic
- if loc_heuristic == "anti-entropy":
- location_heuristic = makeAntiEntropyLocationHeuristic(choice_random_weighting)
- if loc_heuristic == "entropy":
- location_heuristic = makeEntropyLocationHeuristic(choice_random_weighting)
- if loc_heuristic == "random":
- location_heuristic = makeRandomLocationHeuristic(choice_random_weighting)
- if loc_heuristic == "simple":
- location_heuristic = simpleLocationHeuristic
- if loc_heuristic == "spiral":
- location_heuristic = makeSpiralLocationHeuristic(choice_random_weighting)
- if loc_heuristic == "hilbert":
- # This requires hilbert_curve to be installed
- location_heuristic = makeHilbertLocationHeuristic(choice_random_weighting)
- # Global Constraints #
- if global_constraint == "allpatterns":
- active_global_constraint = make_global_use_all_patterns()
- else:
- def active_global_constraint(wave) -> bool:
- return True
- logger.debug(active_global_constraint)
- combined_constraints = [active_global_constraint]
- def combinedConstraints(wave: NDArray[np.bool_]) -> bool:
- return all(fn(wave) for fn in combined_constraints)
- # Solving #
- time_solve_start = None
- time_solve_end = None
- solution_tile_grid = None
- logger.debug("solving...")
- attempts = 0
- while attempts < attempt_limit:
- attempts += 1
- time_solve_start = time.perf_counter()
- stats = {}
- try:
- solution = run(
- wave.copy(),
- adjacency_matrix,
- locationHeuristic=location_heuristic,
- patternHeuristic=pattern_heuristic,
- periodic=output_periodic,
- backtracking=backtracking,
- checkFeasible=combinedConstraints,
- )
- solution_as_ids = np.vectorize(lambda x: decode_patterns[x])(solution)
- solution_tile_grid = pattern_grid_to_tiles(solution_as_ids, pattern_catalog)
- time_solve_end = time.perf_counter()
- stats.update({"outcome": "success"})
- except StopEarly:
- logger.debug("Skipping...")
- stats.update({"outcome": "skipped"})
- raise
- except TimedOut:
- logger.debug("Timed Out")
- stats.update({"outcome": "timed_out"})
- except Contradiction:
- # logger.warning(f"Contradiction: {exc}")
- stats.update({"outcome": "contradiction"})
- finally:
- # profiler.dump_stats(f"logs/profile_{filename}_{timecode}.txt")
- outstats = {}
- outstats.update(input_stats)
- solve_duration = time.perf_counter() - time_solve_start
- if time_solve_end is not None:
- solve_duration = time_solve_end - time_solve_start
- adjacency_duration = time_solve_start - time_adjacency
- outstats.update(
- {
- "attempts": attempts,
- "time_start": time_begin,
- "time_adjacency": time_adjacency,
- "adjacency_duration": adjacency_duration,
- "time solve start": time_solve_start,
- "time solve end": time_solve_end,
- "solve duration": solve_duration,
- "pattern count": number_of_patterns,
- }
- )
- outstats.update(stats)
- if log_stats_to_output is not None:
- log_stats_to_output(
- outstats, output_destination + log_filename + ".tsv"
- )
- if solution_tile_grid is not None:
- return (
- tile_grid_to_image(
- solution_tile_grid, tile_catalog, (tile_size, tile_size)
- ),
- outstats,
- )
- else:
- return None, outstats
- raise TimedOut("Attempt limit exceeded.")
|