toposort.go 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293
  1. package toposort
  2. import (
  3. "bytes"
  4. "fmt"
  5. "sort"
  6. "strings"
  7. )
  8. // Reworked from https://github.com/philopon/go-toposort
  9. // Graph represents a directed acyclic graph.
  10. type Graph struct {
  11. // Outgoing connections for every node.
  12. outputs map[string]map[string]int
  13. // How many parents each node has.
  14. inputs map[string]int
  15. }
  16. // NewGraph initializes a new Graph.
  17. func NewGraph() *Graph {
  18. return &Graph{
  19. inputs: map[string]int{},
  20. outputs: map[string]map[string]int{},
  21. }
  22. }
  23. // Copy clones the graph and returns the independent copy.
  24. func (g *Graph) Copy() *Graph {
  25. clone := NewGraph()
  26. for k, v := range g.inputs {
  27. clone.inputs[k] = v
  28. }
  29. for k1, v1 := range g.outputs {
  30. m := map[string]int{}
  31. clone.outputs[k1] = m
  32. for k2, v2 := range v1 {
  33. m[k2] = v2
  34. }
  35. }
  36. return clone
  37. }
  38. // AddNode inserts a new node into the graph.
  39. func (g *Graph) AddNode(name string) bool {
  40. if _, exists := g.outputs[name]; exists {
  41. return false
  42. }
  43. g.outputs[name] = make(map[string]int)
  44. g.inputs[name] = 0
  45. return true
  46. }
  47. // AddNodes inserts multiple nodes into the graph at once.
  48. func (g *Graph) AddNodes(names ...string) bool {
  49. for _, name := range names {
  50. if ok := g.AddNode(name); !ok {
  51. return false
  52. }
  53. }
  54. return true
  55. }
  56. // AddEdge inserts the link from "from" node to "to" node.
  57. func (g *Graph) AddEdge(from, to string) int {
  58. m, ok := g.outputs[from]
  59. if !ok {
  60. return 0
  61. }
  62. m[to] = len(m) + 1
  63. ni := g.inputs[to] + 1
  64. g.inputs[to] = ni
  65. return ni
  66. }
  67. // ReindexNode updates the internal representation of the node after edge removals.
  68. func (g *Graph) ReindexNode(node string) {
  69. children, ok := g.outputs[node]
  70. if !ok {
  71. return
  72. }
  73. keys := []string{}
  74. for key := range children {
  75. keys = append(keys, key)
  76. }
  77. sort.Strings(keys)
  78. for i, key := range keys {
  79. children[key] = i + 1
  80. }
  81. }
  82. func (g *Graph) unsafeRemoveEdge(from, to string) {
  83. delete(g.outputs[from], to)
  84. g.inputs[to]--
  85. }
  86. // RemoveEdge deletes the link from "from" node to "to" node.
  87. // Call ReindexNode(from) after you finish modifying the edges.
  88. func (g *Graph) RemoveEdge(from, to string) bool {
  89. if _, ok := g.outputs[from]; !ok {
  90. return false
  91. }
  92. g.unsafeRemoveEdge(from, to)
  93. return true
  94. }
  95. // Toposort sorts the nodes in the graph in topological order.
  96. func (g *Graph) Toposort() ([]string, bool) {
  97. L := make([]string, 0, len(g.outputs))
  98. S := make([]string, 0, len(g.outputs))
  99. for n := range g.outputs {
  100. if g.inputs[n] == 0 {
  101. S = append(S, n)
  102. }
  103. }
  104. sort.Strings(S)
  105. for len(S) > 0 {
  106. var n string
  107. n, S = S[0], S[1:]
  108. L = append(L, n)
  109. ms := make([]string, len(g.outputs[n]))
  110. for m, i := range g.outputs[n] {
  111. ms[i-1] = m
  112. }
  113. for _, m := range ms {
  114. g.unsafeRemoveEdge(n, m)
  115. if g.inputs[m] == 0 {
  116. S = append(S, m)
  117. }
  118. }
  119. }
  120. N := 0
  121. for _, v := range g.inputs {
  122. N += v
  123. }
  124. if N > 0 {
  125. return L, false
  126. }
  127. return L, true
  128. }
  129. // BreadthSort sorts the nodes in the graph in BFS order.
  130. func (g *Graph) BreadthSort() []string {
  131. L := make([]string, 0, len(g.outputs))
  132. S := make([]string, 0, len(g.outputs))
  133. for n := range g.outputs {
  134. if g.inputs[n] == 0 {
  135. S = append(S, n)
  136. }
  137. }
  138. visited := map[string]bool{}
  139. for len(S) > 0 {
  140. node := S[0]
  141. S = S[1:]
  142. if _, exists := visited[node]; !exists {
  143. L = append(L, node)
  144. visited[node] = true
  145. for child := range g.outputs[node] {
  146. S = append(S, child)
  147. }
  148. }
  149. }
  150. return L
  151. }
  152. // FindCycle returns the cycle in the graph which contains "seed" node.
  153. func (g *Graph) FindCycle(seed string) []string {
  154. type edge struct {
  155. node string
  156. parent string
  157. }
  158. S := make([]edge, 0, len(g.outputs))
  159. S = append(S, edge{seed, ""})
  160. visited := map[string]string{}
  161. for len(S) > 0 {
  162. e := S[0]
  163. S = S[1:]
  164. if parent, exists := visited[e.node]; !exists || parent == "" {
  165. visited[e.node] = e.parent
  166. for child := range g.outputs[e.node] {
  167. S = append(S, edge{child, e.node})
  168. }
  169. }
  170. if e.node == seed && e.parent != "" {
  171. result := []string{}
  172. node := e.parent
  173. for node != seed {
  174. result = append(result, node)
  175. node = visited[node]
  176. }
  177. result = append(result, seed)
  178. // reverse
  179. for left, right := 0, len(result)-1; left < right; left, right = left+1, right-1 {
  180. result[left], result[right] = result[right], result[left]
  181. }
  182. return result
  183. }
  184. }
  185. return []string{}
  186. }
  187. // FindParents returns the other ends of incoming edges.
  188. func (g *Graph) FindParents(to string) []string {
  189. result := []string{}
  190. for node, children := range g.outputs {
  191. if _, exists := children[to]; exists {
  192. result = append(result, node)
  193. }
  194. }
  195. return result
  196. }
  197. // FindChildren returns the other ends of outgoing edges.
  198. func (g *Graph) FindChildren(from string) []string {
  199. result := []string{}
  200. for child := range g.outputs[from] {
  201. result = append(result, child)
  202. }
  203. sort.Strings(result)
  204. return result
  205. }
  206. // Serialize outputs the graph in Graphviz format.
  207. func (g *Graph) Serialize(sorted []string) string {
  208. node2index := map[string]int{}
  209. for index, node := range sorted {
  210. node2index[node] = index
  211. }
  212. var buffer bytes.Buffer
  213. buffer.WriteString("digraph Hercules {\n")
  214. nodesFrom := []string{}
  215. for nodeFrom := range g.outputs {
  216. nodesFrom = append(nodesFrom, nodeFrom)
  217. }
  218. sort.Strings(nodesFrom)
  219. for _, nodeFrom := range nodesFrom {
  220. links := []string{}
  221. for nodeTo := range g.outputs[nodeFrom] {
  222. links = append(links, nodeTo)
  223. }
  224. sort.Strings(links)
  225. for _, nodeTo := range links {
  226. buffer.WriteString(fmt.Sprintf(" \"%d %s\" -> \"%d %s\"\n",
  227. node2index[nodeFrom], nodeFrom, node2index[nodeTo], nodeTo))
  228. }
  229. }
  230. buffer.WriteString("}")
  231. return buffer.String()
  232. }
  233. // DebugDump converts the graph to a string. As the name suggests, useful for debugging.
  234. func (g *Graph) DebugDump() string {
  235. S := make([]string, 0, len(g.outputs))
  236. for n := range g.outputs {
  237. if g.inputs[n] == 0 {
  238. S = append(S, n)
  239. }
  240. }
  241. sort.Strings(S)
  242. var buffer bytes.Buffer
  243. buffer.WriteString(strings.Join(S, " ") + "\n")
  244. keys := []string{}
  245. vals := map[string][]string{}
  246. for key, val1 := range g.outputs {
  247. val2 := make([]string, len(val1))
  248. for name, idx := range val1 {
  249. val2[idx-1] = name
  250. }
  251. keys = append(keys, key)
  252. vals[key] = val2
  253. }
  254. sort.Strings(keys)
  255. for _, key := range keys {
  256. buffer.WriteString(fmt.Sprintf("%s %d = ", key, g.inputs[key]))
  257. outs := vals[key]
  258. buffer.WriteString(strings.Join(outs, " ") + "\n")
  259. }
  260. return buffer.String()
  261. }