test_wfc_solver.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. from __future__ import annotations
  2. import numpy as np
  3. import pytest
  4. from numpy.typing import NDArray
  5. from minigrid.envs.wfc.wfclogic import solver as wfc_solver
  6. def test_makeWave() -> None:
  7. wave = wfc_solver.makeWave(3, 10, 20, ground=[-1])
  8. assert wave.sum() == (2 * 10 * 19) + (1 * 10 * 1)
  9. assert wave[2, 5, 19]
  10. assert not wave[1, 5, 19]
  11. def test_entropyLocationHeuristic() -> None:
  12. wave = np.ones((5, 3, 4), dtype=bool) # everything is possible
  13. wave[1:, 0, 0] = False # first cell is fully observed
  14. wave[4, :, 2] = False
  15. preferences: NDArray[np.float64] = np.ones((3, 4), dtype=np.float64) * 0.5
  16. preferences[1, 2] = 0.3
  17. preferences[1, 1] = 0.1
  18. heu = wfc_solver.makeEntropyLocationHeuristic(preferences)
  19. result = heu(wave)
  20. assert (1, 2) == result
  21. def test_observe() -> None:
  22. my_wave = np.ones((5, 3, 4), dtype=np.bool_)
  23. my_wave[0, 1, 2] = False
  24. def locHeu(wave: NDArray[np.bool_]) -> tuple[int, int]:
  25. assert np.array_equal(wave, my_wave)
  26. return 1, 2
  27. def patHeu(weights: NDArray[np.bool_], wave: NDArray[np.bool_]) -> int:
  28. assert np.array_equal(weights, my_wave[:, 1, 2])
  29. return 3
  30. assert wfc_solver.observe(
  31. my_wave, locationHeuristic=locHeu, patternHeuristic=patHeu
  32. ) == (
  33. 3,
  34. 1,
  35. 2,
  36. )
  37. def test_propagate() -> None:
  38. wave = np.ones((3, 3, 4), dtype=bool)
  39. adjLists = {}
  40. # checkerboard #0/#1 or solid fill #2
  41. adjLists[(+1, 0)] = adjLists[(-1, 0)] = adjLists[(0, +1)] = adjLists[(0, -1)] = [
  42. [1],
  43. [0],
  44. [2],
  45. ]
  46. wave[:, 0, 0] = False
  47. wave[0, 0, 0] = True
  48. adj = wfc_solver.makeAdj(adjLists)
  49. wfc_solver.propagate(wave, adj, periodic=False)
  50. expected_result = np.array(
  51. [
  52. [
  53. [True, False, True, False],
  54. [False, True, False, True],
  55. [True, False, True, False],
  56. ],
  57. [
  58. [False, True, False, True],
  59. [True, False, True, False],
  60. [False, True, False, True],
  61. ],
  62. [
  63. [False, False, False, False],
  64. [False, False, False, False],
  65. [False, False, False, False],
  66. ],
  67. ]
  68. )
  69. assert np.array_equal(wave, expected_result)
  70. def test_run() -> None:
  71. wave = wfc_solver.makeWave(3, 3, 4)
  72. adjLists = {}
  73. adjLists[(+1, 0)] = adjLists[(-1, 0)] = adjLists[(0, +1)] = adjLists[(0, -1)] = [
  74. [1],
  75. [0],
  76. [2],
  77. ]
  78. adj = wfc_solver.makeAdj(adjLists)
  79. first_result = wfc_solver.run(
  80. wave.copy(),
  81. adj,
  82. locationHeuristic=wfc_solver.lexicalLocationHeuristic,
  83. patternHeuristic=wfc_solver.lexicalPatternHeuristic,
  84. periodic=False,
  85. )
  86. expected_first_result = np.array([[0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1]])
  87. assert np.array_equal(first_result, expected_first_result)
  88. event_log: list = []
  89. def onChoice(pattern: int, i: int, j: int) -> None:
  90. event_log.append((pattern, i, j))
  91. def onBacktrack() -> None:
  92. event_log.append("backtrack")
  93. second_result = wfc_solver.run(
  94. wave.copy(),
  95. adj,
  96. locationHeuristic=wfc_solver.lexicalLocationHeuristic,
  97. patternHeuristic=wfc_solver.lexicalPatternHeuristic,
  98. periodic=True,
  99. backtracking=True,
  100. onChoice=onChoice,
  101. onBacktrack=onBacktrack,
  102. )
  103. expected_second_result = np.array([[2, 2, 2, 2], [2, 2, 2, 2], [2, 2, 2, 2]])
  104. assert np.array_equal(second_result, expected_second_result)
  105. assert event_log == [(0, 0, 0), "backtrack", (2, 0, 0)]
  106. class Infeasible(Exception):
  107. pass
  108. def explode(wave: NDArray[np.bool_]) -> bool:
  109. if wave.sum() < 20:
  110. raise Infeasible
  111. return False
  112. with pytest.raises(wfc_solver.Contradiction):
  113. wfc_solver.run(
  114. wave.copy(),
  115. adj,
  116. locationHeuristic=wfc_solver.lexicalLocationHeuristic,
  117. patternHeuristic=wfc_solver.lexicalPatternHeuristic,
  118. periodic=True,
  119. backtracking=True,
  120. checkFeasible=explode,
  121. )