graphtransforms.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397
  1. from __future__ import annotations
  2. from collections import OrderedDict, defaultdict
  3. from dataclasses import dataclass
  4. from itertools import product
  5. import networkx as nx
  6. import numpy as np
  7. from minigrid.core.constants import COLOR_TO_IDX, IDX_TO_OBJECT, OBJECT_TO_IDX
  8. from minigrid.minigrid_env import MiniGridEnv
  9. @dataclass
  10. class EdgeDescriptor:
  11. between: tuple[str, str] | tuple[str]
  12. structure: str | None = None
  13. # This is maybe general enough to be in utils
  14. class GraphTransforms:
  15. OBJECT_TO_DENSE_GRAPH_ATTRIBUTE = {
  16. "empty": ("navigable", "empty"),
  17. "start": ("navigable", "start"),
  18. "agent": ("navigable", "start"),
  19. "goal": ("navigable", "goal"),
  20. "moss": ("navigable", "moss"),
  21. "wall": ("non_navigable", "wall"),
  22. "lava": ("non_navigable", "lava"),
  23. }
  24. DENSE_GRAPH_ATTRIBUTE_TO_OBJECT = {
  25. "empty": "empty",
  26. "start": "start",
  27. "goal": "goal",
  28. "moss": "moss",
  29. "wall": "wall",
  30. "lava": "lava",
  31. "navigable": None,
  32. "non_navigable": None,
  33. }
  34. MINIGRID_COLOR_CONFIG = {
  35. "empty": None,
  36. "wall": "grey",
  37. "agent": "blue",
  38. "goal": "green",
  39. "lava": "red",
  40. "moss": "purple",
  41. }
  42. @staticmethod
  43. def minigrid_to_bitmap(grids):
  44. layout = grids[..., 0]
  45. bitmap = np.zeros_like(layout)
  46. bitmap[layout == 2] = 1
  47. bitmap = list(bitmap)
  48. start_pos_id = np.where(layout == 10)
  49. goal_pos_id = np.where(layout == 8)
  50. start_pos = []
  51. goal_pos = []
  52. for i in range(len(bitmap)):
  53. bitmap[i] = bitmap[i][1:-1, 1:-1]
  54. start_pos.append(np.array([start_pos_id[2][i], start_pos_id[1][i]]))
  55. goal_pos.append(np.array([goal_pos_id[2][i], goal_pos_id[1][i]]))
  56. return bitmap, start_pos, goal_pos
  57. @staticmethod
  58. def minigrid_to_dense_graph(
  59. minigrids: np.ndarray | list[MiniGridEnv],
  60. node_attr=None,
  61. edge_config=None,
  62. ) -> list[nx.Graph]:
  63. if isinstance(minigrids[0], np.ndarray):
  64. minigrids = np.array(minigrids)
  65. layouts = minigrids[..., 0]
  66. elif isinstance(minigrids[0], MiniGridEnv):
  67. layouts = [minigrid.grid.encode()[..., 0] for minigrid in minigrids]
  68. for i in range(len(minigrids)):
  69. layouts[i][tuple(minigrids[i].agent_pos)] = OBJECT_TO_IDX["agent"]
  70. layouts = np.array(layouts)
  71. else:
  72. raise TypeError(
  73. f"minigrids must be of type List[np.ndarray], List[MiniGridEnv], "
  74. f"List[MultiGridEnv], not {type(minigrids[0])}"
  75. )
  76. graphs, _ = GraphTransforms.minigrid_layout_to_dense_graph(
  77. layouts, remove_border=True, node_attr=node_attr, edge_config=edge_config
  78. )
  79. return graphs
  80. @staticmethod
  81. def minigrid_layout_to_dense_graph(
  82. layouts: np.ndarray, remove_border=True, node_attr=None, edge_config=None
  83. ) -> tuple[list[nx.Graph], dict[str, list[nx.Graph]]]:
  84. assert (
  85. layouts.ndim == 3
  86. ), f"Wrong dimensions for minigrid layout, expected 3 dimensions, got {layouts.ndim}."
  87. node_attr = [] if node_attr is None else node_attr
  88. # Remove borders
  89. if remove_border:
  90. layouts = layouts[:, 1:-1, 1:-1] # remove edges
  91. dim_grid = layouts.shape[1:]
  92. # Get the objects present in the layout
  93. objects_idx = np.unique(layouts)
  94. object_instances = [IDX_TO_OBJECT[obj] for obj in objects_idx]
  95. assert set(object_instances).issubset(
  96. {"empty", "wall", "start", "goal", "agent", "lava", "moss"}
  97. ), (
  98. f"Unsupported object(s) in minigrid layout. Supported objects are: "
  99. f"empty, wall, start, goal, agent, lava, moss. Got {object_instances}."
  100. )
  101. # Get location of each object in the layout
  102. object_locations = {}
  103. for obj in object_instances:
  104. object_locations[obj] = defaultdict(list)
  105. ids = list(zip(*np.where(layouts == OBJECT_TO_IDX[obj])))
  106. for tup in ids:
  107. object_locations[obj][tup[0]].append(tup[1:])
  108. for m in range(layouts.shape[0]):
  109. if m not in object_locations[obj]:
  110. object_locations[obj][m] = []
  111. object_locations[obj] = OrderedDict(sorted(object_locations[obj].items()))
  112. if "start" not in object_instances and "agent" in object_instances:
  113. object_locations["start"] = object_locations["agent"]
  114. if "agent" not in object_instances and "start" in object_instances:
  115. object_locations["agent"] = object_locations["start"]
  116. # Create one-hot graph feature tensor
  117. graph_feats = {}
  118. object_to_attr = GraphTransforms.OBJECT_TO_DENSE_GRAPH_ATTRIBUTE
  119. for obj in object_instances:
  120. for attr in object_to_attr[obj]:
  121. if attr not in graph_feats and attr in node_attr:
  122. graph_feats[attr] = np.zeros(layouts.shape)
  123. loc = list(object_locations[obj].values())
  124. assert len(loc) == layouts.shape[0]
  125. for m in range(layouts.shape[0]):
  126. if loc[m]:
  127. loc_m = np.array(loc[m])
  128. graph_feats[attr][m][loc_m[:, 0], loc_m[:, 1]] = 1
  129. for attr in node_attr:
  130. if attr not in graph_feats:
  131. graph_feats[attr] = np.zeros(layouts.shape)
  132. graph_feats[attr] = graph_feats[attr].reshape(layouts.shape[0], -1)
  133. graphs, edge_graphs = GraphTransforms.features_to_dense_graph(
  134. graph_feats, dim_grid, edge_config
  135. )
  136. return graphs, edge_graphs
  137. @staticmethod
  138. def features_to_dense_graph(
  139. features: dict[str, np.ndarray],
  140. dim_grid: tuple,
  141. edge_config: dict[str, EdgeDescriptor] = None,
  142. ) -> tuple[list[nx.Graph], dict[str, list[nx.Graph]]]:
  143. graphs = []
  144. edge_graphs = defaultdict(list)
  145. for m in range(features[list(features.keys())[0]].shape[0]):
  146. g_temp = nx.grid_2d_graph(*dim_grid)
  147. g = nx.Graph()
  148. g.add_nodes_from(sorted(g_temp.nodes(data=True)))
  149. for attr in features:
  150. nx.set_node_attributes(
  151. g, {k: v for k, v in zip(g.nodes, features[attr][m].tolist())}, attr
  152. )
  153. if edge_config is not None:
  154. edge_layers = GraphTransforms.get_edge_layers(
  155. g, edge_config, list(features.keys()), dim_grid
  156. )
  157. for edge_n, edge_g in edge_layers.items():
  158. g.add_edges_from(edge_g.edges(data=True), label=edge_n)
  159. edge_graphs[edge_n].append(edge_g)
  160. graphs.append(g)
  161. return graphs, edge_graphs
  162. @staticmethod
  163. def graph_features_to_minigrid(
  164. graph_features: dict[str, np.ndarray], shape: tuple[int, int], padding=1
  165. ) -> np.ndarray:
  166. features = graph_features.copy()
  167. node_attributes = list(features.keys())
  168. color_config = GraphTransforms.MINIGRID_COLOR_CONFIG
  169. # shape_no_padding = (features[node_attributes[0]].shape[-2], shape[0] - 2, shape[1] - 2, 3)
  170. shape_no_padding = (shape[0] - 2 * padding, shape[1] - 2 * padding, 3)
  171. for attr in node_attributes:
  172. features[attr] = features[attr].reshape(*shape_no_padding[:-1])
  173. grids = np.ones(shape_no_padding, dtype=np.uint8) * OBJECT_TO_IDX["empty"]
  174. minigrid_object_to_encoding_map = {} # [object_id, color, state]
  175. for feature in node_attributes:
  176. obj_type = GraphTransforms.DENSE_GRAPH_ATTRIBUTE_TO_OBJECT[feature]
  177. if (
  178. obj_type is not None
  179. and obj_type not in minigrid_object_to_encoding_map.keys()
  180. ):
  181. if obj_type == "empty":
  182. minigrid_object_to_encoding_map[obj_type] = [
  183. OBJECT_TO_IDX["empty"],
  184. 0,
  185. 0,
  186. ]
  187. elif obj_type == "agent":
  188. minigrid_object_to_encoding_map[obj_type] = [
  189. OBJECT_TO_IDX["agent"],
  190. 0,
  191. 0,
  192. ]
  193. elif obj_type == "start":
  194. color_str = color_config["agent"]
  195. minigrid_object_to_encoding_map[obj_type] = [
  196. OBJECT_TO_IDX["agent"],
  197. COLOR_TO_IDX[color_str],
  198. 0,
  199. ]
  200. else:
  201. color_str = color_config[obj_type]
  202. minigrid_object_to_encoding_map[obj_type] = [
  203. OBJECT_TO_IDX[obj_type],
  204. COLOR_TO_IDX[color_str],
  205. 0,
  206. ]
  207. if (
  208. "start" not in minigrid_object_to_encoding_map.keys()
  209. and "agent" in minigrid_object_to_encoding_map.keys()
  210. ):
  211. minigrid_object_to_encoding_map["start"] = minigrid_object_to_encoding_map[
  212. "agent"
  213. ]
  214. if (
  215. "agent" not in minigrid_object_to_encoding_map.keys()
  216. and "start" in minigrid_object_to_encoding_map.keys()
  217. ):
  218. minigrid_object_to_encoding_map["agent"] = minigrid_object_to_encoding_map[
  219. "start"
  220. ]
  221. for i, attr in enumerate(node_attributes):
  222. if "wall" not in node_attributes:
  223. if attr == "navigable" and "wall" not in node_attributes:
  224. mapping = minigrid_object_to_encoding_map["wall"]
  225. grids[features[attr] == 0] = np.array(mapping, dtype=np.uint8)
  226. else:
  227. mapping = minigrid_object_to_encoding_map[attr]
  228. grids[features[attr] == 1] = np.array(mapping, dtype=np.uint8)
  229. else:
  230. try:
  231. mapping = minigrid_object_to_encoding_map[attr]
  232. grids[features[attr] == 1] = np.array(mapping, dtype=np.uint8)
  233. except KeyError:
  234. pass
  235. wall_encoding = np.array(
  236. minigrid_object_to_encoding_map["wall"], dtype=np.uint8
  237. )
  238. padded_grid = np.pad(
  239. grids,
  240. ((padding, padding), (padding, padding), (0, 0)),
  241. "constant",
  242. constant_values=-1,
  243. )
  244. padded_grid = np.where(
  245. padded_grid == -np.ones(3, dtype=np.uint8), wall_encoding, padded_grid
  246. )
  247. return padded_grid
  248. @staticmethod
  249. def get_node_features(
  250. graph: nx.Graph, pattern_shape, node_attributes: list[str] = None, reshape=True
  251. ) -> tuple[np.ndarray, list[str]]:
  252. if node_attributes is None:
  253. # Get node attributes from some node
  254. node_attributes = list(next(iter(graph.nodes.data()))[1].keys())
  255. # Get node features
  256. Fx = []
  257. for attr in node_attributes:
  258. if attr == "non_navigable" or attr == "wall":
  259. # The graph we are getting is only the navigable nodes so those that
  260. # are not present should be assumed to be walls and non-navigable
  261. f = np.ones(pattern_shape)
  262. else:
  263. f = np.zeros(pattern_shape)
  264. for node, data in graph.nodes.data(attr):
  265. f[node] = data
  266. if reshape:
  267. f = f.ravel()
  268. Fx.append(f)
  269. # Fx = torch.stack(Fx, dim=-1).to(device)
  270. Fx = np.stack(Fx, axis=-1)
  271. return Fx, node_attributes
  272. @staticmethod
  273. def dense_graph_to_minigrid(
  274. graph: nx.Graph, shape: tuple[int, int], padding=1
  275. ) -> np.ndarray:
  276. pattern_shape = (shape[0] - 2 * padding, shape[1] - 2 * padding)
  277. features, node_attributes = GraphTransforms.get_node_features(
  278. graph, pattern_shape, node_attributes=None
  279. )
  280. # num_zeros = features[features == 0.0].numel()
  281. # num_ones = features[features == 1.0].numel()
  282. num_zeros = (features == 0.0).sum()
  283. num_ones = (features == 1.0).sum()
  284. assert num_zeros + num_ones == features.size, "Graph features should be binary"
  285. features_dict = {}
  286. for i, key in enumerate(node_attributes):
  287. features_dict[key] = features[..., i]
  288. grids = GraphTransforms.graph_features_to_minigrid(
  289. features_dict, shape=shape, padding=padding
  290. )
  291. return grids
  292. @staticmethod
  293. def get_edge_layers(
  294. graph: nx.Graph,
  295. edge_config: dict[str, EdgeDescriptor],
  296. node_attr: list[str],
  297. dim_grid: tuple[int, int],
  298. ) -> dict[str, nx.Graph]:
  299. navigable_nodes = ["empty", "start", "goal", "moss"]
  300. non_navigable_nodes = ["wall", "lava"]
  301. assert all([isinstance(n, tuple) for n in graph.nodes])
  302. assert all([len(n) == 2 for n in graph.nodes])
  303. def partial_grid(graph, nodes, dim_grid):
  304. non_grid_nodes = [n for n in graph.nodes if n not in nodes]
  305. g_temp = nx.grid_2d_graph(*dim_grid)
  306. g_temp.remove_nodes_from(non_grid_nodes)
  307. g_temp.add_nodes_from(non_grid_nodes)
  308. g = nx.Graph()
  309. g.add_nodes_from(graph.nodes(data=True))
  310. g.add_edges_from(g_temp.edges)
  311. return g
  312. def pair_edges(graph, node_types):
  313. all_nodes = []
  314. for n_type in node_types:
  315. all_nodes.append(
  316. [n for n, a in graph.nodes.items() if a[n_type] >= 1.0]
  317. )
  318. edges = list(product(*all_nodes))
  319. edged_graph = nx.create_empty_copy(graph, with_data=True)
  320. edged_graph.add_edges_from(edges)
  321. return edged_graph
  322. edge_graphs = {}
  323. for edge_ in edge_config.keys():
  324. if edge_ == "navigable" and "navigable" not in node_attr:
  325. edge_config[edge_].between = navigable_nodes
  326. elif edge_ == "non_navigable" and "non_navigable" not in node_attr:
  327. edge_config[edge_].between = non_navigable_nodes
  328. elif not set(edge_config[edge_].between).issubset(set(node_attr)):
  329. # TODO: remove
  330. # logger.warning(f"Edge {edge_} not compatible with node attributes {node_attr}. Skipping.")
  331. continue
  332. if edge_config[edge_].structure is None:
  333. edge_graphs[edge_] = pair_edges(graph, edge_config[edge_].between)
  334. elif edge_config[edge_].structure == "grid":
  335. nodes = []
  336. for n_type in edge_config[edge_].between:
  337. nodes += [
  338. n
  339. for n, a in graph.nodes.items()
  340. if a[n_type] >= 1.0 and n not in nodes
  341. ]
  342. edge_graphs[edge_] = partial_grid(graph, nodes, dim_grid)
  343. else:
  344. raise NotImplementedError(
  345. f"Edge structure {edge_config[edge_].structure} not supported."
  346. )
  347. return edge_graphs