errorcounter.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Some simple tools for error counting.
  16. """
  17. import collections
  18. # Named tuple Error counts describes the counts needed to accumulate errors
  19. # over multiple trials:
  20. # false negatives (aka drops or deletions),
  21. # false positives: (aka adds or insertions),
  22. # truth_count: number of elements in ground truth = denominator for fn,
  23. # test_count: number of elements in test string = denominator for fp,
  24. # Note that recall = 1 - fn/truth_count, precision = 1 - fp/test_count,
  25. # accuracy = 1 - (fn + fp) / (truth_count + test_count).
  26. ErrorCounts = collections.namedtuple('ErrorCounts', ['fn', 'fp', 'truth_count',
  27. 'test_count'])
  28. # Named tuple for error rates, as a percentage. Accuracies are just 100-error.
  29. ErrorRates = collections.namedtuple('ErrorRates',
  30. ['label_error', 'word_recall_error',
  31. 'word_precision_error', 'sequence_error'])
  32. def CountWordErrors(ocr_text, truth_text):
  33. """Counts the word drop and add errors as a bag of words.
  34. Args:
  35. ocr_text: OCR text string.
  36. truth_text: Truth text string.
  37. Returns:
  38. ErrorCounts named tuple.
  39. """
  40. # Convert to lists of words.
  41. return CountErrors(ocr_text.split(), truth_text.split())
  42. def CountErrors(ocr_text, truth_text):
  43. """Counts the drops and adds between 2 bags of iterables.
  44. Simple bag of objects count returns the number of dropped and added
  45. elements, regardless of order, from anything that is iterable, eg
  46. a pair of strings gives character errors, and a pair of word lists give
  47. word errors.
  48. Args:
  49. ocr_text: OCR text iterable (eg string for chars, word list for words).
  50. truth_text: Truth text iterable.
  51. Returns:
  52. ErrorCounts named tuple.
  53. """
  54. counts = collections.Counter(truth_text)
  55. counts.subtract(ocr_text)
  56. drops = sum(c for c in counts.values() if c > 0)
  57. adds = sum(-c for c in counts.values() if c < 0)
  58. return ErrorCounts(drops, adds, len(truth_text), len(ocr_text))
  59. def AddErrors(counts1, counts2):
  60. """Adds the counts and returns a new sum tuple.
  61. Args:
  62. counts1: ErrorCounts named tuples to sum.
  63. counts2: ErrorCounts named tuples to sum.
  64. Returns:
  65. Sum of counts1, counts2.
  66. """
  67. return ErrorCounts(counts1.fn + counts2.fn, counts1.fp + counts2.fp,
  68. counts1.truth_count + counts2.truth_count,
  69. counts1.test_count + counts2.test_count)
  70. def ComputeErrorRates(label_counts, word_counts, seq_errors, num_seqs):
  71. """Returns an ErrorRates corresponding to the given counts.
  72. Args:
  73. label_counts: ErrorCounts for the character labels
  74. word_counts: ErrorCounts for the words
  75. seq_errors: Number of sequence errors
  76. num_seqs: Total sequences
  77. Returns:
  78. ErrorRates corresponding to the given counts.
  79. """
  80. label_errors = label_counts.fn + label_counts.fp
  81. num_labels = label_counts.truth_count + label_counts.test_count
  82. return ErrorRates(
  83. ComputeErrorRate(label_errors, num_labels),
  84. ComputeErrorRate(word_counts.fn, word_counts.truth_count),
  85. ComputeErrorRate(word_counts.fp, word_counts.test_count),
  86. ComputeErrorRate(seq_errors, num_seqs))
  87. def ComputeErrorRate(error_count, truth_count):
  88. """Returns a sanitized percent error rate from the raw counts.
  89. Prevents div by 0 and clips return to 100%.
  90. Args:
  91. error_count: Number of errors.
  92. truth_count: Number to divide by.
  93. Returns:
  94. 100.0 * error_count / truth_count clipped to 100.
  95. """
  96. if truth_count == 0:
  97. truth_count = 1
  98. error_count = 1
  99. elif error_count > truth_count:
  100. error_count = truth_count
  101. return error_count * 100.0 / truth_count