123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149 |
- from __future__ import annotations
- import numpy as np
- import pytest
- from numpy.typing import NDArray
- from minigrid.envs.wfc.wfclogic import solver as wfc_solver
- def test_makeWave() -> None:
- wave = wfc_solver.makeWave(3, 10, 20, ground=[-1])
- assert wave.sum() == (2 * 10 * 19) + (1 * 10 * 1)
- assert wave[2, 5, 19]
- assert not wave[1, 5, 19]
- def test_entropyLocationHeuristic() -> None:
- wave = np.ones((5, 3, 4), dtype=bool) # everything is possible
- wave[1:, 0, 0] = False # first cell is fully observed
- wave[4, :, 2] = False
- preferences: NDArray[np.float64] = np.ones((3, 4), dtype=np.float64) * 0.5
- preferences[1, 2] = 0.3
- preferences[1, 1] = 0.1
- heu = wfc_solver.makeEntropyLocationHeuristic(preferences)
- result = heu(wave)
- assert (1, 2) == result
- def test_observe() -> None:
- my_wave = np.ones((5, 3, 4), dtype=np.bool_)
- my_wave[0, 1, 2] = False
- def locHeu(wave: NDArray[np.bool_]) -> tuple[int, int]:
- assert np.array_equal(wave, my_wave)
- return 1, 2
- def patHeu(weights: NDArray[np.bool_], wave: NDArray[np.bool_]) -> int:
- assert np.array_equal(weights, my_wave[:, 1, 2])
- return 3
- assert wfc_solver.observe(
- my_wave, locationHeuristic=locHeu, patternHeuristic=patHeu
- ) == (
- 3,
- 1,
- 2,
- )
- def test_propagate() -> None:
- wave = np.ones((3, 3, 4), dtype=bool)
- adjLists = {}
- # checkerboard #0/#1 or solid fill #2
- adjLists[(+1, 0)] = adjLists[(-1, 0)] = adjLists[(0, +1)] = adjLists[(0, -1)] = [
- [1],
- [0],
- [2],
- ]
- wave[:, 0, 0] = False
- wave[0, 0, 0] = True
- adj = wfc_solver.makeAdj(adjLists)
- wfc_solver.propagate(wave, adj, periodic=False)
- expected_result = np.array(
- [
- [
- [True, False, True, False],
- [False, True, False, True],
- [True, False, True, False],
- ],
- [
- [False, True, False, True],
- [True, False, True, False],
- [False, True, False, True],
- ],
- [
- [False, False, False, False],
- [False, False, False, False],
- [False, False, False, False],
- ],
- ]
- )
- assert np.array_equal(wave, expected_result)
- def test_run() -> None:
- wave = wfc_solver.makeWave(3, 3, 4)
- adjLists = {}
- adjLists[(+1, 0)] = adjLists[(-1, 0)] = adjLists[(0, +1)] = adjLists[(0, -1)] = [
- [1],
- [0],
- [2],
- ]
- adj = wfc_solver.makeAdj(adjLists)
- first_result = wfc_solver.run(
- wave.copy(),
- adj,
- locationHeuristic=wfc_solver.lexicalLocationHeuristic,
- patternHeuristic=wfc_solver.lexicalPatternHeuristic,
- periodic=False,
- )
- expected_first_result = np.array([[0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1]])
- assert np.array_equal(first_result, expected_first_result)
- event_log: list = []
- def onChoice(pattern: int, i: int, j: int) -> None:
- event_log.append((pattern, i, j))
- def onBacktrack() -> None:
- event_log.append("backtrack")
- second_result = wfc_solver.run(
- wave.copy(),
- adj,
- locationHeuristic=wfc_solver.lexicalLocationHeuristic,
- patternHeuristic=wfc_solver.lexicalPatternHeuristic,
- periodic=True,
- backtracking=True,
- onChoice=onChoice,
- onBacktrack=onBacktrack,
- )
- expected_second_result = np.array([[2, 2, 2, 2], [2, 2, 2, 2], [2, 2, 2, 2]])
- assert np.array_equal(second_result, expected_second_result)
- assert event_log == [(0, 0, 0), "backtrack", (2, 0, 0)]
- class Infeasible(Exception):
- pass
- def explode(wave: NDArray[np.bool_]) -> bool:
- if wave.sum() < 20:
- raise Infeasible
- return False
- with pytest.raises(wfc_solver.Contradiction):
- wfc_solver.run(
- wave.copy(),
- adj,
- locationHeuristic=wfc_solver.lexicalLocationHeuristic,
- patternHeuristic=wfc_solver.lexicalPatternHeuristic,
- periodic=True,
- backtracking=True,
- checkFeasible=explode,
- )
|