eval_utils.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. import sys
  2. import ujson as json
  3. import re
  4. import string
  5. from collections import Counter
  6. import pickle
  7. def normalize_answer(s):
  8. def remove_articles(text):
  9. return re.sub(r'\b(a|an|the)\b', ' ', text)
  10. def white_space_fix(text):
  11. return ' '.join(text.split())
  12. def remove_punc(text):
  13. exclude = set(string.punctuation)
  14. return ''.join(ch for ch in text if ch not in exclude)
  15. def lower(text):
  16. return text.lower()
  17. return white_space_fix(remove_articles(remove_punc(lower(s))))
  18. def f1_score(prediction, ground_truth):
  19. normalized_prediction = normalize_answer(prediction)
  20. normalized_ground_truth = normalize_answer(ground_truth)
  21. ZERO_METRIC = (0, 0, 0)
  22. if normalized_prediction in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:
  23. return ZERO_METRIC
  24. if normalized_ground_truth in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:
  25. return ZERO_METRIC
  26. prediction_tokens = normalized_prediction.split()
  27. ground_truth_tokens = normalized_ground_truth.split()
  28. common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
  29. num_same = sum(common.values())
  30. if num_same == 0:
  31. return ZERO_METRIC
  32. precision = 1.0 * num_same / len(prediction_tokens)
  33. recall = 1.0 * num_same / len(ground_truth_tokens)
  34. f1 = (2 * precision * recall) / (precision + recall)
  35. return f1, precision, recall
  36. def exact_match_score(prediction, ground_truth):
  37. return (normalize_answer(prediction) == normalize_answer(ground_truth))
  38. def update_answer(metrics, prediction, gold):
  39. em = exact_match_score(prediction, gold)
  40. f1, prec, recall = f1_score(prediction, gold)
  41. metrics['em'] += float(em)
  42. metrics['f1'] += f1
  43. metrics['prec'] += prec
  44. metrics['recall'] += recall
  45. return em, prec, recall
  46. def update_sp(metrics, prediction, gold):
  47. cur_sp_pred = set(map(tuple, prediction))
  48. gold_sp_pred = set(map(tuple, gold))
  49. tp, fp, fn = 0, 0, 0
  50. for e in cur_sp_pred:
  51. if e in gold_sp_pred:
  52. tp += 1
  53. else:
  54. fp += 1
  55. for e in gold_sp_pred:
  56. if e not in cur_sp_pred:
  57. fn += 1
  58. prec = 1.0 * tp / (tp + fp) if tp + fp > 0 else 0.0
  59. recall = 1.0 * tp / (tp + fn) if tp + fn > 0 else 0.0
  60. f1 = 2 * prec * recall / (prec + recall) if prec + recall > 0 else 0.0
  61. em = 1.0 if fp + fn == 0 else 0.0
  62. metrics['sp_em'] += em
  63. metrics['sp_f1'] += f1
  64. metrics['sp_prec'] += prec
  65. metrics['sp_recall'] += recall
  66. return em, prec, recall
  67. def eval(prediction, gold):
  68. metrics = {'em': 0, 'f1': 0, 'prec': 0, 'recall': 0,
  69. 'sp_em': 0, 'sp_f1': 0, 'sp_prec': 0, 'sp_recall': 0,
  70. 'joint_em': 0, 'joint_f1': 0, 'joint_prec': 0, 'joint_recall': 0}
  71. for dp in gold:
  72. cur_id = dp['_id']
  73. can_eval_joint = True
  74. if cur_id not in prediction['answer']:
  75. print('missing answer {}'.format(cur_id))
  76. can_eval_joint = False
  77. else:
  78. em, prec, recall = update_answer(
  79. metrics, prediction['answer'][cur_id], dp['answer'])
  80. if cur_id not in prediction['sp']:
  81. print('missing sp fact {}'.format(cur_id))
  82. can_eval_joint = False
  83. else:
  84. sp_em, sp_prec, sp_recall = update_sp(
  85. metrics, prediction['sp'][cur_id], dp['supporting_facts'])
  86. if can_eval_joint:
  87. joint_prec = prec * sp_prec
  88. joint_recall = recall * sp_recall
  89. if joint_prec + joint_recall > 0:
  90. joint_f1 = 2 * joint_prec * joint_recall / (joint_prec + joint_recall)
  91. else:
  92. joint_f1 = 0.
  93. joint_em = em * sp_em
  94. metrics['joint_em'] += joint_em
  95. metrics['joint_f1'] += joint_f1
  96. metrics['joint_prec'] += joint_prec
  97. metrics['joint_recall'] += joint_recall
  98. N = len(gold)
  99. for k in metrics.keys():
  100. metrics[k] /= N
  101. return metrics