wfcenv.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259
  1. from __future__ import annotations
  2. import copy
  3. import networkx as nx
  4. import numpy as np
  5. from minigrid.core.constants import OBJECT_TO_IDX
  6. from minigrid.core.grid import Grid
  7. from minigrid.core.mission import MissionSpace
  8. from minigrid.envs.wfc.config import WFC_PRESETS_ALL, WFCConfig
  9. from minigrid.envs.wfc.graphtransforms import EdgeDescriptor, GraphTransforms
  10. from minigrid.envs.wfc.wfclogic.control import execute_wfc
  11. from minigrid.minigrid_env import MiniGridEnv
  12. FEATURE_DESCRIPTORS = {"empty", "wall", "lava", "start", "goal"} | {
  13. "navigable",
  14. "non_navigable",
  15. }
  16. EDGE_CONFIG = {
  17. "navigable": EdgeDescriptor(between=("navigable",), structure="grid"),
  18. "non_navigable": EdgeDescriptor(between=("non_navigable",), structure="grid"),
  19. "start_goal": EdgeDescriptor(between=("start", "goal"), structure=None),
  20. # "lava_goal": EdgeDescriptor(between=("lava", "goal"), weight="lava_prob"),
  21. # "moss_goal": EdgeDescriptor(between=("moss", "goal"), weight="moss_prob"),
  22. }
  23. class WFCEnv(MiniGridEnv):
  24. """
  25. ## Description
  26. This environment procedurally generates a level using the Wave Function Collapse algorithm.
  27. The environment supports a variety of different level structures but the default is a simple maze.
  28. See [WFC module page](index) for sample images of the available presets.
  29. Requires the optional dependencies `imageio` and `networkx` to be installed with `pip install minigrid[wfc]`.
  30. ## Mission Space
  31. "traverse the maze to get to the goal"
  32. ## Action Space
  33. | Num | Name | Action |
  34. |-----|--------------|---------------------------|
  35. | 0 | left | Turn left |
  36. | 1 | right | Turn right |
  37. | 2 | forward | Move forward |
  38. | 3 | pickup | Unused |
  39. | 4 | drop | Unused |
  40. | 5 | toggle | Unused |
  41. | 6 | done | Unused |
  42. ## Observation Encoding
  43. - Each tile is encoded as a 3 dimensional tuple:
  44. `(OBJECT_IDX, COLOR_IDX, STATE)`
  45. - `OBJECT_TO_IDX` and `COLOR_TO_IDX` mapping can be found in
  46. [minigrid/core/constants.py](minigrid/core/constants.py)
  47. - `STATE` refers to the door state with 0=open, 1=closed and 2=locked (unused)
  48. ## Rewards
  49. A reward of '1 - 0.9 * (step_count / max_steps)' is given for success, and '0' for failure.
  50. ## Termination
  51. The episode ends if any one of the following conditions is met:
  52. 1. The agent reaches the goal.
  53. 2. Timeout (see `max_steps`).
  54. ## Registered Configurations
  55. - `MiniGrid-WFC-MazeSimple-v0`
  56. - `MiniGrid-WFC-DungeonMazeScaled-v0`
  57. - `MiniGrid-WFC-RoomsFabric-v0`
  58. - `MiniGrid-WFC-ObstaclesBlackdots-v0`
  59. - `MiniGrid-WFC-ObstaclesAngular-v0`
  60. - `MiniGrid-WFC-ObstaclesHogs3-v0`
  61. Note: There are many more unregistered configuration presets but they may take a long time to generate a consistent environment.
  62. They can be registered with the following snippet:
  63. ```python
  64. import gymnasium
  65. from minigrid.envs.wfc.config import register_wfc_presets, WFC_PRESETS_INCONSISTENT, WFC_PRESETS_SLOW
  66. register_wfc_presets(WFC_PRESETS_INCONSISTENT, gymnasium.register)
  67. register_wfc_presets(WFC_PRESETS_SLOW, gymnasium.register)
  68. ```
  69. ## Research
  70. Adapted for `Minigrid` by the following work.
  71. ```bibtex
  72. @inproceedings{garcin2024dred,
  73. title = {DRED: Zero-Shot Transfer in Reinforcement Learning via Data-Regularised Environment Design},
  74. author = {Garcin, Samuel and Doran, James and Guo, Shangmin and Lucas, Christopher G and Albrecht, Stefano V},
  75. booktitle = {Forty-first International Conference on Machine Learning},
  76. year = {2024},
  77. }
  78. ```
  79. """
  80. PATTERN_COLOR_CONFIG = {
  81. "wall": (0, 0, 0), # black
  82. "empty": (255, 255, 255), # white
  83. }
  84. def __init__(
  85. self,
  86. wfc_config: WFCConfig | str = "MazeSimple",
  87. size: int = 25,
  88. ensure_connected: bool = True,
  89. max_steps: int | None = None,
  90. **kwargs,
  91. ):
  92. self.config = (
  93. wfc_config
  94. if isinstance(wfc_config, WFCConfig)
  95. else WFC_PRESETS_ALL[wfc_config]
  96. )
  97. self.padding = 1
  98. # This controls whether to process the level such that there is only a single connected navigable area
  99. self.ensure_connected = ensure_connected
  100. mission_space = MissionSpace(mission_func=self._gen_mission)
  101. if size < 3:
  102. raise ValueError(f"Grid size must be at least 3 (currently {size})")
  103. self.size = size
  104. self.max_attempts = 1000
  105. if max_steps is None:
  106. max_steps = self.size * 20
  107. super().__init__(
  108. mission_space=mission_space,
  109. width=self.size,
  110. height=self.size,
  111. max_steps=max_steps,
  112. **kwargs,
  113. )
  114. @staticmethod
  115. def _gen_mission():
  116. return "traverse the maze to get to the goal"
  117. def _gen_grid(self, width, height):
  118. shape = (height, width)
  119. # Main call to generate a black and white pattern with WFC
  120. shape_unpadded = (shape[0] - 2 * self.padding, shape[1] - 2 * self.padding)
  121. pattern, _stats = execute_wfc(
  122. attempt_limit=self.max_attempts,
  123. output_size=shape_unpadded,
  124. np_random=self.np_random,
  125. **self.config.wfc_kwargs,
  126. )
  127. if pattern is None:
  128. raise RuntimeError(
  129. f"Could not generate a valid pattern within {self.max_attempts} attempts"
  130. )
  131. grid_raw = self._pattern_to_minigrid_layout(pattern)
  132. # Stage 1: Make a navigable graph with only one main cavern
  133. stage1_edge_config = {k: v for k, v in EDGE_CONFIG.items() if k == "navigable"}
  134. graph_raw, _edge_graphs = GraphTransforms.minigrid_layout_to_dense_graph(
  135. grid_raw[np.newaxis],
  136. remove_border=False,
  137. node_attr=FEATURE_DESCRIPTORS,
  138. edge_config=stage1_edge_config,
  139. )
  140. graph = graph_raw[0]
  141. # Stage 2: Graph processing
  142. # Retain only the largest connected graph component, fill in the rest with walls
  143. if self.ensure_connected:
  144. graph = self._get_largest_component(graph)
  145. # Add start and goal nodes
  146. graph = self._place_start_and_goal_random(graph)
  147. # Convert graph back to grid
  148. grid_array = GraphTransforms.dense_graph_to_minigrid(
  149. graph, shape=shape, padding=self.padding
  150. )
  151. # Decode to minigrid and set variables
  152. self.agent_dir = self._rand_int(0, 4)
  153. self.agent_pos = next(
  154. zip(*np.nonzero(grid_array[:, :, 0] == OBJECT_TO_IDX["agent"]))
  155. )
  156. self.grid, _vismask = Grid.decode(grid_array)
  157. self.mission = self._gen_mission()
  158. def _pattern_to_minigrid_layout(self, pattern: np.ndarray):
  159. if pattern.ndim != 3:
  160. raise ValueError(
  161. f"Expected pattern to have 3 dimensions, but got {pattern.ndim}"
  162. )
  163. layout = np.ones(pattern.shape, dtype=np.uint8) * OBJECT_TO_IDX["empty"]
  164. wall_ids = np.where(pattern == self.PATTERN_COLOR_CONFIG["wall"])
  165. layout[wall_ids] = OBJECT_TO_IDX["wall"]
  166. layout = layout[..., 0]
  167. return layout
  168. @staticmethod
  169. def _get_largest_component(graph: nx.Graph) -> nx.Graph:
  170. wall_graph_attr = GraphTransforms.OBJECT_TO_DENSE_GRAPH_ATTRIBUTE["wall"]
  171. # Prepare graph
  172. inactive_nodes = [x for x, y in graph.nodes(data=True) if y["navigable"] < 0.5]
  173. graph.remove_nodes_from(inactive_nodes)
  174. components = [
  175. graph.subgraph(c).copy()
  176. for c in sorted(nx.connected_components(graph), key=len, reverse=True)
  177. if len(c) > 1
  178. ]
  179. component = components[0]
  180. graph = graph.subgraph(component)
  181. for node in graph.nodes():
  182. if node not in component.nodes():
  183. for feat in graph.nodes[node]:
  184. if feat in wall_graph_attr:
  185. graph.nodes[node][feat] = 1.0
  186. else:
  187. graph.nodes[node][feat] = 0.0
  188. # TODO: Check if this is necessary
  189. g = nx.Graph()
  190. g.add_nodes_from(graph.nodes(data=True))
  191. g.add_edges_from(component.edges(data=True))
  192. g_out = copy.deepcopy(g)
  193. return g_out
  194. def _place_start_and_goal_random(self, graph: nx.Graph) -> nx.Graph:
  195. node_set = "navigable"
  196. # Get two random navigable nodes
  197. possible_nodes = [n for n, d in graph.nodes(data=True) if d[node_set]]
  198. inds = self.np_random.permutation(len(possible_nodes))[:2]
  199. start_node, goal_node = possible_nodes[inds[0]], possible_nodes[inds[1]]
  200. graph.nodes[start_node]["start"] = 1
  201. graph.nodes[goal_node]["goal"] = 1
  202. return graph