123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123 |
- import sys
- import ujson as json
- import re
- import string
- from collections import Counter
- import pickle
- def normalize_answer(s):
- def remove_articles(text):
- return re.sub(r'\b(a|an|the)\b', ' ', text)
- def white_space_fix(text):
- return ' '.join(text.split())
- def remove_punc(text):
- exclude = set(string.punctuation)
- return ''.join(ch for ch in text if ch not in exclude)
- def lower(text):
- return text.lower()
- return white_space_fix(remove_articles(remove_punc(lower(s))))
- def f1_score(prediction, ground_truth):
- normalized_prediction = normalize_answer(prediction)
- normalized_ground_truth = normalize_answer(ground_truth)
- ZERO_METRIC = (0, 0, 0)
- if normalized_prediction in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:
- return ZERO_METRIC
- if normalized_ground_truth in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:
- return ZERO_METRIC
- prediction_tokens = normalized_prediction.split()
- ground_truth_tokens = normalized_ground_truth.split()
- common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
- num_same = sum(common.values())
- if num_same == 0:
- return ZERO_METRIC
- precision = 1.0 * num_same / len(prediction_tokens)
- recall = 1.0 * num_same / len(ground_truth_tokens)
- f1 = (2 * precision * recall) / (precision + recall)
- return f1, precision, recall
- def exact_match_score(prediction, ground_truth):
- return (normalize_answer(prediction) == normalize_answer(ground_truth))
- def update_answer(metrics, prediction, gold):
- em = exact_match_score(prediction, gold)
- f1, prec, recall = f1_score(prediction, gold)
- metrics['em'] += float(em)
- metrics['f1'] += f1
- metrics['prec'] += prec
- metrics['recall'] += recall
- return em, prec, recall
- def update_sp(metrics, prediction, gold):
- cur_sp_pred = set(map(tuple, prediction))
- gold_sp_pred = set(map(tuple, gold))
- tp, fp, fn = 0, 0, 0
- for e in cur_sp_pred:
- if e in gold_sp_pred:
- tp += 1
- else:
- fp += 1
- for e in gold_sp_pred:
- if e not in cur_sp_pred:
- fn += 1
- prec = 1.0 * tp / (tp + fp) if tp + fp > 0 else 0.0
- recall = 1.0 * tp / (tp + fn) if tp + fn > 0 else 0.0
- f1 = 2 * prec * recall / (prec + recall) if prec + recall > 0 else 0.0
- em = 1.0 if fp + fn == 0 else 0.0
- metrics['sp_em'] += em
- metrics['sp_f1'] += f1
- metrics['sp_prec'] += prec
- metrics['sp_recall'] += recall
- return em, prec, recall
- def eval(prediction, gold):
- metrics = {'em': 0, 'f1': 0, 'prec': 0, 'recall': 0,
- 'sp_em': 0, 'sp_f1': 0, 'sp_prec': 0, 'sp_recall': 0,
- 'joint_em': 0, 'joint_f1': 0, 'joint_prec': 0, 'joint_recall': 0}
- for dp in gold:
- cur_id = dp['_id']
- can_eval_joint = True
- if cur_id not in prediction['answer']:
- print('missing answer {}'.format(cur_id))
- can_eval_joint = False
- else:
- em, prec, recall = update_answer(
- metrics, prediction['answer'][cur_id], dp['answer'])
- if cur_id not in prediction['sp']:
- print('missing sp fact {}'.format(cur_id))
- can_eval_joint = False
- else:
- sp_em, sp_prec, sp_recall = update_sp(
- metrics, prediction['sp'][cur_id], dp['supporting_facts'])
- if can_eval_joint:
- joint_prec = prec * sp_prec
- joint_recall = recall * sp_recall
- if joint_prec + joint_recall > 0:
- joint_f1 = 2 * joint_prec * joint_recall / (joint_prec + joint_recall)
- else:
- joint_f1 = 0.
- joint_em = em * sp_em
- metrics['joint_em'] += joint_em
- metrics['joint_f1'] += joint_f1
- metrics['joint_prec'] += joint_prec
- metrics['joint_recall'] += joint_recall
- N = len(gold)
- for k in metrics.keys():
- metrics[k] /= N
- return metrics
|