toposort.go 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. package toposort
  2. // Copied from https://github.com/philopon/go-toposort
  3. type Graph struct {
  4. nodes []string
  5. outputs map[string]map[string]int
  6. inputs map[string]int
  7. }
  8. func NewGraph() *Graph {
  9. return &Graph{
  10. nodes: []string{},
  11. inputs: map[string]int{},
  12. outputs: map[string]map[string]int{},
  13. }
  14. }
  15. func (g *Graph) AddNode(name string) bool {
  16. g.nodes = append(g.nodes, name)
  17. if _, ok := g.outputs[name]; ok {
  18. return false
  19. }
  20. g.outputs[name] = make(map[string]int)
  21. g.inputs[name] = 0
  22. return true
  23. }
  24. func (g *Graph) AddNodes(names ...string) bool {
  25. for _, name := range names {
  26. if ok := g.AddNode(name); !ok {
  27. return false
  28. }
  29. }
  30. return true
  31. }
  32. func (g *Graph) AddEdge(from, to string) bool {
  33. m, ok := g.outputs[from]
  34. if !ok {
  35. return false
  36. }
  37. m[to] = len(m) + 1
  38. g.inputs[to]++
  39. return true
  40. }
  41. func (g *Graph) unsafeRemoveEdge(from, to string) {
  42. delete(g.outputs[from], to)
  43. g.inputs[to]--
  44. }
  45. func (g *Graph) RemoveEdge(from, to string) bool {
  46. if _, ok := g.outputs[from]; !ok {
  47. return false
  48. }
  49. g.unsafeRemoveEdge(from, to)
  50. return true
  51. }
  52. func (g *Graph) Toposort() ([]string, bool) {
  53. L := make([]string, 0, len(g.nodes))
  54. S := make([]string, 0, len(g.nodes))
  55. for _, n := range g.nodes {
  56. if g.inputs[n] == 0 {
  57. S = append(S, n)
  58. }
  59. }
  60. for len(S) > 0 {
  61. var n string
  62. n, S = S[0], S[1:]
  63. L = append(L, n)
  64. ms := make([]string, len(g.outputs[n]))
  65. for m, i := range g.outputs[n] {
  66. ms[i-1] = m
  67. }
  68. for _, m := range ms {
  69. g.unsafeRemoveEdge(n, m)
  70. if g.inputs[m] == 0 {
  71. S = append(S, m)
  72. }
  73. }
  74. }
  75. N := 0
  76. for _, v := range g.inputs {
  77. N += v
  78. }
  79. if N > 0 {
  80. return L, false
  81. }
  82. return L, true
  83. }