evaluation.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. # Copyright 2017 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Parser evaluation utils."""
  16. from __future__ import division
  17. import tensorflow as tf
  18. from syntaxnet import sentence_pb2
  19. from syntaxnet.util import check
  20. def calculate_parse_metrics(gold_corpus, annotated_corpus):
  21. """Calculate POS/UAS/LAS accuracy based on gold and annotated sentences."""
  22. check.Eq(len(gold_corpus), len(annotated_corpus), 'Corpora are not aligned')
  23. num_tokens = 0
  24. num_correct_pos = 0
  25. num_correct_uas = 0
  26. num_correct_las = 0
  27. for gold_str, annotated_str in zip(gold_corpus, annotated_corpus):
  28. gold = sentence_pb2.Sentence()
  29. annotated = sentence_pb2.Sentence()
  30. gold.ParseFromString(gold_str)
  31. annotated.ParseFromString(annotated_str)
  32. check.Eq(gold.text, annotated.text, 'Text is not aligned')
  33. check.Eq(len(gold.token), len(annotated.token), 'Tokens are not aligned')
  34. tokens = zip(gold.token, annotated.token)
  35. num_tokens += len(tokens)
  36. num_correct_pos += sum(1 for x, y in tokens if x.tag == y.tag)
  37. num_correct_uas += sum(1 for x, y in tokens if x.head == y.head)
  38. num_correct_las += sum(1 for x, y in tokens
  39. if x.head == y.head and x.label == y.label)
  40. tf.logging.info('Total num documents: %d', len(annotated_corpus))
  41. tf.logging.info('Total num tokens: %d', num_tokens)
  42. pos = num_correct_pos * 100.0 / num_tokens
  43. uas = num_correct_uas * 100.0 / num_tokens
  44. las = num_correct_las * 100.0 / num_tokens
  45. tf.logging.info('POS: %.2f%%', pos)
  46. tf.logging.info('UAS: %.2f%%', uas)
  47. tf.logging.info('LAS: %.2f%%', las)
  48. return pos, uas, las
  49. def parser_summaries(gold_corpus, annotated_corpus):
  50. """Computes parser evaluation summaries for gold and annotated sentences."""
  51. pos, uas, las = calculate_parse_metrics(gold_corpus, annotated_corpus)
  52. return {'POS': pos, 'LAS': las, 'UAS': uas, 'eval_metric': las}
  53. def calculate_segmentation_metrics(gold_corpus, annotated_corpus):
  54. """Calculate precision/recall/f1 based on gold and annotated sentences."""
  55. check.Eq(len(gold_corpus), len(annotated_corpus), 'Corpora are not aligned')
  56. num_gold_tokens = 0
  57. num_test_tokens = 0
  58. num_correct_tokens = 0
  59. def token_span(token):
  60. check.Ge(token.end, token.start)
  61. return (token.start, token.end)
  62. def ratio(numerator, denominator):
  63. check.Ge(numerator, 0)
  64. check.Ge(denominator, 0)
  65. if denominator > 0:
  66. return numerator / denominator
  67. elif numerator == 0:
  68. return 0.0 # map 0/0 to 0
  69. else:
  70. return float('inf') # map x/0 to inf
  71. for gold_str, annotated_str in zip(gold_corpus, annotated_corpus):
  72. gold = sentence_pb2.Sentence()
  73. annotated = sentence_pb2.Sentence()
  74. gold.ParseFromString(gold_str)
  75. annotated.ParseFromString(annotated_str)
  76. check.Eq(gold.text, annotated.text, 'Text is not aligned')
  77. gold_spans = set()
  78. test_spans = set()
  79. for token in gold.token:
  80. check.NotIn(token_span(token), gold_spans, 'Duplicate token')
  81. gold_spans.add(token_span(token))
  82. for token in annotated.token:
  83. check.NotIn(token_span(token), test_spans, 'Duplicate token')
  84. test_spans.add(token_span(token))
  85. num_gold_tokens += len(gold_spans)
  86. num_test_tokens += len(test_spans)
  87. num_correct_tokens += len(gold_spans.intersection(test_spans))
  88. tf.logging.info('Total num documents: %d', len(annotated_corpus))
  89. tf.logging.info('Total gold tokens: %d', num_gold_tokens)
  90. tf.logging.info('Total test tokens: %d', num_test_tokens)
  91. precision = 100 * ratio(num_correct_tokens, num_test_tokens)
  92. recall = 100 * ratio(num_correct_tokens, num_gold_tokens)
  93. f1 = ratio(2 * precision * recall, precision + recall)
  94. tf.logging.info('Precision: %.2f%%', precision)
  95. tf.logging.info('Recall: %.2f%%', recall)
  96. tf.logging.info('F1: %.2f%%', f1)
  97. return round(precision, 2), round(recall, 2), round(f1, 2)
  98. def segmentation_summaries(gold_corpus, annotated_corpus):
  99. """Computes segmentation eval summaries for gold and annotated sentences."""
  100. prec, rec, f1 = calculate_segmentation_metrics(gold_corpus, annotated_corpus)
  101. return {'precision': prec, 'recall': rec, 'f1': f1, 'eval_metric': f1}