control.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. """Main WFC execution function. Implementation based on https://github.com/ikarth/wfc_2019f"""
  2. from __future__ import annotations
  3. import logging
  4. import time
  5. from typing import Any, Callable
  6. import numpy as np
  7. from numpy.typing import NDArray
  8. from typing_extensions import Literal
  9. from minigrid.envs.wfc.wfclogic.adjacency import adjacency_extraction
  10. from minigrid.envs.wfc.wfclogic.patterns import (
  11. make_pattern_catalog_with_rotations,
  12. pattern_grid_to_tiles,
  13. )
  14. from minigrid.envs.wfc.wfclogic.solver import (
  15. Contradiction,
  16. StopEarly,
  17. TimedOut,
  18. lexicalLocationHeuristic,
  19. lexicalPatternHeuristic,
  20. make_global_use_all_patterns,
  21. makeAdj,
  22. makeAntiEntropyLocationHeuristic,
  23. makeEntropyLocationHeuristic,
  24. makeHilbertLocationHeuristic,
  25. makeRandomLocationHeuristic,
  26. makeRandomPatternHeuristic,
  27. makeRarestPatternHeuristic,
  28. makeSpiralLocationHeuristic,
  29. makeWave,
  30. makeWeightedPatternHeuristic,
  31. run,
  32. simpleLocationHeuristic,
  33. )
  34. from .tiles import make_tile_catalog
  35. from .utilities import tile_grid_to_image
  36. logger = logging.getLogger(__name__)
  37. def make_log_stats() -> Callable[[dict[str, Any], str], None]:
  38. log_line = 0
  39. def log_stats(stats: dict[str, Any], filename: str) -> None:
  40. nonlocal log_line
  41. if stats:
  42. log_line += 1
  43. with open(filename, "a", encoding="utf_8") as logf:
  44. if log_line < 2:
  45. for s in stats.keys():
  46. print(str(s), end="\t", file=logf)
  47. print("", file=logf)
  48. for s in stats.keys():
  49. print(str(stats[s]), end="\t", file=logf)
  50. print("", file=logf)
  51. return log_stats
  52. def execute_wfc(
  53. image: NDArray[np.integer],
  54. tile_size: int = 1,
  55. pattern_width: int = 2,
  56. rotations: int = 8,
  57. output_size: tuple[int, int] = (48, 48),
  58. ground: int | None = None,
  59. attempt_limit: int = 10,
  60. output_periodic: bool = True,
  61. input_periodic: bool = True,
  62. loc_heuristic: Literal[
  63. "lexical", "hilbert", "spiral", "entropy", "anti-entropy", "simple", "random"
  64. ] = "entropy",
  65. choice_heuristic: Literal["lexical", "rarest", "weighted", "random"] = "weighted",
  66. global_constraint: Literal[False, "allpatterns"] = False,
  67. backtracking: bool = False,
  68. log_filename: str = "log",
  69. logging: bool = False,
  70. log_stats_to_output: Callable[[dict[str, Any], str], None] | None = None,
  71. np_random: np.random.Generator | None = None,
  72. ) -> NDArray[np.integer]:
  73. time_begin = time.perf_counter()
  74. output_destination = r"./output/"
  75. np_random: np.random.Generator = (
  76. np.random.default_rng() if np_random is None else np_random
  77. )
  78. rotations -= 1 # change to zero-based
  79. input_stats = {
  80. "tile_size": tile_size,
  81. "pattern_width": pattern_width,
  82. "rotations": rotations,
  83. "output_size": output_size,
  84. "ground": ground,
  85. "attempt_limit": attempt_limit,
  86. "output_periodic": output_periodic,
  87. "input_periodic": input_periodic,
  88. "location heuristic": loc_heuristic,
  89. "choice heuristic": choice_heuristic,
  90. "global constraint": global_constraint,
  91. "backtracking": backtracking,
  92. }
  93. # TODO: generalize this to more than the four cardinal directions
  94. direction_offsets = list(enumerate([(0, -1), (1, 0), (0, 1), (-1, 0)]))
  95. tile_catalog, tile_grid, _code_list, _unique_tiles = make_tile_catalog(
  96. image, tile_size
  97. )
  98. (
  99. pattern_catalog,
  100. pattern_weights,
  101. pattern_list,
  102. pattern_grid,
  103. ) = make_pattern_catalog_with_rotations(
  104. tile_grid, pattern_width, input_is_periodic=input_periodic, rotations=rotations
  105. )
  106. logger.debug("profiling adjacency relations")
  107. adjacency_relations = adjacency_extraction(
  108. pattern_grid,
  109. pattern_catalog,
  110. direction_offsets,
  111. (pattern_width, pattern_width),
  112. )
  113. logger.debug("adjacency_relations")
  114. logger.debug(f"output size: {output_size}\noutput periodic: {output_periodic}")
  115. number_of_patterns = len(pattern_weights)
  116. logger.debug(f"# patterns: {number_of_patterns}")
  117. decode_patterns = dict(enumerate(pattern_list))
  118. encode_patterns = {x: i for i, x in enumerate(pattern_list)}
  119. adjacency_list: dict[tuple[int, int], list[set[int]]] = {}
  120. for _, adjacency in direction_offsets:
  121. adjacency_list[adjacency] = [set() for _ in pattern_weights]
  122. # logger.debug(adjacency_list)
  123. for adjacency, pattern1, pattern2 in adjacency_relations:
  124. # logger.debug(adjacency)
  125. # logger.debug(decode_patterns[pattern1])
  126. adjacency_list[adjacency][encode_patterns[pattern1]].add(
  127. encode_patterns[pattern2]
  128. )
  129. logger.debug(f"adjacency: {len(adjacency_list)}")
  130. time_adjacency = time.perf_counter()
  131. # Ground #
  132. ground_list: NDArray[np.int64] | None = None
  133. if ground:
  134. ground_list = np.vectorize(lambda x: encode_patterns[x])(
  135. pattern_grid.flat[(ground - 1) :]
  136. )
  137. if ground_list is None or ground_list.size == 0:
  138. ground_list = None
  139. wave = makeWave(
  140. number_of_patterns, output_size[0], output_size[1], ground=ground_list
  141. )
  142. adjacency_matrix = makeAdj(adjacency_list)
  143. # Heuristics #
  144. encoded_weights: NDArray[np.float64] = np.zeros(
  145. (number_of_patterns), dtype=np.float64
  146. )
  147. for w_id, w_val in pattern_weights.items():
  148. encoded_weights[encode_patterns[w_id]] = w_val
  149. choice_random_weighting: NDArray[np.float64] = (
  150. np_random.random(wave.shape[1:]) * 0.1
  151. )
  152. pattern_heuristic: Callable[
  153. [NDArray[np.bool_], NDArray[np.bool_]], int
  154. ] = lexicalPatternHeuristic
  155. if choice_heuristic == "rarest":
  156. pattern_heuristic = makeRarestPatternHeuristic(encoded_weights, np_random)
  157. if choice_heuristic == "weighted":
  158. pattern_heuristic = makeWeightedPatternHeuristic(encoded_weights, np_random)
  159. if choice_heuristic == "random":
  160. pattern_heuristic = makeRandomPatternHeuristic(encoded_weights, np_random)
  161. logger.debug(loc_heuristic)
  162. location_heuristic: Callable[
  163. [NDArray[np.bool_]], tuple[int, int]
  164. ] = lexicalLocationHeuristic
  165. if loc_heuristic == "anti-entropy":
  166. location_heuristic = makeAntiEntropyLocationHeuristic(choice_random_weighting)
  167. if loc_heuristic == "entropy":
  168. location_heuristic = makeEntropyLocationHeuristic(choice_random_weighting)
  169. if loc_heuristic == "random":
  170. location_heuristic = makeRandomLocationHeuristic(choice_random_weighting)
  171. if loc_heuristic == "simple":
  172. location_heuristic = simpleLocationHeuristic
  173. if loc_heuristic == "spiral":
  174. location_heuristic = makeSpiralLocationHeuristic(choice_random_weighting)
  175. if loc_heuristic == "hilbert":
  176. # This requires hilbert_curve to be installed
  177. location_heuristic = makeHilbertLocationHeuristic(choice_random_weighting)
  178. # Global Constraints #
  179. if global_constraint == "allpatterns":
  180. active_global_constraint = make_global_use_all_patterns()
  181. else:
  182. def active_global_constraint(wave) -> bool:
  183. return True
  184. logger.debug(active_global_constraint)
  185. combined_constraints = [active_global_constraint]
  186. def combinedConstraints(wave: NDArray[np.bool_]) -> bool:
  187. return all(fn(wave) for fn in combined_constraints)
  188. # Solving #
  189. time_solve_start = None
  190. time_solve_end = None
  191. solution_tile_grid = None
  192. logger.debug("solving...")
  193. attempts = 0
  194. while attempts < attempt_limit:
  195. attempts += 1
  196. time_solve_start = time.perf_counter()
  197. stats = {}
  198. try:
  199. solution = run(
  200. wave.copy(),
  201. adjacency_matrix,
  202. locationHeuristic=location_heuristic,
  203. patternHeuristic=pattern_heuristic,
  204. periodic=output_periodic,
  205. backtracking=backtracking,
  206. checkFeasible=combinedConstraints,
  207. )
  208. solution_as_ids = np.vectorize(lambda x: decode_patterns[x])(solution)
  209. solution_tile_grid = pattern_grid_to_tiles(solution_as_ids, pattern_catalog)
  210. time_solve_end = time.perf_counter()
  211. stats.update({"outcome": "success"})
  212. except StopEarly:
  213. logger.debug("Skipping...")
  214. stats.update({"outcome": "skipped"})
  215. raise
  216. except TimedOut:
  217. logger.debug("Timed Out")
  218. stats.update({"outcome": "timed_out"})
  219. except Contradiction:
  220. # logger.warning(f"Contradiction: {exc}")
  221. stats.update({"outcome": "contradiction"})
  222. finally:
  223. # profiler.dump_stats(f"logs/profile_{filename}_{timecode}.txt")
  224. outstats = {}
  225. outstats.update(input_stats)
  226. solve_duration = time.perf_counter() - time_solve_start
  227. if time_solve_end is not None:
  228. solve_duration = time_solve_end - time_solve_start
  229. adjacency_duration = time_solve_start - time_adjacency
  230. outstats.update(
  231. {
  232. "attempts": attempts,
  233. "time_start": time_begin,
  234. "time_adjacency": time_adjacency,
  235. "adjacency_duration": adjacency_duration,
  236. "time solve start": time_solve_start,
  237. "time solve end": time_solve_end,
  238. "solve duration": solve_duration,
  239. "pattern count": number_of_patterns,
  240. }
  241. )
  242. outstats.update(stats)
  243. if log_stats_to_output is not None:
  244. log_stats_to_output(
  245. outstats, output_destination + log_filename + ".tsv"
  246. )
  247. if solution_tile_grid is not None:
  248. return (
  249. tile_grid_to_image(
  250. solution_tile_grid, tile_catalog, (tile_size, tile_size)
  251. ),
  252. outstats,
  253. )
  254. else:
  255. return None, outstats
  256. raise TimedOut("Attempt limit exceeded.")