|
@@ -51,7 +51,6 @@ class GraphTransforms:
|
|
|
|
|
|
@staticmethod
|
|
|
def minigrid_to_bitmap(grids):
|
|
|
-
|
|
|
layout = grids[..., 0]
|
|
|
bitmap = np.zeros_like(layout)
|
|
|
bitmap[layout == 2] = 1
|
|
@@ -97,7 +96,6 @@ class GraphTransforms:
|
|
|
def minigrid_layout_to_dense_graph(
|
|
|
layouts: np.ndarray, remove_border=True, node_attr=None, edge_config=None
|
|
|
) -> tuple[list[nx.Graph], dict[str, list[nx.Graph]]]:
|
|
|
-
|
|
|
assert (
|
|
|
layouts.ndim == 3
|
|
|
), f"Wrong dimensions for minigrid layout, expected 3 dimensions, got {layouts.ndim}."
|
|
@@ -165,7 +163,6 @@ class GraphTransforms:
|
|
|
dim_grid: tuple,
|
|
|
edge_config: dict[str, EdgeDescriptor] = None,
|
|
|
) -> tuple[list[nx.Graph], dict[str, list[nx.Graph]]]:
|
|
|
-
|
|
|
graphs = []
|
|
|
edge_graphs = defaultdict(list)
|
|
|
for m in range(features[list(features.keys())[0]].shape[0]):
|
|
@@ -191,7 +188,6 @@ class GraphTransforms:
|
|
|
def graph_features_to_minigrid(
|
|
|
graph_features: dict[str, np.ndarray], shape: tuple[int, int], padding=1
|
|
|
) -> np.ndarray:
|
|
|
-
|
|
|
features = graph_features.copy()
|
|
|
node_attributes = list(features.keys())
|
|
|
|
|
@@ -285,7 +281,6 @@ class GraphTransforms:
|
|
|
def get_node_features(
|
|
|
graph: nx.Graph, pattern_shape, node_attributes: list[str] = None, reshape=True
|
|
|
) -> tuple[np.ndarray, list[str]]:
|
|
|
-
|
|
|
if node_attributes is None:
|
|
|
# Get node attributes from some node
|
|
|
node_attributes = list(next(iter(graph.nodes.data()))[1].keys())
|
|
@@ -313,7 +308,6 @@ class GraphTransforms:
|
|
|
def dense_graph_to_minigrid(
|
|
|
graph: nx.Graph, shape: tuple[int, int], padding=1
|
|
|
) -> np.ndarray:
|
|
|
-
|
|
|
pattern_shape = (shape[0] - 2 * padding, shape[1] - 2 * padding)
|
|
|
features, node_attributes = GraphTransforms.get_node_features(
|
|
|
graph, pattern_shape, node_attributes=None
|
|
@@ -340,7 +334,6 @@ class GraphTransforms:
|
|
|
node_attr: list[str],
|
|
|
dim_grid: tuple[int, int],
|
|
|
) -> dict[str, nx.Graph]:
|
|
|
-
|
|
|
navigable_nodes = ["empty", "start", "goal", "moss"]
|
|
|
non_navigable_nodes = ["wall", "lava"]
|
|
|
assert all([isinstance(n, tuple) for n in graph.nodes])
|