config.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. from __future__ import annotations
  2. from collections import ChainMap
  3. from dataclasses import asdict, dataclass
  4. from pathlib import Path
  5. from typing_extensions import Literal
  6. PATTERN_PATH = Path(__file__).parent / "patterns"
  7. @dataclass
  8. class WFCConfig:
  9. """Dataclass for holding WFC configuration parameters.
  10. This controls the behavior of the WFC algorithm. The parameters are passed directly to the WFC solver.
  11. Attributes:
  12. pattern_path: Path to the pattern image that will be automatically loaded.
  13. tile_size: Size of the tiles in pixels to create from the pattern image.
  14. pattern_width: Size of the patterns in tiles to take from the pattern image. (greater than 3 is quite slow)
  15. rotations: Number of rotations for each tile.
  16. output_periodic: Whether the output should be periodic (wraps over edges).
  17. input_periodic: Whether the input should be periodic (wraps over edges).
  18. loc_heuristic: Heuristic for choosing the next tile location to collapse.
  19. choice_heuristic: Heuristic for choosing the next tile to use between possible tiles.
  20. backtracking: Whether to backtrack when contradictions are discovered.
  21. """
  22. pattern_path: Path
  23. tile_size: int = 1
  24. pattern_width: int = 2
  25. rotations: int = 8
  26. output_periodic: bool = False
  27. input_periodic: bool = False
  28. loc_heuristic: Literal[
  29. "lexical", "spiral", "entropy", "anti-entropy", "simple", "random"
  30. ] = "entropy"
  31. choice_heuristic: Literal["lexical", "rarest", "weighted", "random"] = "weighted"
  32. backtracking: bool = False
  33. @property
  34. def wfc_kwargs(self):
  35. try:
  36. from imageio.v2 import imread
  37. except ImportError as e:
  38. from gymnasium.error import DependencyNotInstalled
  39. raise DependencyNotInstalled(
  40. 'imageio is missing, please run `pip install "minigrid[wfc]"`'
  41. ) from e
  42. kwargs = asdict(self)
  43. kwargs["image"] = imread(kwargs.pop("pattern_path"))[:, :, :3]
  44. return kwargs
  45. # Basic presets for WFC configurations (that should generate in <1 min)
  46. WFC_PRESETS = {
  47. "MazeSimple": WFCConfig(
  48. pattern_path=PATTERN_PATH / "SimpleMaze.png",
  49. tile_size=1,
  50. pattern_width=2,
  51. output_periodic=False,
  52. input_periodic=False,
  53. ),
  54. "DungeonMazeScaled": WFCConfig(
  55. pattern_path=PATTERN_PATH / "ScaledMaze.png",
  56. tile_size=1,
  57. pattern_width=2,
  58. output_periodic=True,
  59. input_periodic=True,
  60. ),
  61. "RoomsFabric": WFCConfig(
  62. pattern_path=PATTERN_PATH / "Fabric.png",
  63. tile_size=1,
  64. pattern_width=3,
  65. output_periodic=False,
  66. input_periodic=False,
  67. ),
  68. "ObstaclesBlackdots": WFCConfig(
  69. pattern_path=PATTERN_PATH / "Blackdots.png",
  70. tile_size=1,
  71. pattern_width=2,
  72. output_periodic=False,
  73. input_periodic=False,
  74. ),
  75. "ObstaclesAngular": WFCConfig(
  76. pattern_path=PATTERN_PATH / "Angular.png",
  77. tile_size=1,
  78. pattern_width=3,
  79. output_periodic=True,
  80. input_periodic=True,
  81. ),
  82. "ObstaclesHogs3": WFCConfig(
  83. pattern_path=PATTERN_PATH / "Hogs.png",
  84. tile_size=1,
  85. pattern_width=3,
  86. output_periodic=True,
  87. input_periodic=True,
  88. ),
  89. }
  90. # Presets that take a large number of attempts to generate a consistent environment
  91. WFC_PRESETS_INCONSISTENT = {
  92. "MazeKnot": WFCConfig(
  93. pattern_path=PATTERN_PATH / "Knot.png",
  94. tile_size=1,
  95. pattern_width=3,
  96. output_periodic=True,
  97. input_periodic=True,
  98. ), # This is not too inconsistent (often 10 attempts is enough)
  99. "MazeWall": WFCConfig(
  100. pattern_path=PATTERN_PATH / "SimpleWall.png",
  101. tile_size=1,
  102. pattern_width=2,
  103. output_periodic=True,
  104. input_periodic=True,
  105. ),
  106. "RoomsOffice": WFCConfig(
  107. pattern_path=PATTERN_PATH / "Office.png",
  108. tile_size=1,
  109. pattern_width=3,
  110. output_periodic=True,
  111. input_periodic=True,
  112. ),
  113. "ObstaclesHogs2": WFCConfig(
  114. pattern_path=PATTERN_PATH / "Hogs.png",
  115. tile_size=1,
  116. pattern_width=2,
  117. output_periodic=True,
  118. input_periodic=True,
  119. ),
  120. "Skew2": WFCConfig(
  121. pattern_path=PATTERN_PATH / "Skew2.png",
  122. tile_size=1,
  123. pattern_width=3,
  124. output_periodic=True,
  125. input_periodic=True,
  126. ),
  127. }
  128. # Slow presets for WFC configurations (Most take about 2-4 min but some take 10+ min)
  129. WFC_PRESETS_SLOW = {
  130. "Maze": WFCConfig(
  131. pattern_path=PATTERN_PATH / "Maze.png",
  132. tile_size=1,
  133. pattern_width=3,
  134. output_periodic=True,
  135. input_periodic=True,
  136. ), # This is unusually slow: ~20min per 25x25 room
  137. "MazeSpirals": WFCConfig(
  138. pattern_path=PATTERN_PATH / "Spirals.png",
  139. tile_size=1,
  140. pattern_width=3,
  141. output_periodic=True,
  142. input_periodic=True,
  143. ),
  144. "MazePaths": WFCConfig(
  145. pattern_path=PATTERN_PATH / "Paths.png",
  146. tile_size=1,
  147. pattern_width=3,
  148. output_periodic=True,
  149. input_periodic=True,
  150. ),
  151. "Mazelike": WFCConfig(
  152. pattern_path=PATTERN_PATH / "Mazelike.png",
  153. tile_size=1,
  154. pattern_width=3,
  155. output_periodic=True,
  156. input_periodic=True,
  157. ),
  158. "Dungeon": WFCConfig(
  159. pattern_path=PATTERN_PATH / "DungeonExtr.png",
  160. tile_size=1,
  161. pattern_width=3,
  162. output_periodic=True,
  163. input_periodic=True,
  164. ), # ~10 mins
  165. "DungeonRooms": WFCConfig(
  166. pattern_path=PATTERN_PATH / "Rooms.png",
  167. tile_size=1,
  168. pattern_width=3,
  169. output_periodic=True,
  170. input_periodic=True,
  171. ),
  172. "DungeonLessRooms": WFCConfig(
  173. pattern_path=PATTERN_PATH / "LessRooms.png",
  174. tile_size=1,
  175. pattern_width=3,
  176. output_periodic=True,
  177. input_periodic=True,
  178. ),
  179. "DungeonSpirals": WFCConfig(
  180. pattern_path=PATTERN_PATH / "SpiralsNeg.png",
  181. tile_size=1,
  182. pattern_width=3,
  183. output_periodic=True,
  184. input_periodic=True,
  185. ),
  186. "RoomsMagicOffice": WFCConfig(
  187. pattern_path=PATTERN_PATH / "MagicOffice.png",
  188. tile_size=1,
  189. pattern_width=3,
  190. output_periodic=True,
  191. input_periodic=True,
  192. ),
  193. "SkewCave": WFCConfig(
  194. pattern_path=PATTERN_PATH / "Cave.png",
  195. tile_size=1,
  196. pattern_width=3,
  197. output_periodic=False,
  198. input_periodic=False,
  199. ),
  200. "SkewLake": WFCConfig(
  201. pattern_path=PATTERN_PATH / "Lake.png",
  202. tile_size=1,
  203. pattern_width=3,
  204. output_periodic=True,
  205. input_periodic=True,
  206. ), # ~10 mins
  207. }
  208. WFC_PRESETS_ALL = ChainMap(WFC_PRESETS, WFC_PRESETS_INCONSISTENT, WFC_PRESETS_SLOW)
  209. def register_wfc_presets(wfc_presets: dict[str, WFCConfig], register_fn):
  210. # Register fn needs to be provided to avoid a circular import
  211. for name in wfc_presets.keys():
  212. register_fn(
  213. id=f"MiniGrid-WFC-{name}-v0",
  214. entry_point="minigrid.envs.wfc:WFCEnv",
  215. kwargs={"wfc_config": name},
  216. )