errorcounter_test.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Tests for errorcounter."""
  16. import tensorflow as tf
  17. import errorcounter as ec
  18. class ErrorcounterTest(tf.test.TestCase):
  19. def testComputeErrorRate(self):
  20. """Tests that the percent calculation works as expected.
  21. """
  22. rate = ec.ComputeErrorRate(error_count=0, truth_count=0)
  23. self.assertEqual(rate, 100.0)
  24. rate = ec.ComputeErrorRate(error_count=1, truth_count=0)
  25. self.assertEqual(rate, 100.0)
  26. rate = ec.ComputeErrorRate(error_count=10, truth_count=1)
  27. self.assertEqual(rate, 100.0)
  28. rate = ec.ComputeErrorRate(error_count=0, truth_count=1)
  29. self.assertEqual(rate, 0.0)
  30. rate = ec.ComputeErrorRate(error_count=3, truth_count=12)
  31. self.assertEqual(rate, 25.0)
  32. def testCountErrors(self):
  33. """Tests that the error counter works as expected.
  34. """
  35. truth_str = 'farm barn'
  36. counts = ec.CountErrors(ocr_text=truth_str, truth_text=truth_str)
  37. self.assertEqual(
  38. counts, ec.ErrorCounts(
  39. fn=0, fp=0, truth_count=9, test_count=9))
  40. # With a period on the end, we get a char error.
  41. dot_str = 'farm barn.'
  42. counts = ec.CountErrors(ocr_text=dot_str, truth_text=truth_str)
  43. self.assertEqual(
  44. counts, ec.ErrorCounts(
  45. fn=0, fp=1, truth_count=9, test_count=10))
  46. counts = ec.CountErrors(ocr_text=truth_str, truth_text=dot_str)
  47. self.assertEqual(
  48. counts, ec.ErrorCounts(
  49. fn=1, fp=0, truth_count=10, test_count=9))
  50. # Space is just another char.
  51. no_space = 'farmbarn'
  52. counts = ec.CountErrors(ocr_text=no_space, truth_text=truth_str)
  53. self.assertEqual(
  54. counts, ec.ErrorCounts(
  55. fn=1, fp=0, truth_count=9, test_count=8))
  56. counts = ec.CountErrors(ocr_text=truth_str, truth_text=no_space)
  57. self.assertEqual(
  58. counts, ec.ErrorCounts(
  59. fn=0, fp=1, truth_count=8, test_count=9))
  60. # Lose them all.
  61. counts = ec.CountErrors(ocr_text='', truth_text=truth_str)
  62. self.assertEqual(
  63. counts, ec.ErrorCounts(
  64. fn=9, fp=0, truth_count=9, test_count=0))
  65. counts = ec.CountErrors(ocr_text=truth_str, truth_text='')
  66. self.assertEqual(
  67. counts, ec.ErrorCounts(
  68. fn=0, fp=9, truth_count=0, test_count=9))
  69. def testCountWordErrors(self):
  70. """Tests that the error counter works as expected.
  71. """
  72. truth_str = 'farm barn'
  73. counts = ec.CountWordErrors(ocr_text=truth_str, truth_text=truth_str)
  74. self.assertEqual(
  75. counts, ec.ErrorCounts(
  76. fn=0, fp=0, truth_count=2, test_count=2))
  77. # With a period on the end, we get a word error.
  78. dot_str = 'farm barn.'
  79. counts = ec.CountWordErrors(ocr_text=dot_str, truth_text=truth_str)
  80. self.assertEqual(
  81. counts, ec.ErrorCounts(
  82. fn=1, fp=1, truth_count=2, test_count=2))
  83. counts = ec.CountWordErrors(ocr_text=truth_str, truth_text=dot_str)
  84. self.assertEqual(
  85. counts, ec.ErrorCounts(
  86. fn=1, fp=1, truth_count=2, test_count=2))
  87. # Space is special.
  88. no_space = 'farmbarn'
  89. counts = ec.CountWordErrors(ocr_text=no_space, truth_text=truth_str)
  90. self.assertEqual(
  91. counts, ec.ErrorCounts(
  92. fn=2, fp=1, truth_count=2, test_count=1))
  93. counts = ec.CountWordErrors(ocr_text=truth_str, truth_text=no_space)
  94. self.assertEqual(
  95. counts, ec.ErrorCounts(
  96. fn=1, fp=2, truth_count=1, test_count=2))
  97. # Lose them all.
  98. counts = ec.CountWordErrors(ocr_text='', truth_text=truth_str)
  99. self.assertEqual(
  100. counts, ec.ErrorCounts(
  101. fn=2, fp=0, truth_count=2, test_count=0))
  102. counts = ec.CountWordErrors(ocr_text=truth_str, truth_text='')
  103. self.assertEqual(
  104. counts, ec.ErrorCounts(
  105. fn=0, fp=2, truth_count=0, test_count=2))
  106. # With a space in ba rn, there is an extra add.
  107. sp_str = 'farm ba rn'
  108. counts = ec.CountWordErrors(ocr_text=sp_str, truth_text=truth_str)
  109. self.assertEqual(
  110. counts, ec.ErrorCounts(
  111. fn=1, fp=2, truth_count=2, test_count=3))
  112. counts = ec.CountWordErrors(ocr_text=truth_str, truth_text=sp_str)
  113. self.assertEqual(
  114. counts, ec.ErrorCounts(
  115. fn=2, fp=1, truth_count=3, test_count=2))
  116. if __name__ == '__main__':
  117. tf.test.main()