adjacency.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. """Convert input data to adjacency information. Implementation based on https://github.com/ikarth/wfc_2019f"""
  2. from __future__ import annotations
  3. import numpy as np
  4. from numpy.typing import NDArray
  5. def adjacency_extraction(
  6. pattern_grid: NDArray[np.int64],
  7. pattern_catalog: dict[int, NDArray[np.int64]],
  8. direction_offsets: list[tuple[int, tuple[int, int]]],
  9. pattern_size: tuple[int, int] = (2, 2),
  10. ) -> list[tuple[tuple[int, int], int, int]]:
  11. """Takes a pattern grid and returns a list of all of the legal adjacencies found in it."""
  12. def is_valid_overlap_xy(
  13. adjacency_direction: tuple[int, int], pattern_1: int, pattern_2: int
  14. ) -> bool:
  15. """Given a direction and two patterns, find the overlap of the two patterns
  16. and return True if the intersection matches."""
  17. dimensions = (1, 0)
  18. not_a_number = -1
  19. # TODO: can probably speed this up by using the right slices, rather than rolling the whole pattern...
  20. shifted = np.roll(
  21. np.pad(
  22. pattern_catalog[pattern_2],
  23. max(pattern_size),
  24. mode="constant",
  25. constant_values=not_a_number,
  26. ),
  27. adjacency_direction,
  28. dimensions,
  29. )
  30. compare = shifted[
  31. pattern_size[0] : pattern_size[0] + pattern_size[0],
  32. pattern_size[1] : pattern_size[1] + pattern_size[1],
  33. ]
  34. left = max(0, 0, +adjacency_direction[0])
  35. right = min(pattern_size[0], pattern_size[0] + adjacency_direction[0])
  36. top = max(0, 0 + adjacency_direction[1])
  37. bottom = min(pattern_size[1], pattern_size[1] + adjacency_direction[1])
  38. a = pattern_catalog[pattern_1][top:bottom, left:right]
  39. b = compare[top:bottom, left:right]
  40. res = np.array_equal(a, b)
  41. return res
  42. pattern_list = list(pattern_catalog.keys())
  43. legal = []
  44. for pattern_1 in pattern_list:
  45. for pattern_2 in pattern_list:
  46. for _direction_index, direction in direction_offsets:
  47. if is_valid_overlap_xy(direction, pattern_1, pattern_2):
  48. legal.append((direction, pattern_1, pattern_2))
  49. return legal