Jelajahi Sumber

Wave Function Collapse Environments (#371)

Co-authored-by: Isaac Karth <isaac@isaackarth.com>
Co-authored-by: Valentin Valls <valentin.valls@gmail.com>
Co-authored-by: Isaac Karth <isaackarth@gmail.com>
Co-authored-by: Kyle Benesch <4b796c65+github@gmail.com>
Co-authored-by: Samuel Garcin <garcin.samuel@gmail.com>
Co-authored-by: Mark Towers <mark.m.towers@gmail.com>
James Doran 1 tahun lalu
induk
melakukan
e726259e86
49 mengubah file dengan 2421 tambahan dan 3 penghapusan
  1. 1 1
      .pre-commit-config.yaml
  2. 10 0
      minigrid/__init__.py
  3. 1 1
      minigrid/core/world_object.py
  4. 24 0
      minigrid/envs/wfc/__init__.py
  5. 220 0
      minigrid/envs/wfc/config.py
  6. 396 0
      minigrid/envs/wfc/graphtransforms.py
  7. TEMPAT SAMPAH
      minigrid/envs/wfc/patterns/Angular.png
  8. TEMPAT SAMPAH
      minigrid/envs/wfc/patterns/Blackdots.png
  9. TEMPAT SAMPAH
      minigrid/envs/wfc/patterns/Cave.png
  10. TEMPAT SAMPAH
      minigrid/envs/wfc/patterns/City.png
  11. TEMPAT SAMPAH
      minigrid/envs/wfc/patterns/DungeonExtr.png
  12. TEMPAT SAMPAH
      minigrid/envs/wfc/patterns/Fabric.png
  13. TEMPAT SAMPAH
      minigrid/envs/wfc/patterns/Hogs.png
  14. TEMPAT SAMPAH
      minigrid/envs/wfc/patterns/Knot.png
  15. TEMPAT SAMPAH
      minigrid/envs/wfc/patterns/Lake.png
  16. TEMPAT SAMPAH
      minigrid/envs/wfc/patterns/LessRooms.png
  17. TEMPAT SAMPAH
      minigrid/envs/wfc/patterns/MagicOffice.png
  18. TEMPAT SAMPAH
      minigrid/envs/wfc/patterns/Maze.png
  19. TEMPAT SAMPAH
      minigrid/envs/wfc/patterns/Mazelike.png
  20. TEMPAT SAMPAH
      minigrid/envs/wfc/patterns/Office.png
  21. TEMPAT SAMPAH
      minigrid/envs/wfc/patterns/Paths.png
  22. TEMPAT SAMPAH
      minigrid/envs/wfc/patterns/RedMaze.png
  23. TEMPAT SAMPAH
      minigrid/envs/wfc/patterns/Rooms.png
  24. TEMPAT SAMPAH
      minigrid/envs/wfc/patterns/ScaledMaze.png
  25. TEMPAT SAMPAH
      minigrid/envs/wfc/patterns/SimpleKnot.png
  26. TEMPAT SAMPAH
      minigrid/envs/wfc/patterns/SimpleMaze.png
  27. TEMPAT SAMPAH
      minigrid/envs/wfc/patterns/SimpleWall.png
  28. TEMPAT SAMPAH
      minigrid/envs/wfc/patterns/Skew1.png
  29. TEMPAT SAMPAH
      minigrid/envs/wfc/patterns/Skew2.png
  30. TEMPAT SAMPAH
      minigrid/envs/wfc/patterns/Spirals.png
  31. TEMPAT SAMPAH
      minigrid/envs/wfc/patterns/SpiralsNeg.png
  32. 226 0
      minigrid/envs/wfc/wfcenv.py
  33. 0 0
      minigrid/envs/wfc/wfclogic/__init__.py
  34. 56 0
      minigrid/envs/wfc/wfclogic/adjacency.py
  35. 295 0
      minigrid/envs/wfc/wfclogic/control.py
  36. 199 0
      minigrid/envs/wfc/wfclogic/patterns.py
  37. 530 0
      minigrid/envs/wfc/wfclogic/solver.py
  38. 64 0
      minigrid/envs/wfc/wfclogic/tiles.py
  39. 77 0
      minigrid/envs/wfc/wfclogic/utilities.py
  40. 1 0
      minigrid/wrappers.py
  41. 1 1
      py.Dockerfile
  42. 4 0
      pyproject.toml
  43. 0 0
      tests/test_wfc/__init__.py
  44. 40 0
      tests/test_wfc/conftest.py
  45. 41 0
      tests/test_wfc/test_wfc_adjacency.py
  46. 60 0
      tests/test_wfc/test_wfc_patterns.py
  47. 148 0
      tests/test_wfc/test_wfc_solver.py
  48. 17 0
      tests/test_wfc/test_wfc_tiles.py
  49. 10 0
      tests/utils.py

+ 1 - 1
.pre-commit-config.yaml

@@ -16,7 +16,7 @@ repos:
       - id: flake8
         args:
           - '--per-file-ignores=*/__init__.py:F401'
-#          - --ignore=
+          - --ignore=E203, W503
           - --max-complexity=30
           - --max-line-length=456
           - --show-source

+ 10 - 0
minigrid/__init__.py

@@ -5,6 +5,7 @@ from gymnasium.envs.registration import register
 from minigrid import minigrid_env, wrappers
 from minigrid.core import roomgrid
 from minigrid.core.world_object import Wall
+from minigrid.envs.wfc.config import WFC_PRESETS
 
 __version__ = "2.3.1"
 
@@ -565,6 +566,15 @@ def register_minigrid_envs():
         entry_point="minigrid.envs:UnlockPickupEnv",
     )
 
+    # WaveFunctionCollapse
+    # ----------------------------------------
+    for name in WFC_PRESETS.keys():
+        register(
+            id=f"MiniGrid-WFC-{name}-v0",
+            entry_point="minigrid.envs.wfc:WFCEnv",
+            kwargs={"wfc_config": name},
+        )
+
     # BabyAI - Language based levels - GoTo
     # ----------------------------------------
 

+ 1 - 1
minigrid/core/world_object.py

@@ -74,7 +74,7 @@ class WorldObj:
         obj_type = IDX_TO_OBJECT[type_idx]
         color = IDX_TO_COLOR[color_idx]
 
-        if obj_type == "empty" or obj_type == "unseen":
+        if obj_type == "empty" or obj_type == "unseen" or obj_type == "agent":
             return None
 
         # State, 0: open, 1: closed, 2: locked

+ 24 - 0
minigrid/envs/wfc/__init__.py

@@ -0,0 +1,24 @@
+from __future__ import annotations
+
+from minigrid.envs.wfc.config import (
+    WFC_PRESETS,
+    WFC_PRESETS_INCONSISTENT,
+    WFC_PRESETS_SLOW,
+    WFCConfig,
+)
+
+# This is wrapped in a try-except block so the presets can be accessed for registration
+# Otherwise, importing here will fail when networkx is not installed
+try:
+    from minigrid.envs.wfc.wfcenv import WFCEnv
+except ImportError:
+
+    class WFCEnv:
+        """Dummy class to give a helpful error message when dependencies are missing"""
+
+        def __init__(self, *args, **kwargs):
+            from gymnasium.error import DependencyNotInstalled
+
+            raise DependencyNotInstalled(
+                'WFC dependencies are missing, please run `pip install "minigrid[wfc]"`'
+            )

+ 220 - 0
minigrid/envs/wfc/config.py

@@ -0,0 +1,220 @@
+from __future__ import annotations
+
+from dataclasses import asdict, dataclass
+from pathlib import Path
+
+from typing_extensions import Literal
+
+PATTERN_PATH = Path(__file__).parent / "patterns"
+
+
+@dataclass
+class WFCConfig:
+    """Dataclass for holding WFC configuration parameters.
+
+    This controls the behavior of the WFC algorithm. The parameters are passed directly to the WFC solver.
+
+    Attributes:
+        pattern_path: Path to the pattern image that will be automatically loaded.
+        tile_size: Size of the tiles in pixels to create from the pattern image.
+        pattern_width: Size of the patterns in tiles to take from the pattern image. (greater than 3 is quite slow)
+        rotations: Number of rotations for each tile.
+        output_periodic: Whether the output should be periodic (wraps over edges).
+        input_periodic: Whether the input should be periodic (wraps over edges).
+        loc_heuristic: Heuristic for choosing the next tile location to collapse.
+        choice_heuristic: Heuristic for choosing the next tile to use between possible tiles.
+        backtracking: Whether to backtrack when contradictions are discovered.
+    """
+
+    pattern_path: Path
+    tile_size: int = 1
+    pattern_width: int = 2
+    rotations: int = 8
+    output_periodic: bool = False
+    input_periodic: bool = False
+    loc_heuristic: Literal[
+        "lexical", "spiral", "entropy", "anti-entropy", "simple", "random"
+    ] = "entropy"
+    choice_heuristic: Literal["lexical", "rarest", "weighted", "random"] = "weighted"
+    backtracking: bool = False
+
+    @property
+    def wfc_kwargs(self):
+        try:
+            from imageio.v2 import imread
+        except ImportError as e:
+            from gymnasium.error import DependencyNotInstalled
+
+            raise DependencyNotInstalled(
+                'imageio is missing, please run `pip install "minigrid[wfc]"`'
+            ) from e
+        kwargs = asdict(self)
+        kwargs["image"] = imread(kwargs.pop("pattern_path"))[:, :, :3]
+        return kwargs
+
+
+# Basic presets for WFC configurations (that should generate in <1 min)
+WFC_PRESETS = {
+    "MazeSimple": WFCConfig(
+        pattern_path=PATTERN_PATH / "SimpleMaze.png",
+        tile_size=1,
+        pattern_width=2,
+        output_periodic=False,
+        input_periodic=False,
+    ),
+    "DungeonMazeScaled": WFCConfig(
+        pattern_path=PATTERN_PATH / "ScaledMaze.png",
+        tile_size=1,
+        pattern_width=2,
+        output_periodic=True,
+        input_periodic=True,
+    ),
+    "RoomsFabric": WFCConfig(
+        pattern_path=PATTERN_PATH / "Fabric.png",
+        tile_size=1,
+        pattern_width=3,
+        output_periodic=False,
+        input_periodic=False,
+    ),
+    "ObstaclesBlackdots": WFCConfig(
+        pattern_path=PATTERN_PATH / "Blackdots.png",
+        tile_size=1,
+        pattern_width=2,
+        output_periodic=False,
+        input_periodic=False,
+    ),
+    "ObstaclesAngular": WFCConfig(
+        pattern_path=PATTERN_PATH / "Angular.png",
+        tile_size=1,
+        pattern_width=3,
+        output_periodic=True,
+        input_periodic=True,
+    ),
+    "ObstaclesHogs3": WFCConfig(
+        pattern_path=PATTERN_PATH / "Hogs.png",
+        tile_size=1,
+        pattern_width=3,
+        output_periodic=True,
+        input_periodic=True,
+    ),
+}
+
+# Presets that take a large number of attempts to generate a consistent environment
+WFC_PRESETS_INCONSISTENT = {
+    "MazeKnot": WFCConfig(
+        pattern_path=PATTERN_PATH / "Knot.png",
+        tile_size=1,
+        pattern_width=3,
+        output_periodic=True,
+        input_periodic=True,
+    ),  # This is not too inconsistent (often 10 attempts is enough)
+    "MazeWall": WFCConfig(
+        pattern_path=PATTERN_PATH / "SimpleWall.png",
+        tile_size=1,
+        pattern_width=2,
+        output_periodic=True,
+        input_periodic=True,
+    ),
+    "RoomsOffice": WFCConfig(
+        pattern_path=PATTERN_PATH / "Office.png",
+        tile_size=1,
+        pattern_width=3,
+        output_periodic=True,
+        input_periodic=True,
+    ),
+    "ObstaclesHogs2": WFCConfig(
+        pattern_path=PATTERN_PATH / "Hogs.png",
+        tile_size=1,
+        pattern_width=2,
+        output_periodic=True,
+        input_periodic=True,
+    ),
+    "Skew2": WFCConfig(
+        pattern_path=PATTERN_PATH / "Skew2.png",
+        tile_size=1,
+        pattern_width=3,
+        output_periodic=True,
+        input_periodic=True,
+    ),
+}
+
+# Slow presets for WFC configurations (Most take about 2-4 min but some take 10+ min)
+WFC_PRESETS_SLOW = {
+    "Maze": WFCConfig(
+        pattern_path=PATTERN_PATH / "Maze.png",
+        tile_size=1,
+        pattern_width=3,
+        output_periodic=True,
+        input_periodic=True,
+    ),  # This is unusually slow: ~20min per 25x25 room
+    "MazeSpirals": WFCConfig(
+        pattern_path=PATTERN_PATH / "Spirals.png",
+        tile_size=1,
+        pattern_width=3,
+        output_periodic=True,
+        input_periodic=True,
+    ),
+    "MazePaths": WFCConfig(
+        pattern_path=PATTERN_PATH / "Paths.png",
+        tile_size=1,
+        pattern_width=3,
+        output_periodic=True,
+        input_periodic=True,
+    ),
+    "Mazelike": WFCConfig(
+        pattern_path=PATTERN_PATH / "Mazelike.png",
+        tile_size=1,
+        pattern_width=3,
+        output_periodic=True,
+        input_periodic=True,
+    ),
+    "Dungeon": WFCConfig(
+        pattern_path=PATTERN_PATH / "DungeonExtr.png",
+        tile_size=1,
+        pattern_width=3,
+        output_periodic=True,
+        input_periodic=True,
+    ),  # ~10 mins
+    "DungeonRooms": WFCConfig(
+        pattern_path=PATTERN_PATH / "Rooms.png",
+        tile_size=1,
+        pattern_width=3,
+        output_periodic=True,
+        input_periodic=True,
+    ),
+    "DungeonLessRooms": WFCConfig(
+        pattern_path=PATTERN_PATH / "LessRooms.png",
+        tile_size=1,
+        pattern_width=3,
+        output_periodic=True,
+        input_periodic=True,
+    ),
+    "DungeonSpirals": WFCConfig(
+        pattern_path=PATTERN_PATH / "SpiralsNeg.png",
+        tile_size=1,
+        pattern_width=3,
+        output_periodic=True,
+        input_periodic=True,
+    ),
+    "RoomsMagicOffice": WFCConfig(
+        pattern_path=PATTERN_PATH / "MagicOffice.png",
+        tile_size=1,
+        pattern_width=3,
+        output_periodic=True,
+        input_periodic=True,
+    ),
+    "SkewCave": WFCConfig(
+        pattern_path=PATTERN_PATH / "Cave.png",
+        tile_size=1,
+        pattern_width=3,
+        output_periodic=False,
+        input_periodic=False,
+    ),
+    "SkewLake": WFCConfig(
+        pattern_path=PATTERN_PATH / "Lake.png",
+        tile_size=1,
+        pattern_width=3,
+        output_periodic=True,
+        input_periodic=True,
+    ),  # ~10 mins
+}

+ 396 - 0
minigrid/envs/wfc/graphtransforms.py

@@ -0,0 +1,396 @@
+from __future__ import annotations
+
+from collections import OrderedDict, defaultdict
+from dataclasses import dataclass
+from itertools import product
+
+import networkx as nx
+import numpy as np
+
+from minigrid.core.constants import COLOR_TO_IDX, IDX_TO_OBJECT, OBJECT_TO_IDX
+from minigrid.minigrid_env import MiniGridEnv
+
+
+@dataclass
+class EdgeDescriptor:
+    between: tuple[str, str] | tuple[str]
+    structure: str | None = None
+
+
+# This is maybe general enough to be in utils
+class GraphTransforms:
+    OBJECT_TO_DENSE_GRAPH_ATTRIBUTE = {
+        "empty": ("navigable", "empty"),
+        "start": ("navigable", "start"),
+        "agent": ("navigable", "start"),
+        "goal": ("navigable", "goal"),
+        "moss": ("navigable", "moss"),
+        "wall": ("non_navigable", "wall"),
+        "lava": ("non_navigable", "lava"),
+    }
+
+    DENSE_GRAPH_ATTRIBUTE_TO_OBJECT = {
+        "empty": "empty",
+        "start": "start",
+        "goal": "goal",
+        "moss": "moss",
+        "wall": "wall",
+        "lava": "lava",
+        "navigable": None,
+        "non_navigable": None,
+    }
+
+    MINIGRID_COLOR_CONFIG = {
+        "empty": None,
+        "wall": "grey",
+        "agent": "blue",
+        "goal": "green",
+        "lava": "red",
+        "moss": "purple",
+    }
+
+    @staticmethod
+    def minigrid_to_bitmap(grids):
+
+        layout = grids[..., 0]
+        bitmap = np.zeros_like(layout)
+        bitmap[layout == 2] = 1
+        bitmap = list(bitmap)
+
+        start_pos_id = np.where(layout == 10)
+        goal_pos_id = np.where(layout == 8)
+
+        start_pos = []
+        goal_pos = []
+        for i in range(len(bitmap)):
+            bitmap[i] = bitmap[i][1:-1, 1:-1]
+            start_pos.append(np.array([start_pos_id[2][i], start_pos_id[1][i]]))
+            goal_pos.append(np.array([goal_pos_id[2][i], goal_pos_id[1][i]]))
+
+        return bitmap, start_pos, goal_pos
+
+    @staticmethod
+    def minigrid_to_dense_graph(
+        minigrids: np.ndarray | list[MiniGridEnv],
+        node_attr=None,
+        edge_config=None,
+    ) -> list[nx.Graph]:
+        if isinstance(minigrids[0], np.ndarray):
+            minigrids = np.array(minigrids)
+            layouts = minigrids[..., 0]
+        elif isinstance(minigrids[0], MiniGridEnv):
+            layouts = [minigrid.grid.encode()[..., 0] for minigrid in minigrids]
+            for i in range(len(minigrids)):
+                layouts[i][tuple(minigrids[i].agent_pos)] = OBJECT_TO_IDX["agent"]
+            layouts = np.array(layouts)
+        else:
+            raise TypeError(
+                f"minigrids must be of type List[np.ndarray], List[MiniGridEnv], "
+                f"List[MultiGridEnv], not {type(minigrids[0])}"
+            )
+        graphs, _ = GraphTransforms.minigrid_layout_to_dense_graph(
+            layouts, remove_border=True, node_attr=node_attr, edge_config=edge_config
+        )
+        return graphs
+
+    @staticmethod
+    def minigrid_layout_to_dense_graph(
+        layouts: np.ndarray, remove_border=True, node_attr=None, edge_config=None
+    ) -> tuple[list[nx.Graph], dict[str, list[nx.Graph]]]:
+
+        assert (
+            layouts.ndim == 3
+        ), f"Wrong dimensions for minigrid layout, expected 3 dimensions, got {layouts.ndim}."
+
+        node_attr = [] if node_attr is None else node_attr
+
+        # Remove borders
+        if remove_border:
+            layouts = layouts[:, 1:-1, 1:-1]  # remove edges
+        dim_grid = layouts.shape[1:]
+
+        # Get the objects present in the layout
+        objects_idx = np.unique(layouts)
+        object_instances = [IDX_TO_OBJECT[obj] for obj in objects_idx]
+        assert set(object_instances).issubset(
+            {"empty", "wall", "start", "goal", "agent", "lava", "moss"}
+        ), (
+            f"Unsupported object(s) in minigrid layout. Supported objects are: "
+            f"empty, wall, start, goal, agent, lava, moss. Got {object_instances}."
+        )
+
+        # Get location of each object in the layout
+        object_locations = {}
+        for obj in object_instances:
+            object_locations[obj] = defaultdict(list)
+            ids = list(zip(*np.where(layouts == OBJECT_TO_IDX[obj])))
+            for tup in ids:
+                object_locations[obj][tup[0]].append(tup[1:])
+            for m in range(layouts.shape[0]):
+                if m not in object_locations[obj]:
+                    object_locations[obj][m] = []
+            object_locations[obj] = OrderedDict(sorted(object_locations[obj].items()))
+        if "start" not in object_instances and "agent" in object_instances:
+            object_locations["start"] = object_locations["agent"]
+        if "agent" not in object_instances and "start" in object_instances:
+            object_locations["agent"] = object_locations["start"]
+
+        # Create one-hot graph feature tensor
+        graph_feats = {}
+        object_to_attr = GraphTransforms.OBJECT_TO_DENSE_GRAPH_ATTRIBUTE
+        for obj in object_instances:
+            for attr in object_to_attr[obj]:
+                if attr not in graph_feats and attr in node_attr:
+                    graph_feats[attr] = np.zeros(layouts.shape)
+                loc = list(object_locations[obj].values())
+                assert len(loc) == layouts.shape[0]
+                for m in range(layouts.shape[0]):
+                    if loc[m]:
+                        loc_m = np.array(loc[m])
+                        graph_feats[attr][m][loc_m[:, 0], loc_m[:, 1]] = 1
+        for attr in node_attr:
+            if attr not in graph_feats:
+                graph_feats[attr] = np.zeros(layouts.shape)
+            graph_feats[attr] = graph_feats[attr].reshape(layouts.shape[0], -1)
+
+        graphs, edge_graphs = GraphTransforms.features_to_dense_graph(
+            graph_feats, dim_grid, edge_config
+        )
+
+        return graphs, edge_graphs
+
+    @staticmethod
+    def features_to_dense_graph(
+        features: dict[str, np.ndarray],
+        dim_grid: tuple,
+        edge_config: dict[str, EdgeDescriptor] = None,
+    ) -> tuple[list[nx.Graph], dict[str, list[nx.Graph]]]:
+
+        graphs = []
+        edge_graphs = defaultdict(list)
+        for m in range(features[list(features.keys())[0]].shape[0]):
+            g_temp = nx.grid_2d_graph(*dim_grid)
+            g = nx.Graph()
+            g.add_nodes_from(sorted(g_temp.nodes(data=True)))
+            for attr in features:
+                nx.set_node_attributes(
+                    g, {k: v for k, v in zip(g.nodes, features[attr][m].tolist())}, attr
+                )
+            if edge_config is not None:
+                edge_layers = GraphTransforms.get_edge_layers(
+                    g, edge_config, list(features.keys()), dim_grid
+                )
+                for edge_n, edge_g in edge_layers.items():
+                    g.add_edges_from(edge_g.edges(data=True), label=edge_n)
+                    edge_graphs[edge_n].append(edge_g)
+            graphs.append(g)
+
+        return graphs, edge_graphs
+
+    @staticmethod
+    def graph_features_to_minigrid(
+        graph_features: dict[str, np.ndarray], shape: tuple[int, int], padding=1
+    ) -> np.ndarray:
+
+        features = graph_features.copy()
+        node_attributes = list(features.keys())
+
+        color_config = GraphTransforms.MINIGRID_COLOR_CONFIG
+
+        # shape_no_padding = (features[node_attributes[0]].shape[-2], shape[0] - 2, shape[1] - 2, 3)
+        shape_no_padding = (shape[0] - 2 * padding, shape[1] - 2 * padding, 3)
+        for attr in node_attributes:
+            features[attr] = features[attr].reshape(*shape_no_padding[:-1])
+        grids = np.ones(shape_no_padding, dtype=np.uint8) * OBJECT_TO_IDX["empty"]
+
+        minigrid_object_to_encoding_map = {}  # [object_id, color, state]
+        for feature in node_attributes:
+            obj_type = GraphTransforms.DENSE_GRAPH_ATTRIBUTE_TO_OBJECT[feature]
+            if (
+                obj_type is not None
+                and obj_type not in minigrid_object_to_encoding_map.keys()
+            ):
+                if obj_type == "empty":
+                    minigrid_object_to_encoding_map[obj_type] = [
+                        OBJECT_TO_IDX["empty"],
+                        0,
+                        0,
+                    ]
+                elif obj_type == "agent":
+                    minigrid_object_to_encoding_map[obj_type] = [
+                        OBJECT_TO_IDX["agent"],
+                        0,
+                        0,
+                    ]
+                elif obj_type == "start":
+                    color_str = color_config["agent"]
+                    minigrid_object_to_encoding_map[obj_type] = [
+                        OBJECT_TO_IDX["agent"],
+                        COLOR_TO_IDX[color_str],
+                        0,
+                    ]
+                else:
+                    color_str = color_config[obj_type]
+                    minigrid_object_to_encoding_map[obj_type] = [
+                        OBJECT_TO_IDX[obj_type],
+                        COLOR_TO_IDX[color_str],
+                        0,
+                    ]
+
+        if (
+            "start" not in minigrid_object_to_encoding_map.keys()
+            and "agent" in minigrid_object_to_encoding_map.keys()
+        ):
+            minigrid_object_to_encoding_map["start"] = minigrid_object_to_encoding_map[
+                "agent"
+            ]
+        if (
+            "agent" not in minigrid_object_to_encoding_map.keys()
+            and "start" in minigrid_object_to_encoding_map.keys()
+        ):
+            minigrid_object_to_encoding_map["agent"] = minigrid_object_to_encoding_map[
+                "start"
+            ]
+
+        for i, attr in enumerate(node_attributes):
+            if "wall" not in node_attributes:
+                if attr == "navigable" and "wall" not in node_attributes:
+                    mapping = minigrid_object_to_encoding_map["wall"]
+                    grids[features[attr] == 0] = np.array(mapping, dtype=np.uint8)
+                else:
+                    mapping = minigrid_object_to_encoding_map[attr]
+                    grids[features[attr] == 1] = np.array(mapping, dtype=np.uint8)
+            else:
+                try:
+                    mapping = minigrid_object_to_encoding_map[attr]
+                    grids[features[attr] == 1] = np.array(mapping, dtype=np.uint8)
+                except KeyError:
+                    pass
+
+        wall_encoding = np.array(
+            minigrid_object_to_encoding_map["wall"], dtype=np.uint8
+        )
+        padded_grid = np.pad(
+            grids,
+            ((padding, padding), (padding, padding), (0, 0)),
+            "constant",
+            constant_values=-1,
+        )
+        padded_grid = np.where(
+            padded_grid == -np.ones(3, dtype=np.uint8), wall_encoding, padded_grid
+        )
+        return padded_grid
+
+    @staticmethod
+    def get_node_features(
+        graph: nx.Graph, pattern_shape, node_attributes: list[str] = None, reshape=True
+    ) -> tuple[np.ndarray, list[str]]:
+
+        if node_attributes is None:
+            # Get node attributes from some node
+            node_attributes = list(next(iter(graph.nodes.data()))[1].keys())
+
+        # Get node features
+        Fx = []
+        for attr in node_attributes:
+            if attr == "non_navigable" or attr == "wall":
+                # The graph we are getting is only the navigable nodes so those that
+                # are not present should be assumed to be walls and non-navigable
+                f = np.ones(pattern_shape)
+            else:
+                f = np.zeros(pattern_shape)
+            for node, data in graph.nodes.data(attr):
+                f[node] = data
+            if reshape:
+                f = f.ravel()
+            Fx.append(f)
+        # Fx = torch.stack(Fx, dim=-1).to(device)
+        Fx = np.stack(Fx, axis=-1)
+
+        return Fx, node_attributes
+
+    @staticmethod
+    def dense_graph_to_minigrid(
+        graph: nx.Graph, shape: tuple[int, int], padding=1
+    ) -> np.ndarray:
+
+        pattern_shape = (shape[0] - 2 * padding, shape[1] - 2 * padding)
+        features, node_attributes = GraphTransforms.get_node_features(
+            graph, pattern_shape, node_attributes=None
+        )
+        # num_zeros = features[features == 0.0].numel()
+        # num_ones = features[features == 1.0].numel()
+        num_zeros = (features == 0.0).sum()
+        num_ones = (features == 1.0).sum()
+
+        assert num_zeros + num_ones == features.size, "Graph features should be binary"
+        features_dict = {}
+        for i, key in enumerate(node_attributes):
+            features_dict[key] = features[..., i]
+        grids = GraphTransforms.graph_features_to_minigrid(
+            features_dict, shape=shape, padding=padding
+        )
+
+        return grids
+
+    @staticmethod
+    def get_edge_layers(
+        graph: nx.Graph,
+        edge_config: dict[str, EdgeDescriptor],
+        node_attr: list[str],
+        dim_grid: tuple[int, int],
+    ) -> dict[str, nx.Graph]:
+
+        navigable_nodes = ["empty", "start", "goal", "moss"]
+        non_navigable_nodes = ["wall", "lava"]
+        assert all([isinstance(n, tuple) for n in graph.nodes])
+        assert all([len(n) == 2 for n in graph.nodes])
+
+        def partial_grid(graph, nodes, dim_grid):
+            non_grid_nodes = [n for n in graph.nodes if n not in nodes]
+            g_temp = nx.grid_2d_graph(*dim_grid)
+            g_temp.remove_nodes_from(non_grid_nodes)
+            g_temp.add_nodes_from(non_grid_nodes)
+            g = nx.Graph()
+            g.add_nodes_from(graph.nodes(data=True))
+            g.add_edges_from(g_temp.edges)
+            return g
+
+        def pair_edges(graph, node_types):
+            all_nodes = []
+            for n_type in node_types:
+                all_nodes.append(
+                    [n for n, a in graph.nodes.items() if a[n_type] >= 1.0]
+                )
+            edges = list(product(*all_nodes))
+            edged_graph = nx.create_empty_copy(graph, with_data=True)
+            edged_graph.add_edges_from(edges)
+            return edged_graph
+
+        edge_graphs = {}
+        for edge_ in edge_config.keys():
+            if edge_ == "navigable" and "navigable" not in node_attr:
+                edge_config[edge_].between = navigable_nodes
+            elif edge_ == "non_navigable" and "non_navigable" not in node_attr:
+                edge_config[edge_].between = non_navigable_nodes
+            elif not set(edge_config[edge_].between).issubset(set(node_attr)):
+                # TODO: remove
+                # logger.warning(f"Edge {edge_} not compatible with node attributes {node_attr}. Skipping.")
+                continue
+            if edge_config[edge_].structure is None:
+                edge_graphs[edge_] = pair_edges(graph, edge_config[edge_].between)
+            elif edge_config[edge_].structure == "grid":
+                nodes = []
+                for n_type in edge_config[edge_].between:
+                    nodes += [
+                        n
+                        for n, a in graph.nodes.items()
+                        if a[n_type] >= 1.0 and n not in nodes
+                    ]
+                edge_graphs[edge_] = partial_grid(graph, nodes, dim_grid)
+            else:
+                raise NotImplementedError(
+                    f"Edge structure {edge_config[edge_].structure} not supported."
+                )
+
+        return edge_graphs

TEMPAT SAMPAH
minigrid/envs/wfc/patterns/Angular.png


TEMPAT SAMPAH
minigrid/envs/wfc/patterns/Blackdots.png


TEMPAT SAMPAH
minigrid/envs/wfc/patterns/Cave.png


TEMPAT SAMPAH
minigrid/envs/wfc/patterns/City.png


TEMPAT SAMPAH
minigrid/envs/wfc/patterns/DungeonExtr.png


TEMPAT SAMPAH
minigrid/envs/wfc/patterns/Fabric.png


TEMPAT SAMPAH
minigrid/envs/wfc/patterns/Hogs.png


TEMPAT SAMPAH
minigrid/envs/wfc/patterns/Knot.png


TEMPAT SAMPAH
minigrid/envs/wfc/patterns/Lake.png


TEMPAT SAMPAH
minigrid/envs/wfc/patterns/LessRooms.png


TEMPAT SAMPAH
minigrid/envs/wfc/patterns/MagicOffice.png


TEMPAT SAMPAH
minigrid/envs/wfc/patterns/Maze.png


TEMPAT SAMPAH
minigrid/envs/wfc/patterns/Mazelike.png


TEMPAT SAMPAH
minigrid/envs/wfc/patterns/Office.png


TEMPAT SAMPAH
minigrid/envs/wfc/patterns/Paths.png


TEMPAT SAMPAH
minigrid/envs/wfc/patterns/RedMaze.png


TEMPAT SAMPAH
minigrid/envs/wfc/patterns/Rooms.png


TEMPAT SAMPAH
minigrid/envs/wfc/patterns/ScaledMaze.png


TEMPAT SAMPAH
minigrid/envs/wfc/patterns/SimpleKnot.png


TEMPAT SAMPAH
minigrid/envs/wfc/patterns/SimpleMaze.png


TEMPAT SAMPAH
minigrid/envs/wfc/patterns/SimpleWall.png


TEMPAT SAMPAH
minigrid/envs/wfc/patterns/Skew1.png


TEMPAT SAMPAH
minigrid/envs/wfc/patterns/Skew2.png


TEMPAT SAMPAH
minigrid/envs/wfc/patterns/Spirals.png


TEMPAT SAMPAH
minigrid/envs/wfc/patterns/SpiralsNeg.png


+ 226 - 0
minigrid/envs/wfc/wfcenv.py

@@ -0,0 +1,226 @@
+from __future__ import annotations
+
+import copy
+
+import networkx as nx
+import numpy as np
+
+from minigrid.core.constants import OBJECT_TO_IDX
+from minigrid.core.grid import Grid
+from minigrid.core.mission import MissionSpace
+from minigrid.envs.wfc.config import WFC_PRESETS, WFCConfig
+from minigrid.envs.wfc.graphtransforms import EdgeDescriptor, GraphTransforms
+from minigrid.envs.wfc.wfclogic.control import execute_wfc
+from minigrid.minigrid_env import MiniGridEnv
+
+FEATURE_DESCRIPTORS = {"empty", "wall", "lava", "start", "goal"} | {
+    "navigable",
+    "non_navigable",
+}
+
+EDGE_CONFIG = {
+    "navigable": EdgeDescriptor(between=("navigable",), structure="grid"),
+    "non_navigable": EdgeDescriptor(between=("non_navigable",), structure="grid"),
+    "start_goal": EdgeDescriptor(between=("start", "goal"), structure=None),
+    # "lava_goal": EdgeDescriptor(between=("lava", "goal"), weight="lava_prob"),
+    # "moss_goal": EdgeDescriptor(between=("moss", "goal"), weight="moss_prob"),
+}
+
+
+class WFCEnv(MiniGridEnv):
+    """
+    ## Description
+
+    This environment procedurally generates a level using the Wave Function Collapse algorithm.
+    The environment supports a variety of different level structures but the default is a simple maze.
+    Requires the optional dependencies `imageio` and `networkx` to be installed with `pip install minigrid[wfc]`.
+
+    ## Mission Space
+
+    "traverse the maze to get to the goal"
+
+    ## Action Space
+
+    | Num | Name         | Action                    |
+    |-----|--------------|---------------------------|
+    | 0   | left         | Turn left                 |
+    | 1   | right        | Turn right                |
+    | 2   | forward      | Move forward              |
+    | 3   | pickup       | Unused                    |
+    | 4   | drop         | Unused                    |
+    | 5   | toggle       | Unused                    |
+    | 6   | done         | Unused                    |
+
+    ## Observation Encoding
+
+    - Each tile is encoded as a 3 dimensional tuple:
+        `(OBJECT_IDX, COLOR_IDX, STATE)`
+    - `OBJECT_TO_IDX` and `COLOR_TO_IDX` mapping can be found in
+        [minigrid/minigrid.py](minigrid/minigrid.py)
+    - `STATE` refers to the door state with 0=open, 1=closed and 2=locked
+
+    ## Rewards
+
+    A reward of '1 - 0.9 * (step_count / max_steps)' is given for success, and '0' for failure.
+
+    ## Termination
+
+    The episode ends if any one of the following conditions is met:
+
+    1. The agent reaches the goal.
+    2. Timeout (see `max_steps`).
+
+    ## Registered Configurations
+
+    S: size of map SxS.
+
+    """
+
+    PATTERN_COLOR_CONFIG = {
+        "wall": (0, 0, 0),  # black
+        "empty": (255, 255, 255),  # white
+    }
+
+    def __init__(
+        self,
+        wfc_config: WFCConfig | str = "MazeSimple",
+        size: int = 25,
+        ensure_connected: bool = True,
+        max_steps: int | None = None,
+        **kwargs,
+    ):
+        self.config = (
+            wfc_config if isinstance(wfc_config, WFCConfig) else WFC_PRESETS[wfc_config]
+        )
+        self.padding = 1
+
+        # This controls whether to process the level such that there is only a single connected navigable area
+        self.ensure_connected = ensure_connected
+
+        mission_space = MissionSpace(mission_func=self._gen_mission)
+
+        if size < 3:
+            raise ValueError(f"Grid size must be at least 3 (currently {size})")
+        self.size = size
+        self.max_attempts = 1000
+
+        if max_steps is None:
+            max_steps = self.size * 20
+
+        super().__init__(
+            mission_space=mission_space,
+            width=self.size,
+            height=self.size,
+            max_steps=max_steps,
+            **kwargs,
+        )
+
+    @staticmethod
+    def _gen_mission():
+        return "traverse the maze to get to the goal"
+
+    def _gen_grid(self, width, height):
+        shape = (height, width)
+
+        # Main call to generate a black and white pattern with WFC
+        shape_unpadded = (shape[0] - 2 * self.padding, shape[1] - 2 * self.padding)
+        pattern, _stats = execute_wfc(
+            attempt_limit=self.max_attempts,
+            output_size=shape_unpadded,
+            np_random=self.np_random,
+            **self.config.wfc_kwargs,
+        )
+        if pattern is None:
+            raise RuntimeError(
+                f"Could not generate a valid pattern within {self.max_attempts} attempts"
+            )
+
+        grid_raw = self._pattern_to_minigrid_layout(pattern)
+
+        # Stage 1: Make a navigable graph with only one main cavern
+        stage1_edge_config = {k: v for k, v in EDGE_CONFIG.items() if k == "navigable"}
+        graph_raw, _edge_graphs = GraphTransforms.minigrid_layout_to_dense_graph(
+            grid_raw[np.newaxis],
+            remove_border=False,
+            node_attr=FEATURE_DESCRIPTORS,
+            edge_config=stage1_edge_config,
+        )
+        graph = graph_raw[0]
+
+        # Stage 2: Graph processing
+        # Retain only the largest connected graph component, fill in the rest with walls
+        if self.ensure_connected:
+            graph = self._get_largest_component(graph)
+
+        # Add start and goal nodes
+        graph = self._place_start_and_goal_random(graph)
+
+        # Convert graph back to grid
+        grid_array = GraphTransforms.dense_graph_to_minigrid(
+            graph, shape=shape, padding=self.padding
+        )
+
+        # Decode to minigrid and set variables
+        self.agent_dir = self._rand_int(0, 4)
+        self.agent_pos = next(
+            zip(*np.nonzero(grid_array[:, :, 0] == OBJECT_TO_IDX["agent"]))
+        )
+        self.grid, _vismask = Grid.decode(grid_array)
+        self.mission = self._gen_mission()
+
+    def _pattern_to_minigrid_layout(self, pattern: np.ndarray):
+        if pattern.ndim != 3:
+            raise ValueError(
+                f"Expected pattern to have 3 dimensions, but got {pattern.ndim}"
+            )
+        layout = np.ones(pattern.shape, dtype=np.uint8) * OBJECT_TO_IDX["empty"]
+
+        wall_ids = np.where(pattern == self.PATTERN_COLOR_CONFIG["wall"])
+        layout[wall_ids] = OBJECT_TO_IDX["wall"]
+        layout = layout[..., 0]
+
+        return layout
+
+    @staticmethod
+    def _get_largest_component(graph: nx.Graph) -> nx.Graph:
+        wall_graph_attr = GraphTransforms.OBJECT_TO_DENSE_GRAPH_ATTRIBUTE["wall"]
+        # Prepare graph
+        inactive_nodes = [x for x, y in graph.nodes(data=True) if y["navigable"] < 0.5]
+        graph.remove_nodes_from(inactive_nodes)
+
+        components = [
+            graph.subgraph(c).copy()
+            for c in sorted(nx.connected_components(graph), key=len, reverse=True)
+            if len(c) > 1
+        ]
+        component = components[0]
+        graph = graph.subgraph(component)
+
+        for node in graph.nodes():
+            if node not in component.nodes():
+                for feat in graph.nodes[node]:
+                    if feat in wall_graph_attr:
+                        graph.nodes[node][feat] = 1.0
+                    else:
+                        graph.nodes[node][feat] = 0.0
+        # TODO: Check if this is necessary
+        g = nx.Graph()
+        g.add_nodes_from(graph.nodes(data=True))
+        g.add_edges_from(component.edges(data=True))
+
+        g_out = copy.deepcopy(g)
+
+        return g_out
+
+    def _place_start_and_goal_random(self, graph: nx.Graph) -> nx.Graph:
+        node_set = "navigable"
+
+        # Get two random navigable nodes
+        possible_nodes = [n for n, d in graph.nodes(data=True) if d[node_set]]
+        inds = self.np_random.permutation(len(possible_nodes))[:2]
+        start_node, goal_node = possible_nodes[inds[0]], possible_nodes[inds[1]]
+
+        graph.nodes[start_node]["start"] = 1
+        graph.nodes[goal_node]["goal"] = 1
+
+        return graph

+ 0 - 0
minigrid/envs/wfc/wfclogic/__init__.py


+ 56 - 0
minigrid/envs/wfc/wfclogic/adjacency.py

@@ -0,0 +1,56 @@
+"""Convert input data to adjacency information. Implementation based on https://github.com/ikarth/wfc_2019f"""
+from __future__ import annotations
+
+import numpy as np
+from numpy.typing import NDArray
+
+
+def adjacency_extraction(
+    pattern_grid: NDArray[np.int64],
+    pattern_catalog: dict[int, NDArray[np.int64]],
+    direction_offsets: list[tuple[int, tuple[int, int]]],
+    pattern_size: tuple[int, int] = (2, 2),
+) -> list[tuple[tuple[int, int], int, int]]:
+    """Takes a pattern grid and returns a list of all of the legal adjacencies found in it."""
+
+    def is_valid_overlap_xy(
+        adjacency_direction: tuple[int, int], pattern_1: int, pattern_2: int
+    ) -> bool:
+        """Given a direction and two patterns, find the overlap of the two patterns
+        and return True if the intersection matches."""
+        dimensions = (1, 0)
+        not_a_number = -1
+
+        # TODO: can probably speed this up by using the right slices, rather than rolling the whole pattern...
+        shifted = np.roll(
+            np.pad(
+                pattern_catalog[pattern_2],
+                max(pattern_size),
+                mode="constant",
+                constant_values=not_a_number,
+            ),
+            adjacency_direction,
+            dimensions,
+        )
+        compare = shifted[
+            pattern_size[0] : pattern_size[0] + pattern_size[0],
+            pattern_size[1] : pattern_size[1] + pattern_size[1],
+        ]
+
+        left = max(0, 0, +adjacency_direction[0])
+        right = min(pattern_size[0], pattern_size[0] + adjacency_direction[0])
+        top = max(0, 0 + adjacency_direction[1])
+        bottom = min(pattern_size[1], pattern_size[1] + adjacency_direction[1])
+        a = pattern_catalog[pattern_1][top:bottom, left:right]
+        b = compare[top:bottom, left:right]
+        res = np.array_equal(a, b)
+        return res
+
+    pattern_list = list(pattern_catalog.keys())
+    legal = []
+    for pattern_1 in pattern_list:
+        for pattern_2 in pattern_list:
+            for _direction_index, direction in direction_offsets:
+                if is_valid_overlap_xy(direction, pattern_1, pattern_2):
+                    legal.append((direction, pattern_1, pattern_2))
+    return legal

+ 295 - 0
minigrid/envs/wfc/wfclogic/control.py

@@ -0,0 +1,295 @@
+"""Main WFC execution function. Implementation based on https://github.com/ikarth/wfc_2019f"""
+from __future__ import annotations
+
+import logging
+import time
+from typing import Any, Callable
+
+import numpy as np
+from numpy.typing import NDArray
+from typing_extensions import Literal
+
+from minigrid.envs.wfc.wfclogic.adjacency import adjacency_extraction
+from minigrid.envs.wfc.wfclogic.patterns import (
+    make_pattern_catalog_with_rotations,
+    pattern_grid_to_tiles,
+)
+from minigrid.envs.wfc.wfclogic.solver import (
+    Contradiction,
+    StopEarly,
+    TimedOut,
+    lexicalLocationHeuristic,
+    lexicalPatternHeuristic,
+    make_global_use_all_patterns,
+    makeAdj,
+    makeAntiEntropyLocationHeuristic,
+    makeEntropyLocationHeuristic,
+    makeHilbertLocationHeuristic,
+    makeRandomLocationHeuristic,
+    makeRandomPatternHeuristic,
+    makeRarestPatternHeuristic,
+    makeSpiralLocationHeuristic,
+    makeWave,
+    makeWeightedPatternHeuristic,
+    run,
+    simpleLocationHeuristic,
+)
+
+from .tiles import make_tile_catalog
+from .utilities import tile_grid_to_image
+
+logger = logging.getLogger(__name__)
+
+
+def make_log_stats() -> Callable[[dict[str, Any], str], None]:
+    log_line = 0
+
+    def log_stats(stats: dict[str, Any], filename: str) -> None:
+        nonlocal log_line
+        if stats:
+            log_line += 1
+            with open(filename, "a", encoding="utf_8") as logf:
+                if log_line < 2:
+                    for s in stats.keys():
+                        print(str(s), end="\t", file=logf)
+                    print("", file=logf)
+                for s in stats.keys():
+                    print(str(stats[s]), end="\t", file=logf)
+                print("", file=logf)
+
+    return log_stats
+
+
+def execute_wfc(
+    image: NDArray[np.integer],
+    tile_size: int = 1,
+    pattern_width: int = 2,
+    rotations: int = 8,
+    output_size: tuple[int, int] = (48, 48),
+    ground: int | None = None,
+    attempt_limit: int = 10,
+    output_periodic: bool = True,
+    input_periodic: bool = True,
+    loc_heuristic: Literal[
+        "lexical", "hilbert", "spiral", "entropy", "anti-entropy", "simple", "random"
+    ] = "entropy",
+    choice_heuristic: Literal["lexical", "rarest", "weighted", "random"] = "weighted",
+    global_constraint: Literal[False, "allpatterns"] = False,
+    backtracking: bool = False,
+    log_filename: str = "log",
+    logging: bool = False,
+    global_constraints: None = None,
+    log_stats_to_output: Callable[[dict[str, Any], str], None] | None = None,
+    np_random: np.random.Generator | None = None,
+) -> NDArray[np.integer]:
+    time_begin = time.perf_counter()
+    output_destination = r"./output/"
+    np_random: np.random.Generator = (
+        np.random.default_rng() if np_random is None else np_random
+    )
+
+    rotations -= 1  # change to zero-based
+
+    input_stats = {
+        "tile_size": tile_size,
+        "pattern_width": pattern_width,
+        "rotations": rotations,
+        "output_size": output_size,
+        "ground": ground,
+        "attempt_limit": attempt_limit,
+        "output_periodic": output_periodic,
+        "input_periodic": input_periodic,
+        "location heuristic": loc_heuristic,
+        "choice heuristic": choice_heuristic,
+        "global constraint": global_constraint,
+        "backtracking": backtracking,
+    }
+    # TODO: generalize this to more than the four cardinal directions
+    direction_offsets = list(enumerate([(0, -1), (1, 0), (0, 1), (-1, 0)]))
+
+    tile_catalog, tile_grid, _code_list, _unique_tiles = make_tile_catalog(
+        image, tile_size
+    )
+    (
+        pattern_catalog,
+        pattern_weights,
+        pattern_list,
+        pattern_grid,
+    ) = make_pattern_catalog_with_rotations(
+        tile_grid, pattern_width, input_is_periodic=input_periodic, rotations=rotations
+    )
+
+    logger.debug("profiling adjacency relations")
+
+    adjacency_relations = adjacency_extraction(
+        pattern_grid,
+        pattern_catalog,
+        direction_offsets,
+        (pattern_width, pattern_width),
+    )
+
+    logger.debug("adjacency_relations")
+
+    logger.debug(f"output size: {output_size}\noutput periodic: {output_periodic}")
+    number_of_patterns = len(pattern_weights)
+    logger.debug(f"# patterns: {number_of_patterns}")
+    decode_patterns = dict(enumerate(pattern_list))
+    encode_patterns = {x: i for i, x in enumerate(pattern_list)}
+
+    adjacency_list: dict[tuple[int, int], list[set[int]]] = {}
+    for _, adjacency in direction_offsets:
+        adjacency_list[adjacency] = [set() for _ in pattern_weights]
+    # logger.debug(adjacency_list)
+    for adjacency, pattern1, pattern2 in adjacency_relations:
+        # logger.debug(adjacency)
+        # logger.debug(decode_patterns[pattern1])
+        adjacency_list[adjacency][encode_patterns[pattern1]].add(
+            encode_patterns[pattern2]
+        )
+
+    logger.debug(f"adjacency: {len(adjacency_list)}")
+
+    time_adjacency = time.perf_counter()
+
+    # Ground #
+
+    ground_list: NDArray[np.int64] | None = None
+    if ground:
+        ground_list = np.vectorize(lambda x: encode_patterns[x])(
+            pattern_grid.flat[(ground - 1) :]
+        )
+    if ground_list is None or ground_list.size == 0:
+        ground_list = None
+
+    wave = makeWave(
+        number_of_patterns, output_size[0], output_size[1], ground=ground_list
+    )
+    adjacency_matrix = makeAdj(adjacency_list)
+
+    # Heuristics #
+
+    encoded_weights: NDArray[np.float64] = np.zeros(
+        (number_of_patterns), dtype=np.float64
+    )
+    for w_id, w_val in pattern_weights.items():
+        encoded_weights[encode_patterns[w_id]] = w_val
+    choice_random_weighting: NDArray[np.float64] = (
+        np_random.random(wave.shape[1:]) * 0.1
+    )
+
+    pattern_heuristic: Callable[
+        [NDArray[np.bool_], NDArray[np.bool_]], int
+    ] = lexicalPatternHeuristic
+    if choice_heuristic == "rarest":
+        pattern_heuristic = makeRarestPatternHeuristic(encoded_weights, np_random)
+    if choice_heuristic == "weighted":
+        pattern_heuristic = makeWeightedPatternHeuristic(encoded_weights, np_random)
+    if choice_heuristic == "random":
+        pattern_heuristic = makeRandomPatternHeuristic(encoded_weights, np_random)
+
+    logger.debug(loc_heuristic)
+    location_heuristic: Callable[
+        [NDArray[np.bool_]], tuple[int, int]
+    ] = lexicalLocationHeuristic
+    if loc_heuristic == "anti-entropy":
+        location_heuristic = makeAntiEntropyLocationHeuristic(choice_random_weighting)
+    if loc_heuristic == "entropy":
+        location_heuristic = makeEntropyLocationHeuristic(choice_random_weighting)
+    if loc_heuristic == "random":
+        location_heuristic = makeRandomLocationHeuristic(choice_random_weighting)
+    if loc_heuristic == "simple":
+        location_heuristic = simpleLocationHeuristic
+    if loc_heuristic == "spiral":
+        location_heuristic = makeSpiralLocationHeuristic(choice_random_weighting)
+    if loc_heuristic == "hilbert":
+        # This requires hilbert_curve to be installed
+        location_heuristic = makeHilbertLocationHeuristic(choice_random_weighting)
+
+    # Global Constraints #
+
+    if global_constraint == "allpatterns":
+        active_global_constraint = make_global_use_all_patterns()
+    else:
+
+        def active_global_constraint(wave) -> bool:
+            return True
+
+    logger.debug(active_global_constraint)
+    combined_constraints = [active_global_constraint]
+
+    def combinedConstraints(wave: NDArray[np.bool_]) -> bool:
+        return all(fn(wave) for fn in combined_constraints)
+
+    # Solving #
+
+    time_solve_start = None
+    time_solve_end = None
+
+    solution_tile_grid = None
+    logger.debug("solving...")
+    attempts = 0
+    while attempts < attempt_limit:
+        attempts += 1
+        time_solve_start = time.perf_counter()
+        stats = {}
+        try:
+            solution = run(
+                wave.copy(),
+                adjacency_matrix,
+                locationHeuristic=location_heuristic,
+                patternHeuristic=pattern_heuristic,
+                periodic=output_periodic,
+                backtracking=backtracking,
+                checkFeasible=combinedConstraints,
+            )
+            solution_as_ids = np.vectorize(lambda x: decode_patterns[x])(solution)
+            solution_tile_grid = pattern_grid_to_tiles(solution_as_ids, pattern_catalog)
+
+            time_solve_end = time.perf_counter()
+            stats.update({"outcome": "success"})
+        except StopEarly:
+            logger.debug("Skipping...")
+            stats.update({"outcome": "skipped"})
+            raise
+        except TimedOut:
+            logger.debug("Timed Out")
+            stats.update({"outcome": "timed_out"})
+        except Contradiction:
+            # logger.warning(f"Contradiction: {exc}")
+            stats.update({"outcome": "contradiction"})
+        finally:
+            # profiler.dump_stats(f"logs/profile_{filename}_{timecode}.txt")
+            outstats = {}
+            outstats.update(input_stats)
+            solve_duration = time.perf_counter() - time_solve_start
+            if time_solve_end is not None:
+                solve_duration = time_solve_end - time_solve_start
+            adjacency_duration = time_solve_start - time_adjacency
+            outstats.update(
+                {
+                    "attempts": attempts,
+                    "time_start": time_begin,
+                    "time_adjacency": time_adjacency,
+                    "adjacency_duration": adjacency_duration,
+                    "time solve start": time_solve_start,
+                    "time solve end": time_solve_end,
+                    "solve duration": solve_duration,
+                    "pattern count": number_of_patterns,
+                }
+            )
+            outstats.update(stats)
+            if log_stats_to_output is not None:
+                log_stats_to_output(
+                    outstats, output_destination + log_filename + ".tsv"
+                )
+        if solution_tile_grid is not None:
+            return (
+                tile_grid_to_image(
+                    solution_tile_grid, tile_catalog, (tile_size, tile_size)
+                ),
+                outstats,
+            )
+        else:
+            return None, outstats
+
+    raise TimedOut("Attempt limit exceeded.")

+ 199 - 0
minigrid/envs/wfc/wfclogic/patterns.py

@@ -0,0 +1,199 @@
+"Extract patterns from grids of tiles. Implementation based on https://github.com/ikarth/wfc_2019f"
+from __future__ import annotations
+
+import logging
+from collections import Counter
+from typing import Any, Mapping
+
+import numpy as np
+from numpy.typing import NDArray
+
+from minigrid.envs.wfc.wfclogic.utilities import hash_downto
+
+logger = logging.getLogger(__name__)
+
+
+def unique_patterns_2d(
+    agrid: NDArray[np.int64], ksize: int, periodic_input: bool
+) -> tuple[NDArray[np.int64], NDArray[np.int64], NDArray[np.int64]]:
+    assert ksize >= 1
+    if periodic_input:
+        agrid = np.pad(
+            agrid,
+            ((0, ksize - 1), (0, ksize - 1), *(((0, 0),) * (len(agrid.shape) - 2))),
+            mode="wrap",
+        )
+    else:
+        # TODO: implement non-wrapped image handling
+        # a = np.pad(a, ((0,k-1),(0,k-1),*(((0,0),)*(len(a.shape)-2))), mode='constant', constant_values=None)
+        agrid = np.pad(
+            agrid,
+            ((0, ksize - 1), (0, ksize - 1), *(((0, 0),) * (len(agrid.shape) - 2))),
+            mode="wrap",
+        )
+
+    patches: NDArray[np.int64] = np.lib.stride_tricks.as_strided(
+        agrid,
+        (
+            agrid.shape[0] - ksize + 1,
+            agrid.shape[1] - ksize + 1,
+            ksize,
+            ksize,
+            *agrid.shape[2:],
+        ),
+        agrid.strides[:2] + agrid.strides[:2] + agrid.strides[2:],
+        writeable=False,
+    )
+    patch_codes = hash_downto(patches, 2)
+    uc, ui = np.unique(patch_codes, return_index=True)
+    locs = np.unravel_index(ui, patch_codes.shape)
+    up: NDArray[np.int64] = patches[locs[0], locs[1]]
+    ids: NDArray[np.int64] = np.vectorize(
+        {code: ind for ind, code in enumerate(uc)}.get
+    )(patch_codes)
+    return ids, up, patch_codes
+
+
+def unique_patterns_brute_force(grid, size, periodic_input):
+    padded_grid = np.pad(
+        grid,
+        ((0, size - 1), (0, size - 1), *(((0, 0),) * (len(grid.shape) - 2))),
+        mode="wrap",
+    )
+    patches = []
+    for x in range(grid.shape[0]):
+        row_patches = []
+        for y in range(grid.shape[1]):
+            row_patches.append(
+                np.ndarray.tolist(padded_grid[x : x + size, y : y + size])
+            )
+        patches.append(row_patches)
+    patches = np.array(patches)
+    patch_codes = hash_downto(patches, 2)
+    uc, ui = np.unique(patch_codes, return_index=True)
+    locs = np.unravel_index(ui, patch_codes.shape)
+    up = patches[locs[0], locs[1]]
+    ids = np.vectorize({c: i for i, c in enumerate(uc)}.get)(patch_codes)
+    return ids, up
+
+
+def make_pattern_catalog(
+    tile_grid: NDArray[np.int64], pattern_width: int, input_is_periodic: bool = True
+) -> tuple[dict[int, NDArray[np.int64]], Counter, NDArray[np.int64], NDArray[np.int64]]:
+    """Returns a pattern catalog (dictionary of pattern hashes to constituent tiles),
+    an ordered list of pattern weights, and an ordered list of pattern contents."""
+    _patterns_in_grid, pattern_contents_list, patch_codes = unique_patterns_2d(
+        tile_grid, pattern_width, input_is_periodic
+    )
+    dict_of_pattern_contents: dict[int, NDArray[np.int64]] = {}
+    for pat_idx in range(pattern_contents_list.shape[0]):
+        p_hash = hash_downto(pattern_contents_list[pat_idx], 0)
+        dict_of_pattern_contents.update({p_hash.item(): pattern_contents_list[pat_idx]})
+    pattern_frequency = Counter(hash_downto(pattern_contents_list, 1))
+    return (
+        dict_of_pattern_contents,
+        pattern_frequency,
+        hash_downto(pattern_contents_list, 1),
+        patch_codes,
+    )
+
+
+def identity_grid(grid):
+    """Do nothing to the grid"""
+    # return np.array([[7,5,5,5],[5,0,0,0],[5,0,1,0],[5,0,0,0]])
+    return grid
+
+
+def reflect_grid(grid):
+    """Reflect the grid left/right"""
+    return np.fliplr(grid)
+
+
+def rotate_grid(grid):
+    """Rotate the grid"""
+    return np.rot90(grid, axes=(1, 0))
+
+
+def make_pattern_catalog_with_rotations(
+    tile_grid: NDArray[np.int64],
+    pattern_width: int,
+    rotations: int = 7,
+    input_is_periodic: bool = True,
+) -> tuple[dict[int, NDArray[np.int64]], Counter, NDArray[np.int64], NDArray[np.int64]]:
+    rotated_tile_grid = tile_grid.copy()
+    merged_dict_of_pattern_contents: dict[int, NDArray[np.int64]] = {}
+    merged_pattern_frequency: Counter = Counter()
+    merged_pattern_contents_list: NDArray[np.int64] | None = None
+    merged_patch_codes: NDArray[np.int64] | None = None
+
+    def _make_catalog() -> None:
+        nonlocal rotated_tile_grid, merged_dict_of_pattern_contents, merged_pattern_contents_list, merged_pattern_frequency, merged_patch_codes
+        (
+            dict_of_pattern_contents,
+            pattern_frequency,
+            pattern_contents_list,
+            patch_codes,
+        ) = make_pattern_catalog(rotated_tile_grid, pattern_width, input_is_periodic)
+        merged_dict_of_pattern_contents.update(dict_of_pattern_contents)
+        merged_pattern_frequency.update(pattern_frequency)
+        if merged_pattern_contents_list is None:
+            merged_pattern_contents_list = pattern_contents_list.copy()
+        else:
+            merged_pattern_contents_list = np.unique(
+                np.concatenate((merged_pattern_contents_list, pattern_contents_list))
+            )
+        if merged_patch_codes is None:
+            merged_patch_codes = patch_codes.copy()
+
+    counter = 0
+    grid_ops = [
+        identity_grid,
+        reflect_grid,
+        rotate_grid,
+        reflect_grid,
+        rotate_grid,
+        reflect_grid,
+        rotate_grid,
+        reflect_grid,
+    ]
+    while counter <= (rotations):
+        # logger.debug(rotated_tile_grid.shape)
+        # logger.debug(np.array_equiv(reflect_grid(rotated_tile_grid.copy()), rotate_grid(rotated_tile_grid.copy())))
+
+        # logger.debug(counter)
+        # logger.debug(grid_ops[counter].__name__)
+        rotated_tile_grid = grid_ops[counter](rotated_tile_grid.copy())
+        # logger.debug(rotated_tile_grid)
+        # logger.debug("---")
+        _make_catalog()
+        counter += 1
+
+    # assert False
+    assert merged_pattern_contents_list is not None
+    assert merged_patch_codes is not None
+    return (
+        merged_dict_of_pattern_contents,
+        merged_pattern_frequency,
+        merged_pattern_contents_list,
+        merged_patch_codes,
+    )
+
+
+def pattern_grid_to_tiles(
+    pattern_grid: NDArray[np.int64], pattern_catalog: Mapping[int, NDArray[np.int64]]
+) -> NDArray[np.int64]:
+    anchor_x = 0
+    anchor_y = 0
+
+    def pattern_to_tile(pattern: int) -> Any:
+        # if isinstance(pattern, list):
+        #     ptrns = []
+        #     for p in pattern:
+        #         logger.debug(p)
+        #         ptrns.push(pattern_to_tile(p))
+        #     logger.debug(ptrns)
+        #     assert False
+        #     return ptrns
+        return pattern_catalog[pattern][anchor_x][anchor_y]
+
+    return np.vectorize(pattern_to_tile)(pattern_grid)

+ 530 - 0
minigrid/envs/wfc/wfclogic/solver.py

@@ -0,0 +1,530 @@
+"""Wave Function Collapse solver. Implementation based on https://github.com/ikarth/wfc_2019f"""
+from __future__ import annotations
+
+import itertools
+import logging
+import math
+from typing import Any, Callable, Collection, Iterable, Iterator, Mapping, TypeVar
+
+# from scipy import sparse  # type: ignore
+import numpy
+import numpy as np
+from numpy.typing import NBitBase, NDArray
+
+logger = logging.getLogger(__name__)
+
+T = TypeVar("T", bound=NBitBase)
+
+
+class Contradiction(Exception):
+    """Solving could not proceed without backtracking/restarting."""
+
+    pass
+
+
+class TimedOut(Exception):
+    """Solve timed out."""
+
+    pass
+
+
+class StopEarly(Exception):
+    """Aborting solve early."""
+
+    pass
+
+
+class Solver:
+    """WFC Solver which can hold wave and backtracking state."""
+
+    def __init__(
+        self,
+        *,
+        wave: NDArray[np.bool_],
+        adj: Mapping[tuple[int, int], NDArray[numpy.bool_]],
+        periodic: bool = False,
+        backtracking: bool = False,
+        on_backtrack: Callable[[], None] | None = None,
+        on_choice: Callable[[int, int, int], None] | None = None,
+        on_observe: Callable[[NDArray[numpy.bool_]], None] | None = None,
+        on_propagate: Callable[[NDArray[numpy.bool_]], None] | None = None,
+        check_feasible: Callable[[NDArray[numpy.bool_]], bool] | None = None,
+    ) -> None:
+        self.wave = wave
+        self.adj = adj
+        self.periodic = periodic
+        self.backtracking = backtracking
+        self.history: list[NDArray[np.bool_]] = []  # An undo history for backtracking.
+        self.on_backtrack = on_backtrack
+        self.on_choice = on_choice
+        self.on_observe = on_observe
+        self.on_propagate = on_propagate
+        self.check_feasible = check_feasible
+
+    @property
+    def is_solved(self) -> bool:
+        """Is True if the wave has been fully resolved."""
+        return (
+            self.wave.sum() == self.wave.shape[1] * self.wave.shape[2]
+            and (self.wave.sum(axis=0) == 1).all()
+        )
+
+    def solve_next(
+        self,
+        location_heuristic: Callable[[NDArray[numpy.bool_]], tuple[int, int]],
+        pattern_heuristic: Callable[[NDArray[np.bool_], NDArray[np.bool_]], int],
+    ) -> bool:
+        """Attempt to collapse one wave.  Returns True if no more steps remain."""
+        if self.is_solved:
+            return True
+        if self.check_feasible and not self.check_feasible(self.wave):
+            raise Contradiction("Not feasible.")
+        if self.backtracking:
+            self.history.append(self.wave.copy())
+        propagate(
+            self.wave, self.adj, periodic=self.periodic, onPropagate=self.on_propagate
+        )
+        pattern, i, j = None, None, None
+        try:
+            pattern, i, j = observe(self.wave, location_heuristic, pattern_heuristic)
+            if self.on_choice:
+                self.on_choice(pattern, i, j)
+            self.wave[:, i, j] = False
+            self.wave[pattern, i, j] = True
+            if self.on_observe:
+                self.on_observe(self.wave)
+            propagate(
+                self.wave,
+                self.adj,
+                periodic=self.periodic,
+                onPropagate=self.on_propagate,
+            )
+            return False  # Assume there is remaining steps, if not then the next call will return True.
+        except Contradiction:
+            if not self.backtracking:
+                raise
+            if not self.history:
+                raise Contradiction("Every permutation has been attempted.")
+            if self.on_backtrack:
+                self.on_backtrack()
+            self.wave = self.history.pop()
+            self.wave[pattern, i, j] = False
+            return False
+
+    def solve(
+        self,
+        location_heuristic: Callable[[NDArray[numpy.bool_]], tuple[int, int]],
+        pattern_heuristic: Callable[[NDArray[np.bool_], NDArray[np.bool_]], int],
+    ) -> NDArray[np.int64]:
+        """Attempts to solve all waves and returns the solution."""
+        while not self.solve_next(
+            location_heuristic=location_heuristic, pattern_heuristic=pattern_heuristic
+        ):
+            pass
+        return numpy.argmax(self.wave, axis=0)
+
+
+def makeWave(
+    n: int, w: int, h: int, ground: Iterable[int] | None = None
+) -> NDArray[numpy.bool_]:
+    wave: NDArray[numpy.bool_] = numpy.ones((n, w, h), dtype=numpy.bool_)
+    if ground is not None:
+        wave[:, :, h - 1] = False
+        for g in ground:
+            wave[
+                g,
+                :,
+            ] = False
+            wave[g, :, h - 1] = True
+    # logger.debug(wave)
+    # for i in range(wave.shape[0]):
+    #  logger.debug(wave[i])
+    return wave
+
+
+def makeAdj(
+    adjLists: Mapping[tuple[int, int], Collection[Iterable[int]]]
+) -> dict[tuple[int, int], NDArray[numpy.bool_]]:
+    adjMatrices = {}
+    # logger.debug(adjLists)
+    num_patterns = len(list(adjLists.values())[0])
+    for d in adjLists:
+        m = numpy.zeros((num_patterns, num_patterns), dtype=bool)
+        for i, js in enumerate(adjLists[d]):
+            # logger.debug(js)
+            for j in js:
+                m[i, j] = 1
+        # If scipy is available, use sparse matrices.
+        # adjMatrices[d] = sparse.csr_matrix(m)
+        adjMatrices[d] = m
+    return adjMatrices
+
+
+######################################
+# Location Heuristics
+
+
+def makeRandomLocationHeuristic(
+    preferences: NDArray[np.floating[Any]],
+) -> Callable[[NDArray[np.bool_]], tuple[int, int]]:
+    def randomLocationHeuristic(wave: NDArray[np.bool_]) -> tuple[int, int]:
+        unresolved_cell_mask = numpy.count_nonzero(wave, axis=0) > 1
+        cell_weights = numpy.where(unresolved_cell_mask, preferences, numpy.inf)
+        row, col = numpy.unravel_index(numpy.argmin(cell_weights), cell_weights.shape)
+        return row.item(), col.item()
+
+    return randomLocationHeuristic
+
+
+def makeEntropyLocationHeuristic(
+    preferences: NDArray[np.floating[Any]],
+) -> Callable[[NDArray[np.bool_]], tuple[int, int]]:
+    def entropyLocationHeuristic(wave: NDArray[np.bool_]) -> tuple[int, int]:
+        unresolved_cell_mask = numpy.count_nonzero(wave, axis=0) > 1
+        cell_weights = numpy.where(
+            unresolved_cell_mask,
+            preferences + numpy.count_nonzero(wave, axis=0),
+            numpy.inf,
+        )
+        row, col = numpy.unravel_index(numpy.argmin(cell_weights), cell_weights.shape)
+        return row.item(), col.item()
+
+    return entropyLocationHeuristic
+
+
+def makeAntiEntropyLocationHeuristic(
+    preferences: NDArray[np.floating[Any]],
+) -> Callable[[NDArray[np.bool_]], tuple[int, int]]:
+    def antiEntropyLocationHeuristic(wave: NDArray[np.bool_]) -> tuple[int, int]:
+        unresolved_cell_mask = numpy.count_nonzero(wave, axis=0) > 1
+        cell_weights = numpy.where(
+            unresolved_cell_mask,
+            preferences + numpy.count_nonzero(wave, axis=0),
+            -numpy.inf,
+        )
+        row, col = numpy.unravel_index(numpy.argmax(cell_weights), cell_weights.shape)
+        return row.item(), col.item()
+
+    return antiEntropyLocationHeuristic
+
+
+def spiral_transforms() -> Iterator[tuple[int, int]]:
+    for N in itertools.count(start=1):
+        if N % 2 == 0:
+            yield (0, 1)  # right
+            for _ in range(N):
+                yield (1, 0)  # down
+            for _ in range(N):
+                yield (0, -1)  # left
+        else:
+            yield (0, -1)  # left
+            for _ in range(N):
+                yield (-1, 0)  # up
+            for _ in range(N):
+                yield (0, 1)  # right
+
+
+def spiral_coords(x: int, y: int) -> Iterator[tuple[int, int]]:
+    yield x, y
+    for transform in spiral_transforms():
+        x += transform[0]
+        y += transform[1]
+        yield x, y
+
+
+def fill_with_curve(
+    arr: NDArray[np.floating[T]], curve_gen: Iterable[Iterable[int]]
+) -> NDArray[np.floating[T]]:
+    arr_len = numpy.prod(arr.shape)
+    fill = 0
+    for coord in curve_gen:
+        # logger.debug(fill, idx, coord)
+        if fill < arr_len:
+            try:
+                arr[tuple(coord)] = fill / arr_len
+                fill += 1
+            except IndexError:
+                pass
+        else:
+            break
+    # logger.debug(arr)
+    return arr
+
+
+def makeSpiralLocationHeuristic(
+    preferences: NDArray[np.floating[Any]],
+) -> Callable[[NDArray[np.bool_]], tuple[int, int]]:
+    # https://stackoverflow.com/a/23707273/5562922
+
+    spiral_gen = (
+        sc for sc in spiral_coords(preferences.shape[0] // 2, preferences.shape[1] // 2)
+    )
+
+    cell_order = fill_with_curve(preferences, spiral_gen)
+
+    def spiralLocationHeuristic(wave: NDArray[np.bool_]) -> tuple[int, int]:
+        unresolved_cell_mask = numpy.count_nonzero(wave, axis=0) > 1
+        cell_weights = numpy.where(unresolved_cell_mask, cell_order, numpy.inf)
+        row, col = numpy.unravel_index(numpy.argmin(cell_weights), cell_weights.shape)
+        return row.item(), col.item()
+
+    return spiralLocationHeuristic
+
+
+def makeHilbertLocationHeuristic(
+    preferences: NDArray[np.floating[Any]],
+) -> Callable[[NDArray[np.bool_]], tuple[int, int]]:
+    from hilbertcurve.hilbertcurve import HilbertCurve  # type: ignore
+
+    curve_size = math.ceil(math.sqrt(max(preferences.shape[0], preferences.shape[1])))
+    logger.debug(curve_size)
+    curve_size = 4
+    h_curve = HilbertCurve(curve_size, 2)
+    h_coords = (h_curve.point_from_distance(i) for i in itertools.count())
+    cell_order = fill_with_curve(preferences, h_coords)
+    # logger.debug(cell_order)
+
+    def hilbertLocationHeuristic(wave: NDArray[np.bool_]) -> tuple[int, int]:
+        unresolved_cell_mask = numpy.count_nonzero(wave, axis=0) > 1
+        cell_weights = numpy.where(unresolved_cell_mask, cell_order, numpy.inf)
+        row, col = numpy.unravel_index(numpy.argmin(cell_weights), cell_weights.shape)
+        return row.item(), col.item()
+
+    return hilbertLocationHeuristic
+
+
+def simpleLocationHeuristic(wave: NDArray[np.bool_]) -> tuple[int, int]:
+    unresolved_cell_mask = numpy.count_nonzero(wave, axis=0) > 1
+    cell_weights = numpy.where(
+        unresolved_cell_mask, numpy.count_nonzero(wave, axis=0), numpy.inf
+    )
+    row, col = numpy.unravel_index(numpy.argmin(cell_weights), cell_weights.shape)
+    return row.item(), col.item()
+
+
+def lexicalLocationHeuristic(wave: NDArray[np.bool_]) -> tuple[int, int]:
+    unresolved_cell_mask = numpy.count_nonzero(wave, axis=0) > 1
+    cell_weights = numpy.where(unresolved_cell_mask, 1.0, numpy.inf)
+    row, col = numpy.unravel_index(numpy.argmin(cell_weights), cell_weights.shape)
+    return row.item(), col.item()
+
+
+#####################################
+# Pattern Heuristics
+
+
+def lexicalPatternHeuristic(weights: NDArray[np.bool_], wave: NDArray[np.bool_]) -> int:
+    return numpy.nonzero(weights)[0][0].item()
+
+
+def makeWeightedPatternHeuristic(
+    weights: NDArray[np.floating[Any]],
+    np_random: numpy.random.Generator | None = None,
+):
+    num_of_patterns = len(weights)
+    np_random: numpy.random.Generator = (
+        numpy.random.default_rng() if np_random is None else np_random
+    )
+
+    def weightedPatternHeuristic(wave: NDArray[np.bool_], _: NDArray[np.bool_]) -> int:
+        # TODO: there's maybe a faster, more controlled way to do this sampling...
+        weighted_wave: NDArray[np.floating[Any]] = weights * wave
+        weighted_wave /= weighted_wave.sum()
+        result = np_random.choice(num_of_patterns, p=weighted_wave)
+        return result
+
+    return weightedPatternHeuristic
+
+
+def makeRarestPatternHeuristic(
+    weights: NDArray[np.floating[Any]],
+    np_random: numpy.random.Generator | None = None,
+) -> Callable[[NDArray[np.bool_], NDArray[np.bool_]], int]:
+    """Return a function that chooses the rarest (currently least-used) pattern."""
+    np_random: numpy.random.Generator = (
+        numpy.random.default_rng() if np_random is None else np_random
+    )
+
+    def weightedPatternHeuristic(
+        wave: NDArray[np.bool_], total_wave: NDArray[np.bool_]
+    ) -> int:
+        logger.debug(total_wave.shape)
+        # [logger.debug(e) for e in wave]
+        wave_sums = numpy.sum(total_wave, (1, 2))
+        # logger.debug(wave_sums)
+        selected_pattern = np_random.choice(
+            numpy.where(wave_sums == wave_sums.max())[0]
+        )
+        return selected_pattern
+
+    return weightedPatternHeuristic
+
+
+def makeMostCommonPatternHeuristic(
+    weights: NDArray[np.floating[Any]],
+    np_random: numpy.random.Generator | None = None,
+) -> Callable[[NDArray[np.bool_], NDArray[np.bool_]], int]:
+    """Return a function that chooses the most common (currently most-used) pattern."""
+    np_random: numpy.random.Generator = (
+        numpy.random.default_rng() if np_random is None else np_random
+    )
+
+    def weightedPatternHeuristic(
+        wave: NDArray[np.bool_], total_wave: NDArray[np.bool_]
+    ) -> int:
+        logger.debug(total_wave.shape)
+        # [logger.debug(e) for e in wave]
+        wave_sums = numpy.sum(total_wave, (1, 2))
+        selected_pattern = np_random.choice(
+            numpy.where(wave_sums == wave_sums.min())[0]
+        )
+        return selected_pattern
+
+    return weightedPatternHeuristic
+
+
+def makeRandomPatternHeuristic(
+    weights: NDArray[np.floating[Any]],
+    np_random: numpy.random.Generator | None = None,
+) -> Callable[[NDArray[np.bool_], NDArray[np.bool_]], int]:
+    num_of_patterns = len(weights)
+    np_random: numpy.random.Generator = (
+        numpy.random.default_rng() if np_random is None else np_random
+    )
+
+    def randomPatternHeuristic(wave: NDArray[np.bool_], _: NDArray[np.bool_]) -> int:
+        # TODO: there's maybe a faster, more controlled way to do this sampling...
+        weighted_wave = 1.0 * wave
+        weighted_wave /= weighted_wave.sum()
+        result = np_random.choice(num_of_patterns, p=weighted_wave)
+        return result
+
+    return randomPatternHeuristic
+
+
+######################################
+# Global Constraints
+
+
+def make_global_use_all_patterns() -> Callable[[NDArray[np.bool_]], bool]:
+    def global_use_all_patterns(wave: NDArray[np.bool_]) -> bool:
+        """Returns true if at least one instance of each pattern is still possible."""
+        return numpy.all(numpy.any(wave, axis=(1, 2))).item()
+
+    return global_use_all_patterns
+
+
+#####################################
+# Solver
+
+
+def propagate(
+    wave: NDArray[np.bool_],
+    adj: Mapping[tuple[int, int], NDArray[numpy.bool_]],
+    periodic: bool = False,
+    onPropagate: Callable[[NDArray[numpy.bool_]], None] | None = None,
+) -> None:
+    """Completely probagate any newly collapsed waves to all areas."""
+    last_count = wave.sum()
+
+    while True:
+        supports = {}
+        if periodic:
+            padded = numpy.pad(wave, ((0, 0), (1, 1), (1, 1)), mode="wrap")
+        else:
+            padded = numpy.pad(
+                wave, ((0, 0), (1, 1), (1, 1)), mode="constant", constant_values=True
+            )
+
+        # adj is the list of adjacencies. For each direction d in adjacency,
+        # check which patterns are still valid...
+        for d in adj:
+            dx, dy = d
+            # padded[] is a version of the adjacency matrix with the values wrapped around
+            # shifted[] is the padded version with the values shifted over in one direction
+            # because my code stores the directions as relative (x,y) coordinates, we can find
+            # the adjacent cell for each direction by simply shifting the matrix in that direction,
+            # which allows for arbitrary adjacency directions. This is somewhat excessive, but elegant.
+
+            shifted = padded[
+                :, 1 + dx : 1 + wave.shape[1] + dx, 1 + dy : 1 + wave.shape[2] + dy
+            ]
+            # logger.debug(f"shifted: {shifted.shape} | adj[d]: {adj[d].shape} | d: {d}")
+            # raise StopEarly
+            # supports[d] = numpy.einsum('pwh,pq->qwh', shifted, adj[d]) > 0
+
+            # The adjacency matrix is a boolean matrix, indexed by the direction and the two patterns.
+            # If the value for (direction, pattern1, pattern2) is True, then this is a valid adjacency.
+            # This gives us a rapid way to compare: True is 1, False is 0, so multiplying the matrices
+            # gives us the adjacency compatibility.
+            supports[d] = (adj[d] @ shifted.reshape(shifted.shape[0], -1)).reshape(
+                shifted.shape
+            ) > 0
+            # supports[d] = ( <- for each cell in the matrix
+            # adj[d]  <- the adjacency matrix [sliced by the direction d]
+            # @       <- Matrix multiplication
+            # shifted.reshape(shifted.shape[0], -1)) <- change the shape of the shifted matrix to 2-dimensions, to make the matrix multiplication easier
+            # .reshape(           <- reshape our matrix-multiplied result...
+            #   shifted.shape)   <- ...to match the original shape of the shifted matrix
+            # > 0    <- is not false
+
+        # multiply the wave matrix by the support matrix to find which patterns are still in the domain
+        for d in adj:
+            wave *= supports[d]
+
+        if wave.sum() == last_count:
+            break  # No changes since the last loop, changed waves have been fully propagated.
+        last_count = wave.sum()
+
+    if onPropagate:
+        onPropagate(wave)
+
+    if (wave.sum(axis=0) == 0).any():
+        raise Contradiction("Wave is in a contradictory state and can not be solved.")
+
+
+def observe(
+    wave: NDArray[np.bool_],
+    locationHeuristic: Callable[[NDArray[np.bool_]], tuple[int, int]],
+    patternHeuristic: Callable[[NDArray[np.bool_], NDArray[np.bool_]], int],
+) -> tuple[int, int, int]:
+    """Return the next best wave to collapse based on the provided heuristics."""
+    i, j = locationHeuristic(wave)
+    pattern = patternHeuristic(wave[:, i, j], wave)
+    return pattern, i, j
+
+
+def run(
+    wave: NDArray[np.bool_],
+    adj: Mapping[tuple[int, int], NDArray[numpy.bool_]],
+    locationHeuristic: Callable[[NDArray[numpy.bool_]], tuple[int, int]],
+    patternHeuristic: Callable[[NDArray[np.bool_], NDArray[np.bool_]], int],
+    periodic: bool = False,
+    backtracking: bool = False,
+    onBacktrack: Callable[[], None] | None = None,
+    onChoice: Callable[[int, int, int], None] | None = None,
+    onObserve: Callable[[NDArray[numpy.bool_]], None] | None = None,
+    onPropagate: Callable[[NDArray[numpy.bool_]], None] | None = None,
+    checkFeasible: Callable[[NDArray[numpy.bool_]], bool] | None = None,
+    onFinal: Callable[[NDArray[numpy.bool_]], None] | None = None,
+    depth: int = 0,
+    depth_limit: int | None = None,
+) -> NDArray[numpy.int64]:
+    solver = Solver(
+        wave=wave,
+        adj=adj,
+        periodic=periodic,
+        backtracking=backtracking,
+        on_backtrack=onBacktrack,
+        on_choice=onChoice,
+        on_observe=onObserve,
+        on_propagate=onPropagate,
+        check_feasible=checkFeasible,
+    )
+    while not solver.solve_next(
+        location_heuristic=locationHeuristic, pattern_heuristic=patternHeuristic
+    ):
+        pass
+    if onFinal:
+        onFinal(solver.wave)
+    return numpy.argmax(solver.wave, axis=0)

+ 64 - 0
minigrid/envs/wfc/wfclogic/tiles.py

@@ -0,0 +1,64 @@
+"""Breaks an image into consituant tiles. Implementation based on https://github.com/ikarth/wfc_2019f"""
+from __future__ import annotations
+
+import numpy as np
+from numpy.typing import NDArray
+
+from minigrid.envs.wfc.wfclogic.utilities import hash_downto
+
+
+def image_to_tiles(img: NDArray[np.integer], tile_size: int) -> NDArray[np.integer]:
+    """
+    Takes an images, divides it into tiles, return an array of tiles.
+    """
+    padding_argument = [(0, 0), (0, 0), (0, 0)]
+    for input_dim in [0, 1]:
+        padding_argument[input_dim] = (
+            0,
+            (tile_size - img.shape[input_dim]) % tile_size,
+        )
+    img = np.pad(img, padding_argument, mode="constant")
+    tiles = img.reshape(
+        (
+            img.shape[0] // tile_size,
+            tile_size,
+            img.shape[1] // tile_size,
+            tile_size,
+            img.shape[2],
+        )
+    ).swapaxes(1, 2)
+    return tiles
+
+
+def make_tile_catalog(
+    image_data: NDArray[np.integer], tile_size: int
+) -> tuple[
+    dict[int, NDArray[np.integer]],
+    NDArray[np.int64],
+    NDArray[np.int64],
+    tuple[NDArray[np.int64], NDArray[np.int64]],
+]:
+    """
+    Takes an image and tile size and returns the following:
+    tile_catalog is a dictionary tiles, with the hashed ID as the key
+    tile_grid is the original image, expressed in terms of hashed tile IDs
+    code_list is the original image, expressed in terms of hashed tile IDs and reduced to one dimension
+    unique_tiles is the set of tiles, plus the frequency of their occurrence
+    """
+    channels = image_data.shape[2]  # Number of color channels in the image
+    tiles = image_to_tiles(image_data, tile_size)
+    tile_list: NDArray[np.integer] = tiles.reshape(
+        (tiles.shape[0] * tiles.shape[1], tile_size, tile_size, channels)
+    )
+    code_list: NDArray[np.int64] = hash_downto(tiles, 2).reshape(
+        tiles.shape[0] * tiles.shape[1]
+    )
+    tile_grid: NDArray[np.int64] = hash_downto(tiles, 2)
+    unique_tiles: tuple[NDArray[np.int64], NDArray[np.int64]] = np.unique(
+        tile_grid, return_counts=True
+    )
+
+    tile_catalog: dict[int, NDArray[np.integer]] = {}
+    for i, j in enumerate(code_list):
+        tile_catalog[j] = tile_list[i]
+    return tile_catalog, tile_grid, code_list, unique_tiles

+ 77 - 0
minigrid/envs/wfc/wfclogic/utilities.py

@@ -0,0 +1,77 @@
+"""Utility data and functions for WFC. Implementation based on https://github.com/ikarth/wfc_2019f"""
+from __future__ import annotations
+
+import collections
+import logging
+
+import numpy as np
+from numpy.typing import NDArray
+
+logger = logging.getLogger(__name__)
+
+CoordXY = collections.namedtuple("CoordXY", ["x", "y"])
+CoordRC = collections.namedtuple("CoordRC", ["row", "column"])
+
+
+def hash_downto(a: NDArray[np.integer], rank: int, seed=0) -> NDArray[np.int64]:
+    state = np.random.RandomState(seed)
+    # np_random = np.random.default_rng(seed)
+    assert rank < len(a.shape)
+
+    u: NDArray[np.integer] = a.reshape((np.prod(a.shape[:rank], dtype=np.int64), -1))
+    v = state.randint(1 - (1 << 63), 1 << 63, np.prod(a.shape[rank:]), dtype=np.int64)
+    # v = np_random.integers(1 - (1 << 63), 1 << 63, np.prod(a.shape[rank:]), dtype=np.int64)
+    return np.asarray(np.inner(u, v).reshape(a.shape[:rank]), dtype=np.int64)
+
+
+def find_pattern_center(wfc_ns):
+    # wfc_ns.pattern_center = (math.floor((wfc_ns.pattern_width - 1) / 2), math.floor((wfc_ns.pattern_width - 1) / 2))
+    wfc_ns.pattern_center = (0, 0)
+    return wfc_ns
+
+
+def tile_grid_to_image(
+    tile_grid: NDArray[np.int64],
+    tile_catalog: dict[int, NDArray[np.integer]],
+    tile_size: tuple[int, int],
+    partial: bool = False,
+    color_channels: int = 3,
+) -> NDArray[np.integer]:
+    """
+    Takes a tile_grid and transforms it into an image, using the information
+    in tile_catalog. We use tile_size to figure out the size the new image
+    should be.
+    """
+    tile_dtype = next(iter(tile_catalog.values())).dtype
+    new_img = np.zeros(
+        (
+            tile_grid.shape[0] * tile_size[0],
+            tile_grid.shape[1] * tile_size[1],
+            color_channels,
+        ),
+        dtype=tile_dtype,
+    )
+    if partial and (len(tile_grid.shape)) > 2:
+        # TODO: implement rendering partially completed solution
+        # Call tile_grid_to_average() instead.
+        assert False
+    else:
+        for i in range(tile_grid.shape[0]):
+            for j in range(tile_grid.shape[1]):
+                tile = tile_grid[i, j]
+                for u in range(tile_size[0]):
+                    for v in range(tile_size[1]):
+                        pixel = [200, 0, 200]
+                        # If we want to display a partial pattern, it is helpful to
+                        # be able to show empty cells.
+                        pixel = tile_catalog[tile][u, v]
+                        # TODO: will need to change if using an image with more than 3 channels
+                        new_img[
+                            (i * tile_size[0]) + u, (j * tile_size[1]) + v
+                        ] = np.resize(
+                            pixel,
+                            new_img[
+                                (i * tile_size[0]) + u, (j * tile_size[1]) + v
+                            ].shape,
+                        )
+    return new_img

+ 1 - 0
minigrid/wrappers.py

@@ -522,6 +522,7 @@ class DictObservationSpaceWrapper(ObservationWrapper):
             "object",
             "from",
             "room",
+            "maze",
         ]
 
         all_words = colors + objects + verbs + extra_words

+ 1 - 1
py.Dockerfile

@@ -11,7 +11,7 @@ RUN apt-get -y update \
 COPY . /usr/local/minigrid/
 WORKDIR /usr/local/minigrid/
 
-RUN pip install .[testing] --no-cache-dir
+RUN pip install .[wfc,testing] --no-cache-dir
 
 RUN ["chmod", "+x", "/usr/local/minigrid/docker_entrypoint"]
 

+ 4 - 0
pyproject.toml

@@ -37,6 +37,10 @@ testing = [
     "pytest-mock>=3.10.0",
     "matplotlib>=3.0"
 ]
+wfc = [
+    "networkx",
+    "imageio>=2.31.1",
+]
 
 [project.urls]
 Homepage = "https://farama.org"

+ 0 - 0
tests/test_wfc/__init__.py


+ 40 - 0
tests/test_wfc/conftest.py

@@ -0,0 +1,40 @@
+from __future__ import annotations
+
+import pytest
+from numpy import array, uint8
+
+from minigrid.envs.wfc.config import PATTERN_PATH
+
+
+class Resources:
+    def get_pattern(self, image: str) -> str:
+        return PATTERN_PATH / image
+
+
+@pytest.fixture(scope="session")
+def resources() -> Resources:
+    return Resources()
+
+
+@pytest.fixture(scope="session")
+def img_redmaze(resources: Resources) -> array:
+    try:
+        import imageio  # type: ignore
+
+        pattern = resources.get_pattern("RedMaze.png")
+        img = imageio.v2.imread(pattern)
+    except ImportError:
+        b = [0, 0, 0]
+        w = [255, 255, 255]
+        r = [255, 0, 0]
+        img = array(
+            [
+                [w, w, w, w],
+                [w, b, b, b],
+                [w, b, r, b],
+                [w, b, b, b],
+            ],
+            dtype=uint8,
+        )
+
+    return img

+ 41 - 0
tests/test_wfc/test_wfc_adjacency.py

@@ -0,0 +1,41 @@
+"""Convert input data to adjacency information"""
+from __future__ import annotations
+
+import numpy as np
+
+from minigrid.envs.wfc.wfclogic import adjacency as wfc_adjacency
+from minigrid.envs.wfc.wfclogic import patterns as wfc_patterns
+from minigrid.envs.wfc.wfclogic import tiles as wfc_tiles
+
+
+def test_adjacency_extraction(img_redmaze: np.ndarray) -> None:
+    # TODO: generalize this to more than the four cardinal directions
+    direction_offsets = list(enumerate([(0, -1), (1, 0), (0, 1), (-1, 0)]))
+
+    img = img_redmaze
+    tile_size = 1
+    pattern_width = 2
+    periodic = False
+    _tile_catalog, tile_grid, _code_list, _unique_tiles = wfc_tiles.make_tile_catalog(
+        img, tile_size
+    )
+    (
+        pattern_catalog,
+        _pattern_weights,
+        _pattern_list,
+        pattern_grid,
+    ) = wfc_patterns.make_pattern_catalog(tile_grid, pattern_width, periodic)
+    adjacency_relations = wfc_adjacency.adjacency_extraction(
+        pattern_grid, pattern_catalog, direction_offsets
+    )
+    assert ((0, -1), -6150964001204120324, -4042134092912931260) in adjacency_relations
+    assert ((-1, 0), -4042134092912931260, 3069048847358774683) in adjacency_relations
+    assert ((1, 0), -3950451988873469076, -3950451988873469076) in adjacency_relations
+    assert ((-1, 0), -3950451988873469076, -3950451988873469076) in adjacency_relations
+    assert ((0, 1), -3950451988873469076, 3336256675067683735) in adjacency_relations
+    assert (
+        not ((0, -1), -3950451988873469076, -3950451988873469076) in adjacency_relations
+    )
+    assert (
+        not ((0, 1), -3950451988873469076, -3950451988873469076) in adjacency_relations
+    )

+ 60 - 0
tests/test_wfc/test_wfc_patterns.py

@@ -0,0 +1,60 @@
+from __future__ import annotations
+
+import numpy as np
+
+from minigrid.envs.wfc.wfclogic import patterns as wfc_patterns
+from minigrid.envs.wfc.wfclogic import tiles as wfc_tiles
+
+
+def test_unique_patterns_2d(img_redmaze) -> None:
+    img = img_redmaze
+    tile_size = 1
+    pattern_width = 2
+    _tile_catalog, tile_grid, _code_list, _unique_tiles = wfc_tiles.make_tile_catalog(
+        img, tile_size
+    )
+
+    (
+        _patterns_in_grid,
+        pattern_contents_list,
+        patch_codes,
+    ) = wfc_patterns.unique_patterns_2d(tile_grid, pattern_width, True)
+    assert patch_codes[1][2] == 4867810695119132864
+    assert pattern_contents_list[7][1][1] == 8253868773529191888
+
+
+def test_make_pattern_catalog(img_redmaze) -> None:
+    img = img_redmaze
+    tile_size = 1
+    pattern_width = 2
+    _tile_catalog, tile_grid, _code_list, _unique_tiles = wfc_tiles.make_tile_catalog(
+        img, tile_size
+    )
+
+    (
+        pattern_catalog,
+        pattern_weights,
+        pattern_list,
+        _pattern_grid,
+    ) = wfc_patterns.make_pattern_catalog(tile_grid, pattern_width)
+    assert pattern_weights[-6150964001204120324] == 1
+    assert pattern_list[3] == 2800765426490226432
+    assert pattern_catalog[5177878755649963747][0][1] == -8754995591521426669
+
+
+def test_pattern_to_tile(img_redmaze) -> None:
+    img = img_redmaze
+    tile_size = 1
+    pattern_width = 2
+    _tile_catalog, tile_grid, _code_list, _unique_tiles = wfc_tiles.make_tile_catalog(
+        img, tile_size
+    )
+
+    (
+        pattern_catalog,
+        _pattern_weights,
+        _pattern_list,
+        pattern_grid,
+    ) = wfc_patterns.make_pattern_catalog(tile_grid, pattern_width)
+    new_tile_grid = wfc_patterns.pattern_grid_to_tiles(pattern_grid, pattern_catalog)
+    assert np.array_equal(tile_grid, new_tile_grid)

+ 148 - 0
tests/test_wfc/test_wfc_solver.py

@@ -0,0 +1,148 @@
+from __future__ import annotations
+
+import numpy as np
+import pytest
+from numpy.typing import NDArray
+
+from minigrid.envs.wfc.wfclogic import solver as wfc_solver
+
+
+def test_makeWave() -> None:
+    wave = wfc_solver.makeWave(3, 10, 20, ground=[-1])
+    assert wave.sum() == (2 * 10 * 19) + (1 * 10 * 1)
+    assert wave[2, 5, 19]
+    assert not wave[1, 5, 19]
+
+
+def test_entropyLocationHeuristic() -> None:
+    wave = np.ones((5, 3, 4), dtype=bool)  # everything is possible
+    wave[1:, 0, 0] = False  # first cell is fully observed
+    wave[4, :, 2] = False
+    preferences: NDArray[np.float_] = np.ones((3, 4), dtype=np.float_) * 0.5
+    preferences[1, 2] = 0.3
+    preferences[1, 1] = 0.1
+    heu = wfc_solver.makeEntropyLocationHeuristic(preferences)
+    result = heu(wave)
+    assert (1, 2) == result
+
+
+def test_observe() -> None:
+    my_wave = np.ones((5, 3, 4), dtype=np.bool_)
+    my_wave[0, 1, 2] = False
+
+    def locHeu(wave: NDArray[np.bool_]) -> tuple[int, int]:
+        assert np.array_equal(wave, my_wave)
+        return 1, 2
+
+    def patHeu(weights: NDArray[np.bool_], wave: NDArray[np.bool_]) -> int:
+        assert np.array_equal(weights, my_wave[:, 1, 2])
+        return 3
+
+    assert wfc_solver.observe(
+        my_wave, locationHeuristic=locHeu, patternHeuristic=patHeu
+    ) == (
+        3,
+        1,
+        2,
+    )
+
+
+def test_propagate() -> None:
+    wave = np.ones((3, 3, 4), dtype=bool)
+    adjLists = {}
+    # checkerboard #0/#1 or solid fill #2
+    adjLists[(+1, 0)] = adjLists[(-1, 0)] = adjLists[(0, +1)] = adjLists[(0, -1)] = [
+        [1],
+        [0],
+        [2],
+    ]
+    wave[:, 0, 0] = False
+    wave[0, 0, 0] = True
+    adj = wfc_solver.makeAdj(adjLists)
+    wfc_solver.propagate(wave, adj, periodic=False)
+    expected_result = np.array(
+        [
+            [
+                [True, False, True, False],
+                [False, True, False, True],
+                [True, False, True, False],
+            ],
+            [
+                [False, True, False, True],
+                [True, False, True, False],
+                [False, True, False, True],
+            ],
+            [
+                [False, False, False, False],
+                [False, False, False, False],
+                [False, False, False, False],
+            ],
+        ]
+    )
+    assert np.array_equal(wave, expected_result)
+
+
+def test_run() -> None:
+    wave = wfc_solver.makeWave(3, 3, 4)
+    adjLists = {}
+    adjLists[(+1, 0)] = adjLists[(-1, 0)] = adjLists[(0, +1)] = adjLists[(0, -1)] = [
+        [1],
+        [0],
+        [2],
+    ]
+    adj = wfc_solver.makeAdj(adjLists)
+
+    first_result = wfc_solver.run(
+        wave.copy(),
+        adj,
+        locationHeuristic=wfc_solver.lexicalLocationHeuristic,
+        patternHeuristic=wfc_solver.lexicalPatternHeuristic,
+        periodic=False,
+    )
+
+    expected_first_result = np.array([[0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1]])
+
+    assert np.array_equal(first_result, expected_first_result)
+
+    event_log: list = []
+
+    def onChoice(pattern: int, i: int, j: int) -> None:
+        event_log.append((pattern, i, j))
+
+    def onBacktrack() -> None:
+        event_log.append("backtrack")
+
+    second_result = wfc_solver.run(
+        wave.copy(),
+        adj,
+        locationHeuristic=wfc_solver.lexicalLocationHeuristic,
+        patternHeuristic=wfc_solver.lexicalPatternHeuristic,
+        periodic=True,
+        backtracking=True,
+        onChoice=onChoice,
+        onBacktrack=onBacktrack,
+    )
+
+    expected_second_result = np.array([[2, 2, 2, 2], [2, 2, 2, 2], [2, 2, 2, 2]])
+
+    assert np.array_equal(second_result, expected_second_result)
+    assert event_log == [(0, 0, 0), "backtrack", (2, 0, 0)]
+
+    class Infeasible(Exception):
+        pass
+
+    def explode(wave: NDArray[np.bool_]) -> bool:
+        if wave.sum() < 20:
+            raise Infeasible
+        return False
+
+    with pytest.raises(wfc_solver.Contradiction):
+        wfc_solver.run(
+            wave.copy(),
+            adj,
+            locationHeuristic=wfc_solver.lexicalLocationHeuristic,
+            patternHeuristic=wfc_solver.lexicalPatternHeuristic,
+            periodic=True,
+            backtracking=True,
+            checkFeasible=explode,
+        )

+ 17 - 0
tests/test_wfc/test_wfc_tiles.py

@@ -0,0 +1,17 @@
+"""Breaks an image into consituant tiles."""
+from __future__ import annotations
+
+from minigrid.envs.wfc.wfclogic import tiles as wfc_tiles
+
+
+def test_image_to_tile(img_redmaze) -> None:
+    img = img_redmaze
+    tiles = wfc_tiles.image_to_tiles(img, 1)
+    assert tiles[2][2][0][0][0] == 255
+    assert tiles[2][2][0][0][1] == 0
+
+
+def test_make_tile_catalog(img_redmaze) -> None:
+    img = img_redmaze
+    tc, tg, cl, ut = wfc_tiles.make_tile_catalog(img, 1)
+    assert ut[1][0] == 7

+ 10 - 0
tests/utils.py

@@ -1,6 +1,8 @@
 """Finds all the specs that we can test with"""
 from __future__ import annotations
 
+from importlib.util import find_spec
+
 import gymnasium as gym
 import numpy as np
 
@@ -13,6 +15,14 @@ all_testing_env_specs = [
     )
 ]
 
+if find_spec("imageio") is None or find_spec("networkx") is None:
+    # Do not test WFC environments if dependencies are not installed
+    all_testing_env_specs = [
+        env_spec
+        for env_spec in all_testing_env_specs
+        if not env_spec.entry_point.startswith("minigrid.envs.wfc")
+    ]
+
 minigrid_testing_env_specs = [
     env_spec
     for env_spec in all_testing_env_specs