wfcenv.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. from __future__ import annotations
  2. import copy
  3. import networkx as nx
  4. import numpy as np
  5. from minigrid.core.constants import OBJECT_TO_IDX
  6. from minigrid.core.grid import Grid
  7. from minigrid.core.mission import MissionSpace
  8. from minigrid.envs.wfc.config import WFC_PRESETS, WFCConfig
  9. from minigrid.envs.wfc.graphtransforms import EdgeDescriptor, GraphTransforms
  10. from minigrid.envs.wfc.wfclogic.control import execute_wfc
  11. from minigrid.minigrid_env import MiniGridEnv
  12. FEATURE_DESCRIPTORS = {"empty", "wall", "lava", "start", "goal"} | {
  13. "navigable",
  14. "non_navigable",
  15. }
  16. EDGE_CONFIG = {
  17. "navigable": EdgeDescriptor(between=("navigable",), structure="grid"),
  18. "non_navigable": EdgeDescriptor(between=("non_navigable",), structure="grid"),
  19. "start_goal": EdgeDescriptor(between=("start", "goal"), structure=None),
  20. # "lava_goal": EdgeDescriptor(between=("lava", "goal"), weight="lava_prob"),
  21. # "moss_goal": EdgeDescriptor(between=("moss", "goal"), weight="moss_prob"),
  22. }
  23. class WFCEnv(MiniGridEnv):
  24. """
  25. ## Description
  26. This environment procedurally generates a level using the Wave Function Collapse algorithm.
  27. The environment supports a variety of different level structures but the default is a simple maze.
  28. Requires the optional dependencies `imageio` and `networkx` to be installed with `pip install minigrid[wfc]`.
  29. ## Mission Space
  30. "traverse the maze to get to the goal"
  31. ## Action Space
  32. | Num | Name | Action |
  33. |-----|--------------|---------------------------|
  34. | 0 | left | Turn left |
  35. | 1 | right | Turn right |
  36. | 2 | forward | Move forward |
  37. | 3 | pickup | Unused |
  38. | 4 | drop | Unused |
  39. | 5 | toggle | Unused |
  40. | 6 | done | Unused |
  41. ## Observation Encoding
  42. - Each tile is encoded as a 3 dimensional tuple:
  43. `(OBJECT_IDX, COLOR_IDX, STATE)`
  44. - `OBJECT_TO_IDX` and `COLOR_TO_IDX` mapping can be found in
  45. [minigrid/core/constants.py](minigrid/core/constants.py)
  46. - `STATE` refers to the door state with 0=open, 1=closed and 2=locked
  47. ## Rewards
  48. A reward of '1 - 0.9 * (step_count / max_steps)' is given for success, and '0' for failure.
  49. ## Termination
  50. The episode ends if any one of the following conditions is met:
  51. 1. The agent reaches the goal.
  52. 2. Timeout (see `max_steps`).
  53. ## Registered Configurations
  54. S: size of map SxS.
  55. """
  56. PATTERN_COLOR_CONFIG = {
  57. "wall": (0, 0, 0), # black
  58. "empty": (255, 255, 255), # white
  59. }
  60. def __init__(
  61. self,
  62. wfc_config: WFCConfig | str = "MazeSimple",
  63. size: int = 25,
  64. ensure_connected: bool = True,
  65. max_steps: int | None = None,
  66. **kwargs,
  67. ):
  68. self.config = (
  69. wfc_config if isinstance(wfc_config, WFCConfig) else WFC_PRESETS[wfc_config]
  70. )
  71. self.padding = 1
  72. # This controls whether to process the level such that there is only a single connected navigable area
  73. self.ensure_connected = ensure_connected
  74. mission_space = MissionSpace(mission_func=self._gen_mission)
  75. if size < 3:
  76. raise ValueError(f"Grid size must be at least 3 (currently {size})")
  77. self.size = size
  78. self.max_attempts = 1000
  79. if max_steps is None:
  80. max_steps = self.size * 20
  81. super().__init__(
  82. mission_space=mission_space,
  83. width=self.size,
  84. height=self.size,
  85. max_steps=max_steps,
  86. **kwargs,
  87. )
  88. @staticmethod
  89. def _gen_mission():
  90. return "traverse the maze to get to the goal"
  91. def _gen_grid(self, width, height):
  92. shape = (height, width)
  93. # Main call to generate a black and white pattern with WFC
  94. shape_unpadded = (shape[0] - 2 * self.padding, shape[1] - 2 * self.padding)
  95. pattern, _stats = execute_wfc(
  96. attempt_limit=self.max_attempts,
  97. output_size=shape_unpadded,
  98. np_random=self.np_random,
  99. **self.config.wfc_kwargs,
  100. )
  101. if pattern is None:
  102. raise RuntimeError(
  103. f"Could not generate a valid pattern within {self.max_attempts} attempts"
  104. )
  105. grid_raw = self._pattern_to_minigrid_layout(pattern)
  106. # Stage 1: Make a navigable graph with only one main cavern
  107. stage1_edge_config = {k: v for k, v in EDGE_CONFIG.items() if k == "navigable"}
  108. graph_raw, _edge_graphs = GraphTransforms.minigrid_layout_to_dense_graph(
  109. grid_raw[np.newaxis],
  110. remove_border=False,
  111. node_attr=FEATURE_DESCRIPTORS,
  112. edge_config=stage1_edge_config,
  113. )
  114. graph = graph_raw[0]
  115. # Stage 2: Graph processing
  116. # Retain only the largest connected graph component, fill in the rest with walls
  117. if self.ensure_connected:
  118. graph = self._get_largest_component(graph)
  119. # Add start and goal nodes
  120. graph = self._place_start_and_goal_random(graph)
  121. # Convert graph back to grid
  122. grid_array = GraphTransforms.dense_graph_to_minigrid(
  123. graph, shape=shape, padding=self.padding
  124. )
  125. # Decode to minigrid and set variables
  126. self.agent_dir = self._rand_int(0, 4)
  127. self.agent_pos = next(
  128. zip(*np.nonzero(grid_array[:, :, 0] == OBJECT_TO_IDX["agent"]))
  129. )
  130. self.grid, _vismask = Grid.decode(grid_array)
  131. self.mission = self._gen_mission()
  132. def _pattern_to_minigrid_layout(self, pattern: np.ndarray):
  133. if pattern.ndim != 3:
  134. raise ValueError(
  135. f"Expected pattern to have 3 dimensions, but got {pattern.ndim}"
  136. )
  137. layout = np.ones(pattern.shape, dtype=np.uint8) * OBJECT_TO_IDX["empty"]
  138. wall_ids = np.where(pattern == self.PATTERN_COLOR_CONFIG["wall"])
  139. layout[wall_ids] = OBJECT_TO_IDX["wall"]
  140. layout = layout[..., 0]
  141. return layout
  142. @staticmethod
  143. def _get_largest_component(graph: nx.Graph) -> nx.Graph:
  144. wall_graph_attr = GraphTransforms.OBJECT_TO_DENSE_GRAPH_ATTRIBUTE["wall"]
  145. # Prepare graph
  146. inactive_nodes = [x for x, y in graph.nodes(data=True) if y["navigable"] < 0.5]
  147. graph.remove_nodes_from(inactive_nodes)
  148. components = [
  149. graph.subgraph(c).copy()
  150. for c in sorted(nx.connected_components(graph), key=len, reverse=True)
  151. if len(c) > 1
  152. ]
  153. component = components[0]
  154. graph = graph.subgraph(component)
  155. for node in graph.nodes():
  156. if node not in component.nodes():
  157. for feat in graph.nodes[node]:
  158. if feat in wall_graph_attr:
  159. graph.nodes[node][feat] = 1.0
  160. else:
  161. graph.nodes[node][feat] = 0.0
  162. # TODO: Check if this is necessary
  163. g = nx.Graph()
  164. g.add_nodes_from(graph.nodes(data=True))
  165. g.add_edges_from(component.edges(data=True))
  166. g_out = copy.deepcopy(g)
  167. return g_out
  168. def _place_start_and_goal_random(self, graph: nx.Graph) -> nx.Graph:
  169. node_set = "navigable"
  170. # Get two random navigable nodes
  171. possible_nodes = [n for n, d in graph.nodes(data=True) if d[node_set]]
  172. inds = self.np_random.permutation(len(possible_nodes))[:2]
  173. start_node, goal_node = possible_nodes[inds[0]], possible_nodes[inds[1]]
  174. graph.nodes[start_node]["start"] = 1
  175. graph.nodes[goal_node]["goal"] = 1
  176. return graph