123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200 |
- "Extract patterns from grids of tiles. Implementation based on https://github.com/ikarth/wfc_2019f"
- from __future__ import annotations
- import logging
- from collections import Counter
- from typing import Any, Mapping
- import numpy as np
- from numpy.typing import NDArray
- from minigrid.envs.wfc.wfclogic.utilities import hash_downto
- logger = logging.getLogger(__name__)
- def unique_patterns_2d(
- agrid: NDArray[np.int64], ksize: int, periodic_input: bool
- ) -> tuple[NDArray[np.int64], NDArray[np.int64], NDArray[np.int64]]:
- assert ksize >= 1
- if periodic_input:
- agrid = np.pad(
- agrid,
- ((0, ksize - 1), (0, ksize - 1), *(((0, 0),) * (len(agrid.shape) - 2))),
- mode="wrap",
- )
- else:
- # TODO: implement non-wrapped image handling
- # a = np.pad(a, ((0,k-1),(0,k-1),*(((0,0),)*(len(a.shape)-2))), mode='constant', constant_values=None)
- agrid = np.pad(
- agrid,
- ((0, ksize - 1), (0, ksize - 1), *(((0, 0),) * (len(agrid.shape) - 2))),
- mode="wrap",
- )
- patches: NDArray[np.int64] = np.lib.stride_tricks.as_strided(
- agrid,
- (
- agrid.shape[0] - ksize + 1,
- agrid.shape[1] - ksize + 1,
- ksize,
- ksize,
- *agrid.shape[2:],
- ),
- agrid.strides[:2] + agrid.strides[:2] + agrid.strides[2:],
- writeable=False,
- )
- patch_codes = hash_downto(patches, 2)
- uc, ui = np.unique(patch_codes, return_index=True)
- locs = np.unravel_index(ui, patch_codes.shape)
- up: NDArray[np.int64] = patches[locs[0], locs[1]]
- ids: NDArray[np.int64] = np.vectorize(
- {code: ind for ind, code in enumerate(uc)}.get
- )(patch_codes)
- return ids, up, patch_codes
- def unique_patterns_brute_force(grid, size, periodic_input):
- padded_grid = np.pad(
- grid,
- ((0, size - 1), (0, size - 1), *(((0, 0),) * (len(grid.shape) - 2))),
- mode="wrap",
- )
- patches = []
- for x in range(grid.shape[0]):
- row_patches = []
- for y in range(grid.shape[1]):
- row_patches.append(
- np.ndarray.tolist(padded_grid[x : x + size, y : y + size])
- )
- patches.append(row_patches)
- patches = np.array(patches)
- patch_codes = hash_downto(patches, 2)
- uc, ui = np.unique(patch_codes, return_index=True)
- locs = np.unravel_index(ui, patch_codes.shape)
- up = patches[locs[0], locs[1]]
- ids = np.vectorize({c: i for i, c in enumerate(uc)}.get)(patch_codes)
- return ids, up
- def make_pattern_catalog(
- tile_grid: NDArray[np.int64], pattern_width: int, input_is_periodic: bool = True
- ) -> tuple[dict[int, NDArray[np.int64]], Counter, NDArray[np.int64], NDArray[np.int64]]:
- """Returns a pattern catalog (dictionary of pattern hashes to constituent tiles),
- an ordered list of pattern weights, and an ordered list of pattern contents."""
- _patterns_in_grid, pattern_contents_list, patch_codes = unique_patterns_2d(
- tile_grid, pattern_width, input_is_periodic
- )
- dict_of_pattern_contents: dict[int, NDArray[np.int64]] = {}
- for pat_idx in range(pattern_contents_list.shape[0]):
- p_hash = hash_downto(pattern_contents_list[pat_idx], 0)
- dict_of_pattern_contents.update({p_hash.item(): pattern_contents_list[pat_idx]})
- pattern_frequency = Counter(hash_downto(pattern_contents_list, 1))
- return (
- dict_of_pattern_contents,
- pattern_frequency,
- hash_downto(pattern_contents_list, 1),
- patch_codes,
- )
- def identity_grid(grid):
- """Do nothing to the grid"""
- # return np.array([[7,5,5,5],[5,0,0,0],[5,0,1,0],[5,0,0,0]])
- return grid
- def reflect_grid(grid):
- """Reflect the grid left/right"""
- return np.fliplr(grid)
- def rotate_grid(grid):
- """Rotate the grid"""
- return np.rot90(grid, axes=(1, 0))
- def make_pattern_catalog_with_rotations(
- tile_grid: NDArray[np.int64],
- pattern_width: int,
- rotations: int = 7,
- input_is_periodic: bool = True,
- ) -> tuple[dict[int, NDArray[np.int64]], Counter, NDArray[np.int64], NDArray[np.int64]]:
- rotated_tile_grid = tile_grid.copy()
- merged_dict_of_pattern_contents: dict[int, NDArray[np.int64]] = {}
- merged_pattern_frequency: Counter = Counter()
- merged_pattern_contents_list: NDArray[np.int64] | None = None
- merged_patch_codes: NDArray[np.int64] | None = None
- def _make_catalog() -> None:
- nonlocal rotated_tile_grid, merged_dict_of_pattern_contents, merged_pattern_contents_list, merged_pattern_frequency, merged_patch_codes
- (
- dict_of_pattern_contents,
- pattern_frequency,
- pattern_contents_list,
- patch_codes,
- ) = make_pattern_catalog(rotated_tile_grid, pattern_width, input_is_periodic)
- merged_dict_of_pattern_contents.update(dict_of_pattern_contents)
- merged_pattern_frequency.update(pattern_frequency)
- if merged_pattern_contents_list is None:
- merged_pattern_contents_list = pattern_contents_list.copy()
- else:
- merged_pattern_contents_list = np.unique(
- np.concatenate((merged_pattern_contents_list, pattern_contents_list))
- )
- if merged_patch_codes is None:
- merged_patch_codes = patch_codes.copy()
- counter = 0
- grid_ops = [
- identity_grid,
- reflect_grid,
- rotate_grid,
- reflect_grid,
- rotate_grid,
- reflect_grid,
- rotate_grid,
- reflect_grid,
- ]
- while counter <= (rotations):
- # logger.debug(rotated_tile_grid.shape)
- # logger.debug(np.array_equiv(reflect_grid(rotated_tile_grid.copy()), rotate_grid(rotated_tile_grid.copy())))
- # logger.debug(counter)
- # logger.debug(grid_ops[counter].__name__)
- rotated_tile_grid = grid_ops[counter](rotated_tile_grid.copy())
- # logger.debug(rotated_tile_grid)
- # logger.debug("---")
- _make_catalog()
- counter += 1
- # assert False
- assert merged_pattern_contents_list is not None
- assert merged_patch_codes is not None
- return (
- merged_dict_of_pattern_contents,
- merged_pattern_frequency,
- merged_pattern_contents_list,
- merged_patch_codes,
- )
- def pattern_grid_to_tiles(
- pattern_grid: NDArray[np.int64], pattern_catalog: Mapping[int, NDArray[np.int64]]
- ) -> NDArray[np.int64]:
- anchor_x = 0
- anchor_y = 0
- def pattern_to_tile(pattern: int) -> Any:
- # if isinstance(pattern, list):
- # ptrns = []
- # for p in pattern:
- # logger.debug(p)
- # ptrns.push(pattern_to_tile(p))
- # logger.debug(ptrns)
- # assert False
- # return ptrns
- return pattern_catalog[pattern][anchor_x][anchor_y]
- return np.vectorize(pattern_to_tile)(pattern_grid)
|