123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178 |
- #!/usr/bin/env python3
- # Copyright (c) Facebook, Inc. and its affiliates.
- # All rights reserved.
- #
- # The following code has been taken from
- # https://github.com/facebookresearch/DPR, which is CC-BY-NC 4.0
- # licensed as of now. More details on the license can be found
- # at https://github.com/facebookresearch/DPR/blob/master/LICENSE
- """
- Set of utilities for Q&A results validation tasks - Retriver passage
- validation and Reader predicted answer validation
- """
- import collections
- import logging
- import string
- import unicodedata
- from functools import partial
- from multiprocessing import Pool as ProcessPool
- from typing import Tuple, List, Dict
- import regex as re
- from tasks.orqa.unsupervised.tokenizers import SimpleTokenizer
- logger = logging.getLogger(__name__)
- QAMatchStats = collections.namedtuple('QAMatchStats', ['top_k_hits',\
- 'questions_doc_hits'])
- def calculate_matches(all_docs: Dict[object, Tuple[str, str]],
- answers: List[List[str]], closest_docs: List[Tuple[List[object],
- List[float]]], workers_num: int, match_type: str) -> QAMatchStats:
- """
- Evaluates answers presence in the set of documents. This function is
- supposed to be used with a large collection of documents and results.
- It internally forks multiple sub-processes for evaluation and then
- merges results
- :param all_docs: dictionary of the entire documents database.
- doc_id -> (doc_text, title)
- :param answers: list of answers's list. One list per question
- :param closest_docs: document ids of the top results along with their
- scores
- :param workers_num: amount of parallel threads to process data
- :param match_type: type of answer matching. Refer to has_answer code for
- available options
- :return: matching information tuple.
- top_k_hits - a list where the index is the amount of top documents retrieved
- and the value is the total amount of valid matches across an entire
- dataset.
- questions_doc_hits - more detailed info with answer matches for every
- question and every retrieved document
- """
- global dpr_all_documents
- dpr_all_documents = all_docs
- tok_opts = {}
- tokenizer = SimpleTokenizer(**tok_opts)
- processes = ProcessPool(
- processes=workers_num,
- )
- logger.info('Matching answers in top docs...')
- get_score_partial = partial(check_answer, match_type=match_type,
- tokenizer=tokenizer)
- questions_answers_docs = zip(answers, closest_docs)
- scores = processes.map(get_score_partial, questions_answers_docs)
- logger.info('Per question validation results len=%d', len(scores))
- n_docs = len(closest_docs[0][0])
- top_k_hits = [0] * n_docs
- for question_hits in scores:
- best_hit = next((i for i, x in enumerate(question_hits) if x), None)
- if best_hit is not None:
- top_k_hits[best_hit:] = [v + 1 for v in top_k_hits[best_hit:]]
- return QAMatchStats(top_k_hits, scores)
- def check_answer(questions_answers_docs, tokenizer, match_type) -> List[bool]:
- """
- Search through all the top docs to see if they have any of the answers.
- """
- answers, (doc_ids, doc_scores) = questions_answers_docs
- global dpr_all_documents
- hits = []
- for i, doc_id in enumerate(doc_ids):
- doc = dpr_all_documents[doc_id]
- text = doc[0]
- answer_found = False
- if text is None: # cannot find the document for some reason
- logger.warning("no doc in db")
- hits.append(False)
- continue
- if has_answer(answers, text, tokenizer, match_type):
- answer_found = True
- hits.append(answer_found)
- return hits
- def has_answer(answers, text, tokenizer, match_type) -> bool:
- """
- Check if a document contains an answer string.
- If `match_type` is string, token matching is done between the text
- and answer.
- If `match_type` is regex, we search the whole text with the regex.
- """
- text = _normalize(text)
- if match_type == 'string':
- # Answer is a list of possible strings
- text = tokenizer.tokenize(text).words(uncased=True)
- for single_answer in answers:
- single_answer = _normalize(single_answer)
- single_answer = tokenizer.tokenize(single_answer)
- single_answer = single_answer.words(uncased=True)
- for i in range(0, len(text) - len(single_answer) + 1):
- if single_answer == text[i: i + len(single_answer)]:
- return True
- elif match_type == 'regex':
- # Answer is a regex
- for single_answer in answers:
- single_answer = _normalize(single_answer)
- if regex_match(text, single_answer):
- return True
- return False
- def regex_match(text, pattern):
- """Test if a regex pattern is contained within a text."""
- try:
- pattern = re.compile(
- pattern,
- flags=re.IGNORECASE + re.UNICODE + re.MULTILINE,
- )
- except BaseException:
- return False
- return pattern.search(text) is not None
- # function for the reader model answer validation
- def exact_match_score(prediction, ground_truth):
- return _normalize_answer(prediction) == _normalize_answer(ground_truth)
- def _normalize_answer(s):
- def remove_articles(text):
- return re.sub(r'\b(a|an|the)\b', ' ', text)
- def white_space_fix(text):
- return ' '.join(text.split())
- def remove_punc(text):
- exclude = set(string.punctuation)
- return ''.join(ch for ch in text if ch not in exclude)
- def lower(text):
- return text.lower()
- return white_space_fix(remove_articles(remove_punc(lower(s))))
- def _normalize(text):
- return unicodedata.normalize('NFD', text)
|