qa_utils.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. #!/usr/bin/env python3
  2. # Copyright (c) Facebook, Inc. and its affiliates.
  3. # All rights reserved.
  4. #
  5. # The following code has been taken from
  6. # https://github.com/facebookresearch/DPR, which is CC-BY-NC 4.0
  7. # licensed as of now. More details on the license can be found
  8. # at https://github.com/facebookresearch/DPR/blob/master/LICENSE
  9. """
  10. Set of utilities for Q&A results validation tasks - Retriver passage
  11. validation and Reader predicted answer validation
  12. """
  13. import collections
  14. import logging
  15. import string
  16. import unicodedata
  17. from functools import partial
  18. from multiprocessing import Pool as ProcessPool
  19. from typing import Tuple, List, Dict
  20. import regex as re
  21. from tasks.orqa.unsupervised.tokenizers import SimpleTokenizer
  22. logger = logging.getLogger(__name__)
  23. QAMatchStats = collections.namedtuple('QAMatchStats', ['top_k_hits',\
  24. 'questions_doc_hits'])
  25. def calculate_matches(all_docs: Dict[object, Tuple[str, str]],
  26. answers: List[List[str]], closest_docs: List[Tuple[List[object],
  27. List[float]]], workers_num: int, match_type: str) -> QAMatchStats:
  28. """
  29. Evaluates answers presence in the set of documents. This function is
  30. supposed to be used with a large collection of documents and results.
  31. It internally forks multiple sub-processes for evaluation and then
  32. merges results
  33. :param all_docs: dictionary of the entire documents database.
  34. doc_id -> (doc_text, title)
  35. :param answers: list of answers's list. One list per question
  36. :param closest_docs: document ids of the top results along with their
  37. scores
  38. :param workers_num: amount of parallel threads to process data
  39. :param match_type: type of answer matching. Refer to has_answer code for
  40. available options
  41. :return: matching information tuple.
  42. top_k_hits - a list where the index is the amount of top documents retrieved
  43. and the value is the total amount of valid matches across an entire
  44. dataset.
  45. questions_doc_hits - more detailed info with answer matches for every
  46. question and every retrieved document
  47. """
  48. global dpr_all_documents
  49. dpr_all_documents = all_docs
  50. tok_opts = {}
  51. tokenizer = SimpleTokenizer(**tok_opts)
  52. processes = ProcessPool(
  53. processes=workers_num,
  54. )
  55. logger.info('Matching answers in top docs...')
  56. get_score_partial = partial(check_answer, match_type=match_type,
  57. tokenizer=tokenizer)
  58. questions_answers_docs = zip(answers, closest_docs)
  59. scores = processes.map(get_score_partial, questions_answers_docs)
  60. logger.info('Per question validation results len=%d', len(scores))
  61. n_docs = len(closest_docs[0][0])
  62. top_k_hits = [0] * n_docs
  63. for question_hits in scores:
  64. best_hit = next((i for i, x in enumerate(question_hits) if x), None)
  65. if best_hit is not None:
  66. top_k_hits[best_hit:] = [v + 1 for v in top_k_hits[best_hit:]]
  67. return QAMatchStats(top_k_hits, scores)
  68. def check_answer(questions_answers_docs, tokenizer, match_type) -> List[bool]:
  69. """
  70. Search through all the top docs to see if they have any of the answers.
  71. """
  72. answers, (doc_ids, doc_scores) = questions_answers_docs
  73. global dpr_all_documents
  74. hits = []
  75. for i, doc_id in enumerate(doc_ids):
  76. doc = dpr_all_documents[doc_id]
  77. text = doc[0]
  78. answer_found = False
  79. if text is None: # cannot find the document for some reason
  80. logger.warning("no doc in db")
  81. hits.append(False)
  82. continue
  83. if has_answer(answers, text, tokenizer, match_type):
  84. answer_found = True
  85. hits.append(answer_found)
  86. return hits
  87. def has_answer(answers, text, tokenizer, match_type) -> bool:
  88. """
  89. Check if a document contains an answer string.
  90. If `match_type` is string, token matching is done between the text
  91. and answer.
  92. If `match_type` is regex, we search the whole text with the regex.
  93. """
  94. text = _normalize(text)
  95. if match_type == 'string':
  96. # Answer is a list of possible strings
  97. text = tokenizer.tokenize(text).words(uncased=True)
  98. for single_answer in answers:
  99. single_answer = _normalize(single_answer)
  100. single_answer = tokenizer.tokenize(single_answer)
  101. single_answer = single_answer.words(uncased=True)
  102. for i in range(0, len(text) - len(single_answer) + 1):
  103. if single_answer == text[i: i + len(single_answer)]:
  104. return True
  105. elif match_type == 'regex':
  106. # Answer is a regex
  107. for single_answer in answers:
  108. single_answer = _normalize(single_answer)
  109. if regex_match(text, single_answer):
  110. return True
  111. return False
  112. def regex_match(text, pattern):
  113. """Test if a regex pattern is contained within a text."""
  114. try:
  115. pattern = re.compile(
  116. pattern,
  117. flags=re.IGNORECASE + re.UNICODE + re.MULTILINE,
  118. )
  119. except BaseException:
  120. return False
  121. return pattern.search(text) is not None
  122. # function for the reader model answer validation
  123. def exact_match_score(prediction, ground_truth):
  124. return _normalize_answer(prediction) == _normalize_answer(ground_truth)
  125. def _normalize_answer(s):
  126. def remove_articles(text):
  127. return re.sub(r'\b(a|an|the)\b', ' ', text)
  128. def white_space_fix(text):
  129. return ' '.join(text.split())
  130. def remove_punc(text):
  131. exclude = set(string.punctuation)
  132. return ''.join(ch for ch in text if ch not in exclude)
  133. def lower(text):
  134. return text.lower()
  135. return white_space_fix(remove_articles(remove_punc(lower(s))))
  136. def _normalize(text):
  137. return unicodedata.normalize('NFD', text)