control.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. """Main WFC execution function. Implementation based on https://github.com/ikarth/wfc_2019f"""
  2. from __future__ import annotations
  3. import logging
  4. import time
  5. from typing import Any, Callable
  6. import numpy as np
  7. from numpy.typing import NDArray
  8. from typing_extensions import Literal
  9. from minigrid.envs.wfc.wfclogic.adjacency import adjacency_extraction
  10. from minigrid.envs.wfc.wfclogic.patterns import (
  11. make_pattern_catalog_with_rotations,
  12. pattern_grid_to_tiles,
  13. )
  14. from minigrid.envs.wfc.wfclogic.solver import (
  15. Contradiction,
  16. StopEarly,
  17. TimedOut,
  18. lexicalLocationHeuristic,
  19. lexicalPatternHeuristic,
  20. make_global_use_all_patterns,
  21. makeAdj,
  22. makeAntiEntropyLocationHeuristic,
  23. makeEntropyLocationHeuristic,
  24. makeHilbertLocationHeuristic,
  25. makeRandomLocationHeuristic,
  26. makeRandomPatternHeuristic,
  27. makeRarestPatternHeuristic,
  28. makeSpiralLocationHeuristic,
  29. makeWave,
  30. makeWeightedPatternHeuristic,
  31. run,
  32. simpleLocationHeuristic,
  33. )
  34. from .tiles import make_tile_catalog
  35. from .utilities import tile_grid_to_image
  36. logger = logging.getLogger(__name__)
  37. def make_log_stats() -> Callable[[dict[str, Any], str], None]:
  38. log_line = 0
  39. def log_stats(stats: dict[str, Any], filename: str) -> None:
  40. nonlocal log_line
  41. if stats:
  42. log_line += 1
  43. with open(filename, "a", encoding="utf_8") as logf:
  44. if log_line < 2:
  45. for s in stats.keys():
  46. print(str(s), end="\t", file=logf)
  47. print("", file=logf)
  48. for s in stats.keys():
  49. print(str(stats[s]), end="\t", file=logf)
  50. print("", file=logf)
  51. return log_stats
  52. def execute_wfc(
  53. image: NDArray[np.integer],
  54. tile_size: int = 1,
  55. pattern_width: int = 2,
  56. rotations: int = 8,
  57. output_size: tuple[int, int] = (48, 48),
  58. ground: int | None = None,
  59. attempt_limit: int = 10,
  60. output_periodic: bool = True,
  61. input_periodic: bool = True,
  62. loc_heuristic: Literal[
  63. "lexical", "hilbert", "spiral", "entropy", "anti-entropy", "simple", "random"
  64. ] = "entropy",
  65. choice_heuristic: Literal["lexical", "rarest", "weighted", "random"] = "weighted",
  66. global_constraint: Literal[False, "allpatterns"] = False,
  67. backtracking: bool = False,
  68. log_filename: str = "log",
  69. logging: bool = False,
  70. global_constraints: None = None,
  71. log_stats_to_output: Callable[[dict[str, Any], str], None] | None = None,
  72. np_random: np.random.Generator | None = None,
  73. ) -> NDArray[np.integer]:
  74. time_begin = time.perf_counter()
  75. output_destination = r"./output/"
  76. np_random: np.random.Generator = (
  77. np.random.default_rng() if np_random is None else np_random
  78. )
  79. rotations -= 1 # change to zero-based
  80. input_stats = {
  81. "tile_size": tile_size,
  82. "pattern_width": pattern_width,
  83. "rotations": rotations,
  84. "output_size": output_size,
  85. "ground": ground,
  86. "attempt_limit": attempt_limit,
  87. "output_periodic": output_periodic,
  88. "input_periodic": input_periodic,
  89. "location heuristic": loc_heuristic,
  90. "choice heuristic": choice_heuristic,
  91. "global constraint": global_constraint,
  92. "backtracking": backtracking,
  93. }
  94. # TODO: generalize this to more than the four cardinal directions
  95. direction_offsets = list(enumerate([(0, -1), (1, 0), (0, 1), (-1, 0)]))
  96. tile_catalog, tile_grid, _code_list, _unique_tiles = make_tile_catalog(
  97. image, tile_size
  98. )
  99. (
  100. pattern_catalog,
  101. pattern_weights,
  102. pattern_list,
  103. pattern_grid,
  104. ) = make_pattern_catalog_with_rotations(
  105. tile_grid, pattern_width, input_is_periodic=input_periodic, rotations=rotations
  106. )
  107. logger.debug("profiling adjacency relations")
  108. adjacency_relations = adjacency_extraction(
  109. pattern_grid,
  110. pattern_catalog,
  111. direction_offsets,
  112. (pattern_width, pattern_width),
  113. )
  114. logger.debug("adjacency_relations")
  115. logger.debug(f"output size: {output_size}\noutput periodic: {output_periodic}")
  116. number_of_patterns = len(pattern_weights)
  117. logger.debug(f"# patterns: {number_of_patterns}")
  118. decode_patterns = dict(enumerate(pattern_list))
  119. encode_patterns = {x: i for i, x in enumerate(pattern_list)}
  120. adjacency_list: dict[tuple[int, int], list[set[int]]] = {}
  121. for _, adjacency in direction_offsets:
  122. adjacency_list[adjacency] = [set() for _ in pattern_weights]
  123. # logger.debug(adjacency_list)
  124. for adjacency, pattern1, pattern2 in adjacency_relations:
  125. # logger.debug(adjacency)
  126. # logger.debug(decode_patterns[pattern1])
  127. adjacency_list[adjacency][encode_patterns[pattern1]].add(
  128. encode_patterns[pattern2]
  129. )
  130. logger.debug(f"adjacency: {len(adjacency_list)}")
  131. time_adjacency = time.perf_counter()
  132. # Ground #
  133. ground_list: NDArray[np.int64] | None = None
  134. if ground:
  135. ground_list = np.vectorize(lambda x: encode_patterns[x])(
  136. pattern_grid.flat[(ground - 1) :]
  137. )
  138. if ground_list is None or ground_list.size == 0:
  139. ground_list = None
  140. wave = makeWave(
  141. number_of_patterns, output_size[0], output_size[1], ground=ground_list
  142. )
  143. adjacency_matrix = makeAdj(adjacency_list)
  144. # Heuristics #
  145. encoded_weights: NDArray[np.float64] = np.zeros(
  146. (number_of_patterns), dtype=np.float64
  147. )
  148. for w_id, w_val in pattern_weights.items():
  149. encoded_weights[encode_patterns[w_id]] = w_val
  150. choice_random_weighting: NDArray[np.float64] = (
  151. np_random.random(wave.shape[1:]) * 0.1
  152. )
  153. pattern_heuristic: Callable[
  154. [NDArray[np.bool_], NDArray[np.bool_]], int
  155. ] = lexicalPatternHeuristic
  156. if choice_heuristic == "rarest":
  157. pattern_heuristic = makeRarestPatternHeuristic(encoded_weights, np_random)
  158. if choice_heuristic == "weighted":
  159. pattern_heuristic = makeWeightedPatternHeuristic(encoded_weights, np_random)
  160. if choice_heuristic == "random":
  161. pattern_heuristic = makeRandomPatternHeuristic(encoded_weights, np_random)
  162. logger.debug(loc_heuristic)
  163. location_heuristic: Callable[
  164. [NDArray[np.bool_]], tuple[int, int]
  165. ] = lexicalLocationHeuristic
  166. if loc_heuristic == "anti-entropy":
  167. location_heuristic = makeAntiEntropyLocationHeuristic(choice_random_weighting)
  168. if loc_heuristic == "entropy":
  169. location_heuristic = makeEntropyLocationHeuristic(choice_random_weighting)
  170. if loc_heuristic == "random":
  171. location_heuristic = makeRandomLocationHeuristic(choice_random_weighting)
  172. if loc_heuristic == "simple":
  173. location_heuristic = simpleLocationHeuristic
  174. if loc_heuristic == "spiral":
  175. location_heuristic = makeSpiralLocationHeuristic(choice_random_weighting)
  176. if loc_heuristic == "hilbert":
  177. # This requires hilbert_curve to be installed
  178. location_heuristic = makeHilbertLocationHeuristic(choice_random_weighting)
  179. # Global Constraints #
  180. if global_constraint == "allpatterns":
  181. active_global_constraint = make_global_use_all_patterns()
  182. else:
  183. def active_global_constraint(wave) -> bool:
  184. return True
  185. logger.debug(active_global_constraint)
  186. combined_constraints = [active_global_constraint]
  187. def combinedConstraints(wave: NDArray[np.bool_]) -> bool:
  188. return all(fn(wave) for fn in combined_constraints)
  189. # Solving #
  190. time_solve_start = None
  191. time_solve_end = None
  192. solution_tile_grid = None
  193. logger.debug("solving...")
  194. attempts = 0
  195. while attempts < attempt_limit:
  196. attempts += 1
  197. time_solve_start = time.perf_counter()
  198. stats = {}
  199. try:
  200. solution = run(
  201. wave.copy(),
  202. adjacency_matrix,
  203. locationHeuristic=location_heuristic,
  204. patternHeuristic=pattern_heuristic,
  205. periodic=output_periodic,
  206. backtracking=backtracking,
  207. checkFeasible=combinedConstraints,
  208. )
  209. solution_as_ids = np.vectorize(lambda x: decode_patterns[x])(solution)
  210. solution_tile_grid = pattern_grid_to_tiles(solution_as_ids, pattern_catalog)
  211. time_solve_end = time.perf_counter()
  212. stats.update({"outcome": "success"})
  213. except StopEarly:
  214. logger.debug("Skipping...")
  215. stats.update({"outcome": "skipped"})
  216. raise
  217. except TimedOut:
  218. logger.debug("Timed Out")
  219. stats.update({"outcome": "timed_out"})
  220. except Contradiction:
  221. # logger.warning(f"Contradiction: {exc}")
  222. stats.update({"outcome": "contradiction"})
  223. finally:
  224. # profiler.dump_stats(f"logs/profile_{filename}_{timecode}.txt")
  225. outstats = {}
  226. outstats.update(input_stats)
  227. solve_duration = time.perf_counter() - time_solve_start
  228. if time_solve_end is not None:
  229. solve_duration = time_solve_end - time_solve_start
  230. adjacency_duration = time_solve_start - time_adjacency
  231. outstats.update(
  232. {
  233. "attempts": attempts,
  234. "time_start": time_begin,
  235. "time_adjacency": time_adjacency,
  236. "adjacency_duration": adjacency_duration,
  237. "time solve start": time_solve_start,
  238. "time solve end": time_solve_end,
  239. "solve duration": solve_duration,
  240. "pattern count": number_of_patterns,
  241. }
  242. )
  243. outstats.update(stats)
  244. if log_stats_to_output is not None:
  245. log_stats_to_output(
  246. outstats, output_destination + log_filename + ".tsv"
  247. )
  248. if solution_tile_grid is not None:
  249. return (
  250. tile_grid_to_image(
  251. solution_tile_grid, tile_catalog, (tile_size, tile_size)
  252. ),
  253. outstats,
  254. )
  255. else:
  256. return None, outstats
  257. raise TimedOut("Attempt limit exceeded.")