tiles.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. """Breaks an image into consituant tiles. Implementation based on https://github.com/ikarth/wfc_2019f"""
  2. from __future__ import annotations
  3. import numpy as np
  4. from numpy.typing import NDArray
  5. from minigrid.envs.wfc.wfclogic.utilities import hash_downto
  6. def image_to_tiles(img: NDArray[np.integer], tile_size: int) -> NDArray[np.integer]:
  7. """
  8. Takes an images, divides it into tiles, return an array of tiles.
  9. """
  10. padding_argument = [(0, 0), (0, 0), (0, 0)]
  11. for input_dim in [0, 1]:
  12. padding_argument[input_dim] = (
  13. 0,
  14. (tile_size - img.shape[input_dim]) % tile_size,
  15. )
  16. img = np.pad(img, padding_argument, mode="constant")
  17. tiles = img.reshape(
  18. (
  19. img.shape[0] // tile_size,
  20. tile_size,
  21. img.shape[1] // tile_size,
  22. tile_size,
  23. img.shape[2],
  24. )
  25. ).swapaxes(1, 2)
  26. return tiles
  27. def make_tile_catalog(
  28. image_data: NDArray[np.integer], tile_size: int
  29. ) -> tuple[
  30. dict[int, NDArray[np.integer]],
  31. NDArray[np.int64],
  32. NDArray[np.int64],
  33. tuple[NDArray[np.int64], NDArray[np.int64]],
  34. ]:
  35. """
  36. Takes an image and tile size and returns the following:
  37. tile_catalog is a dictionary tiles, with the hashed ID as the key
  38. tile_grid is the original image, expressed in terms of hashed tile IDs
  39. code_list is the original image, expressed in terms of hashed tile IDs and reduced to one dimension
  40. unique_tiles is the set of tiles, plus the frequency of their occurrence
  41. """
  42. channels = image_data.shape[2] # Number of color channels in the image
  43. tiles = image_to_tiles(image_data, tile_size)
  44. tile_list: NDArray[np.integer] = tiles.reshape(
  45. (tiles.shape[0] * tiles.shape[1], tile_size, tile_size, channels)
  46. )
  47. code_list: NDArray[np.int64] = hash_downto(tiles, 2).reshape(
  48. tiles.shape[0] * tiles.shape[1]
  49. )
  50. tile_grid: NDArray[np.int64] = hash_downto(tiles, 2)
  51. unique_tiles: tuple[NDArray[np.int64], NDArray[np.int64]] = np.unique(
  52. tile_grid, return_counts=True
  53. )
  54. tile_catalog: dict[int, NDArray[np.integer]] = {}
  55. for i, j in enumerate(code_list):
  56. tile_catalog[j] = tile_list[i]
  57. return tile_catalog, tile_grid, code_list, unique_tiles