tiles.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. """Breaks an image into consituant tiles. Implementation based on https://github.com/ikarth/wfc_2019f"""
  2. from __future__ import annotations
  3. import numpy as np
  4. from numpy.typing import NDArray
  5. from minigrid.envs.wfc.wfclogic.utilities import hash_downto
  6. def image_to_tiles(img: NDArray[np.integer], tile_size: int) -> NDArray[np.integer]:
  7. """
  8. Takes an images, divides it into tiles, return an array of tiles.
  9. """
  10. padding_argument = [(0, 0), (0, 0), (0, 0)]
  11. for input_dim in [0, 1]:
  12. padding_argument[input_dim] = (
  13. 0,
  14. (tile_size - img.shape[input_dim]) % tile_size,
  15. )
  16. img = np.pad(img, padding_argument, mode="constant")
  17. tiles = img.reshape(
  18. (
  19. img.shape[0] // tile_size,
  20. tile_size,
  21. img.shape[1] // tile_size,
  22. tile_size,
  23. img.shape[2],
  24. )
  25. ).swapaxes(1, 2)
  26. return tiles
  27. def make_tile_catalog(image_data: NDArray[np.integer], tile_size: int) -> tuple[
  28. dict[int, NDArray[np.integer]],
  29. NDArray[np.int64],
  30. NDArray[np.int64],
  31. tuple[NDArray[np.int64], NDArray[np.int64]],
  32. ]:
  33. """
  34. Takes an image and tile size and returns the following:
  35. tile_catalog is a dictionary tiles, with the hashed ID as the key
  36. tile_grid is the original image, expressed in terms of hashed tile IDs
  37. code_list is the original image, expressed in terms of hashed tile IDs and reduced to one dimension
  38. unique_tiles is the set of tiles, plus the frequency of their occurrence
  39. """
  40. channels = image_data.shape[2] # Number of color channels in the image
  41. tiles = image_to_tiles(image_data, tile_size)
  42. tile_list: NDArray[np.integer] = tiles.reshape(
  43. (tiles.shape[0] * tiles.shape[1], tile_size, tile_size, channels)
  44. )
  45. code_list: NDArray[np.int64] = hash_downto(tiles, 2).reshape(
  46. tiles.shape[0] * tiles.shape[1]
  47. )
  48. tile_grid: NDArray[np.int64] = hash_downto(tiles, 2)
  49. unique_tiles: tuple[NDArray[np.int64], NDArray[np.int64]] = np.unique(
  50. tile_grid, return_counts=True
  51. )
  52. tile_catalog: dict[int, NDArray[np.integer]] = {}
  53. for i, j in enumerate(code_list):
  54. tile_catalog[j] = tile_list[i]
  55. return tile_catalog, tile_grid, code_list, unique_tiles