rbtree_test.go 9.3 KB


  1. package rbtree
  2. import (
  3. "fmt"
  4. "math/rand"
  5. "sort"
  6. "testing"
  7. "github.com/stretchr/testify/assert"
  8. )
  9. // Create a tree storing a set of integers
  10. func testNewIntSet() *RBTree {
  11. return NewRBTree(NewAllocator())
  12. }
  13. func testAssert(t *testing.T, b bool, message string) {
  14. assert.True(t, b, message)
  15. }
  16. func boolInsert(tree *RBTree, item int) bool {
  17. status, _ := tree.Insert(Item{uint32(item), uint32(item)})
  18. return status
  19. }
  20. func TestEmpty(t *testing.T) {
  21. tree := testNewIntSet()
  22. testAssert(t, tree.Len() == 0, "len!=0")
  23. testAssert(t, tree.Max().NegativeLimit(), "neglimit")
  24. testAssert(t, tree.Min().Limit(), "limit")
  25. testAssert(t, tree.FindGE(10).Limit(), "Not empty")
  26. testAssert(t, tree.FindLE(10).NegativeLimit(), "Not empty")
  27. testAssert(t, tree.Get(10) == nil, "Not empty")
  28. testAssert(t, tree.Limit().Equal(tree.Min()), "iter")
  29. }
  30. func TestFindGE(t *testing.T) {
  31. tree := testNewIntSet()
  32. testAssert(t, boolInsert(tree, 10), "Insert1")
  33. testAssert(t, !boolInsert(tree, 10), "Insert2")
  34. testAssert(t, tree.Len() == 1, "len==1")
  35. testAssert(t, tree.FindGE(10).Item().Key == 10, "FindGE 10")
  36. testAssert(t, tree.FindGE(11).Limit(), "FindGE 11")
  37. assert.Equal(t, tree.FindGE(9).Item().Key, uint32(10), "FindGE 10")
  38. }
  39. func TestFindLE(t *testing.T) {
  40. tree := testNewIntSet()
  41. testAssert(t, boolInsert(tree, 10), "insert1")
  42. testAssert(t, tree.FindLE(10).Item().Key == 10, "FindLE 10")
  43. testAssert(t, tree.FindLE(11).Item().Key == 10, "FindLE 11")
  44. testAssert(t, tree.FindLE(9).NegativeLimit(), "FindLE 9")
  45. }
  46. func TestGet(t *testing.T) {
  47. tree := testNewIntSet()
  48. testAssert(t, boolInsert(tree, 10), "insert1")
  49. assert.Equal(t, *tree.Get(10), uint32(10), "Get 10")
  50. testAssert(t, tree.Get(9) == nil, "Get 9")
  51. testAssert(t, tree.Get(11) == nil, "Get 11")
  52. }
  53. func TestDelete(t *testing.T) {
  54. tree := testNewIntSet()
  55. testAssert(t, !tree.DeleteWithKey(10), "del")
  56. testAssert(t, tree.Len() == 0, "dellen")
  57. testAssert(t, boolInsert(tree, 10), "ins")
  58. testAssert(t, tree.DeleteWithKey(10), "del")
  59. testAssert(t, tree.Len() == 0, "dellen")
  60. // delete was deleting after the request if request not found
  61. // ensure this does not regress:
  62. testAssert(t, boolInsert(tree, 10), "ins")
  63. testAssert(t, !tree.DeleteWithKey(9), "del")
  64. testAssert(t, tree.Len() == 1, "dellen")
  65. }
  66. func iterToString(i Iterator) string {
  67. s := ""
  68. for ; !i.Limit(); i = i.Next() {
  69. if s != "" {
  70. s = s + ","
  71. }
  72. s = s + fmt.Sprintf("%d", i.Item().Key)
  73. }
  74. return s
  75. }
  76. func reverseIterToString(i Iterator) string {
  77. s := ""
  78. for ; !i.NegativeLimit(); i = i.Prev() {
  79. if s != "" {
  80. s = s + ","
  81. }
  82. s = s + fmt.Sprintf("%d", i.Item().Key)
  83. }
  84. return s
  85. }
  86. func TestIterator(t *testing.T) {
  87. tree := testNewIntSet()
  88. for i := 0; i < 10; i = i + 2 {
  89. boolInsert(tree, i)
  90. }
  91. assert.Equal(t, iterToString(tree.FindGE(3)), "4,6,8")
  92. assert.Equal(t, iterToString(tree.FindGE(4)), "4,6,8")
  93. assert.Equal(t, iterToString(tree.FindGE(8)), "8")
  94. assert.Equal(t, iterToString(tree.FindGE(9)), "")
  95. assert.Equal(t, reverseIterToString(tree.FindLE(3)), "2,0")
  96. assert.Equal(t, reverseIterToString(tree.FindLE(2)), "2,0")
  97. assert.Equal(t, reverseIterToString(tree.FindLE(0)), "0")
  98. }
  99. //
  100. // Randomized tests
  101. //
  102. // oracle stores provides an interface similar to rbtree, but stores
  103. // data in an sorted array
  104. type oracle struct {
  105. data []int
  106. }
  107. func newOracle() *oracle {
  108. return &oracle{data: make([]int, 0)}
  109. }
  110. func (o *oracle) Len() int {
  111. return len(o.data)
  112. }
  113. // interface needed for sorting
  114. func (o *oracle) Less(i, j int) bool {
  115. return o.data[i] < o.data[j]
  116. }
  117. func (o *oracle) Swap(i, j int) {
  118. e := o.data[j]
  119. o.data[j] = o.data[i]
  120. o.data[i] = e
  121. }
  122. func (o *oracle) Insert(key int) bool {
  123. for _, e := range o.data {
  124. if e == key {
  125. return false
  126. }
  127. }
  128. n := len(o.data) + 1
  129. newData := make([]int, n)
  130. copy(newData, o.data)
  131. newData[n-1] = key
  132. o.data = newData
  133. sort.Sort(o)
  134. return true
  135. }
  136. func (o *oracle) RandomExistingKey(rand *rand.Rand) int {
  137. index := rand.Int31n(int32(len(o.data)))
  138. return o.data[index]
  139. }
  140. func (o *oracle) FindGE(t *testing.T, key int) oracleIterator {
  141. prev := int(-1)
  142. for i, e := range o.data {
  143. if e <= prev {
  144. t.Fatal("Nonsorted oracle ", e, prev)
  145. }
  146. if e >= key {
  147. return oracleIterator{o: o, index: i}
  148. }
  149. }
  150. return oracleIterator{o: o, index: len(o.data)}
  151. }
  152. func (o *oracle) FindLE(t *testing.T, key int) oracleIterator {
  153. iter := o.FindGE(t, key)
  154. if !iter.Limit() && o.data[iter.index] == key {
  155. return iter
  156. }
  157. return oracleIterator{o, iter.index - 1}
  158. }
  159. func (o *oracle) Delete(key int) bool {
  160. for i, e := range o.data {
  161. if e == key {
  162. newData := make([]int, len(o.data)-1)
  163. copy(newData, o.data[0:i])
  164. copy(newData[i:], o.data[i+1:])
  165. o.data = newData
  166. return true
  167. }
  168. }
  169. return false
  170. }
  171. //
  172. // Test iterator
  173. //
  174. type oracleIterator struct {
  175. o *oracle
  176. index int
  177. }
  178. func (oiter oracleIterator) Limit() bool {
  179. return oiter.index >= len(oiter.o.data)
  180. }
  181. func (oiter oracleIterator) Min() bool {
  182. return oiter.index == 0
  183. }
  184. func (oiter oracleIterator) NegativeLimit() bool {
  185. return oiter.index < 0
  186. }
  187. func (oiter oracleIterator) Max() bool {
  188. return oiter.index == len(oiter.o.data)-1
  189. }
  190. func (oiter oracleIterator) Item() int {
  191. return oiter.o.data[oiter.index]
  192. }
  193. func (oiter oracleIterator) Next() oracleIterator {
  194. return oracleIterator{oiter.o, oiter.index + 1}
  195. }
  196. func (oiter oracleIterator) Prev() oracleIterator {
  197. return oracleIterator{oiter.o, oiter.index - 1}
  198. }
  199. func compareContents(t *testing.T, oiter oracleIterator, titer Iterator) {
  200. oi := oiter
  201. ti := titer
  202. // Test forward iteration
  203. testAssert(t, oi.NegativeLimit() == ti.NegativeLimit(), "rend")
  204. if oi.NegativeLimit() {
  205. oi = oi.Next()
  206. ti = ti.Next()
  207. }
  208. for !oi.Limit() && !ti.Limit() {
  209. // log.Print("Item: ", oi.Item(), ti.Item())
  210. if ti.Item().Key != uint32(oi.Item()) {
  211. t.Fatal("Wrong item", ti.Item(), oi.Item())
  212. }
  213. oi = oi.Next()
  214. ti = ti.Next()
  215. }
  216. if !ti.Limit() {
  217. t.Fatal("!ti.done", ti.Item())
  218. }
  219. if !oi.Limit() {
  220. t.Fatal("!oi.done", oi.Item())
  221. }
  222. // Test reverse iteration
  223. oi = oiter
  224. ti = titer
  225. testAssert(t, oi.Limit() == ti.Limit(), "end")
  226. if oi.Limit() {
  227. oi = oi.Prev()
  228. ti = ti.Prev()
  229. }
  230. for !oi.NegativeLimit() && !ti.NegativeLimit() {
  231. if ti.Item().Key != uint32(oi.Item()) {
  232. t.Fatal("Wrong item", ti.Item(), oi.Item())
  233. }
  234. oi = oi.Prev()
  235. ti = ti.Prev()
  236. }
  237. if !ti.NegativeLimit() {
  238. t.Fatal("!ti.done", ti.Item())
  239. }
  240. if !oi.NegativeLimit() {
  241. t.Fatal("!oi.done", oi.Item())
  242. }
  243. }
  244. func compareContentsFull(t *testing.T, o *oracle, tree *RBTree) {
  245. compareContents(t, o.FindGE(t, -1), tree.FindGE(0))
  246. }
  247. func TestRandomized(t *testing.T) {
  248. const numKeys = 1000
  249. o := newOracle()
  250. tree := testNewIntSet()
  251. r := rand.New(rand.NewSource(0))
  252. for i := 0; i < 10000; i++ {
  253. op := r.Int31n(100)
  254. if op < 50 {
  255. key := r.Int31n(numKeys)
  256. o.Insert(int(key))
  257. boolInsert(tree, int(key))
  258. compareContentsFull(t, o, tree)
  259. } else if op < 90 && o.Len() > 0 {
  260. key := o.RandomExistingKey(r)
  261. o.Delete(key)
  262. if !tree.DeleteWithKey(uint32(key)) {
  263. t.Fatal("DeleteExisting", key)
  264. }
  265. compareContentsFull(t, o, tree)
  266. } else if op < 95 {
  267. key := int(r.Int31n(numKeys))
  268. compareContents(t, o.FindGE(t, key), tree.FindGE(uint32(key)))
  269. } else {
  270. key := int(r.Int31n(numKeys))
  271. compareContents(t, o.FindLE(t, key), tree.FindLE(uint32(key)))
  272. }
  273. }
  274. }
  275. func TestCloneShallow(t *testing.T) {
  276. alloc1 := NewAllocator()
  277. alloc1.malloc()
  278. tree := NewRBTree(alloc1)
  279. tree.Insert(Item{7, 7})
  280. assert.Equal(t, alloc1.storage, []node{{}, {}, {color: black, item: Item{7, 7}}})
  281. assert.Equal(t, tree.minNode, uint32(2))
  282. assert.Equal(t, tree.maxNode, uint32(2))
  283. alloc2 := alloc1.Clone()
  284. clone := tree.CloneShallow(alloc2)
  285. assert.Equal(t, alloc2.storage, []node{{}, {}, {color: black, item: Item{7, 7}}})
  286. assert.Equal(t, clone.minNode, uint32(2))
  287. assert.Equal(t, clone.maxNode, uint32(2))
  288. assert.Equal(t, alloc2.Size(), 3)
  289. tree.Insert(Item{10, 10})
  290. alloc3 := alloc1.Clone()
  291. clone = tree.CloneShallow(alloc3)
  292. assert.Equal(t, alloc3.storage, []node{
  293. {}, {},
  294. {right: 3, color: black, item: Item{7, 7}},
  295. {parent: 2, color: red, item: Item{10, 10}}})
  296. assert.Equal(t, clone.minNode, uint32(2))
  297. assert.Equal(t, clone.maxNode, uint32(3))
  298. assert.Equal(t, alloc3.Size(), 4)
  299. assert.Equal(t, alloc2.Size(), 3)
  300. }
  301. func TestCloneDeep(t *testing.T) {
  302. alloc1 := NewAllocator()
  303. alloc1.malloc()
  304. tree := NewRBTree(alloc1)
  305. tree.Insert(Item{7, 7})
  306. assert.Equal(t, alloc1.storage, []node{{}, {}, {color: black, item: Item{7, 7}}})
  307. assert.Equal(t, tree.minNode, uint32(2))
  308. assert.Equal(t, tree.maxNode, uint32(2))
  309. alloc2 := NewAllocator()
  310. clone := tree.CloneDeep(alloc2)
  311. assert.Equal(t, alloc2.storage, []node{{}, {color: black, item: Item{7, 7}}})
  312. assert.Equal(t, clone.minNode, uint32(1))
  313. assert.Equal(t, clone.maxNode, uint32(1))
  314. assert.Equal(t, alloc2.Size(), 2)
  315. tree.Insert(Item{10, 10})
  316. alloc2 = NewAllocator()
  317. clone = tree.CloneDeep(alloc2)
  318. assert.Equal(t, alloc2.storage, []node{
  319. {},
  320. {right: 2, color: black, item: Item{7, 7}},
  321. {parent: 1, color: red, item: Item{10, 10}}})
  322. assert.Equal(t, clone.minNode, uint32(1))
  323. assert.Equal(t, clone.maxNode, uint32(2))
  324. assert.Equal(t, alloc2.Size(), 3)
  325. }
  326. func TestErase(t *testing.T) {
  327. alloc := NewAllocator()
  328. tree := NewRBTree(alloc)
  329. for i := 0; i < 10; i++ {
  330. tree.Insert(Item{uint32(i), uint32(i)})
  331. }
  332. assert.Equal(t, alloc.Used(), 11)
  333. tree.Erase()
  334. assert.Equal(t, alloc.Used(), 1)
  335. assert.Equal(t, alloc.Size(), 11)
  336. }