evaluate_utils.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. # coding=utf-8
  2. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import torch
  16. from megatron import get_args, print_rank_0
  17. from megatron.checkpointing import load_biencoder_checkpoint
  18. from megatron.data.orqa_wiki_dataset import get_open_retrieval_wiki_dataset
  19. from megatron.data.realm_index import OpenRetreivalDataStore, FaissMIPSIndex
  20. from megatron.model.biencoder_model import get_model_provider
  21. from megatron.training import get_model
  22. from tasks.orqa.unsupervised.nq import get_nq_dataset
  23. from tasks.orqa.unsupervised.nq import get_one_epoch_nq_dataloader
  24. from tasks.orqa.unsupervised.nq import process_nq_batch
  25. from tasks.orqa.unsupervised.qa_utils import calculate_matches
  26. class ORQAEvaluator(object):
  27. def __init__(self):
  28. args = get_args()
  29. self.embedding_size = args.hidden_size
  30. self.faiss_use_gpu = args.faiss_use_gpu
  31. self.evidence_embedder_obj = None
  32. self.evidence_dataset = None
  33. self.mips_index = None
  34. self.eval_dataset = None
  35. # Get Evidence (Wikipedia) dataset
  36. self.get_evidence_dataset()
  37. # Load query encoder checkpoint
  38. only_query_model = True
  39. if args.biencoder_shared_query_context_model:
  40. only_query_model = False
  41. model = get_model(get_model_provider(only_query_model=only_query_model,
  42. biencoder_shared_query_context_model=args.biencoder_shared_query_context_model))
  43. self.model = load_biencoder_checkpoint(model,
  44. only_query_model=only_query_model)
  45. assert len(self.model) == 1
  46. self.model[0].eval()
  47. # Load faiss indexer
  48. self.faiss_wrapper()
  49. def get_evidence_embedding(self):
  50. # This will load the embedding from the embedding path
  51. self.evidence_embedder_obj = OpenRetreivalDataStore(load_from_path=True)
  52. def get_evidence_dataset(self):
  53. self.evidence_dataset = get_open_retrieval_wiki_dataset()
  54. def faiss_wrapper(self):
  55. # Initialize FAISS wrapper on local rank = 0 as the evidence embeddings
  56. # is distributed over all the GPUs in a node and FAISS is not
  57. # thread-safe
  58. args = get_args()
  59. if args.local_rank == 0:
  60. # Get evidence embeddings computed using context encoder
  61. self.get_evidence_embedding()
  62. assert self.evidence_embedder_obj is not None
  63. self.mips_index = FaissMIPSIndex(embed_size=self.embedding_size,
  64. embed_data=self.evidence_embedder_obj,
  65. use_gpu=self.faiss_use_gpu)
  66. # Wait for the FAISS index to be initialized in all the nodes
  67. torch.distributed.barrier()
  68. def generate_query_vectors(self, qa_data, split):
  69. self.eval_dataset = get_nq_dataset(qa_data, split)
  70. dataloader = get_one_epoch_nq_dataloader(self.eval_dataset)
  71. query_vectors = []
  72. reference_list = []
  73. for batch in dataloader:
  74. # batch also has query_tokens and query_pad_data
  75. query_tokens, query_mask, query_types, \
  76. query_len, reference = process_nq_batch(batch)
  77. assert len(self.model) == 1
  78. unwrapped_model = self.model[0]
  79. while not hasattr(unwrapped_model, 'embed_text'):
  80. unwrapped_model = unwrapped_model.module
  81. with torch.no_grad():
  82. query_logits = unwrapped_model.embed_text(
  83. unwrapped_model.query_model, query_tokens,
  84. query_mask, query_types)
  85. reference_list.extend(reference)
  86. query_vectors.extend(query_logits.split(1, dim=0))
  87. if len(query_vectors) % 100 == 0:
  88. print_rank_0('Encoded queries {}'.format(len(query_vectors)))
  89. query_tensor = torch.cat(query_vectors, dim=0)
  90. print_rank_0('Total encoded queries tensor {}'.format(query_tensor.size()))
  91. assert query_tensor.size(0) == len(self.eval_dataset)
  92. return query_tensor, reference_list
  93. def evaluate(self, qa_data, split):
  94. args = get_args()
  95. query_tensor, reference_list = self.generate_query_vectors(qa_data, \
  96. split)
  97. local_rank = args.local_rank
  98. rank = torch.distributed.get_rank()
  99. device_count = torch.cuda.device_count()
  100. num_nodes = torch.distributed.get_world_size() // device_count
  101. node_id = rank // device_count
  102. for node in range(num_nodes):
  103. start_rank = node * device_count
  104. end_rank = (node + 1) * device_count
  105. ranks_list = list(range(start_rank, end_rank))
  106. node_group = torch.distributed.new_group(ranks=ranks_list)
  107. if node_id == node:
  108. device_start_rank = start_rank
  109. group = node_group
  110. input_ = torch.empty_like(query_tensor).copy_(query_tensor).detach_()
  111. tensor_list = [torch.empty_like(input_) for _ in range(device_count)]
  112. torch.distributed.all_gather(tensor_list, query_tensor, group=group)
  113. if local_rank == 0 and self.mips_index is not None:
  114. all_query_tensor = torch.cat(tensor_list, dim=0).contiguous()
  115. distance, topkindex = self.mips_index.search_mips_index(
  116. all_query_tensor, top_k=args.faiss_topk_retrievals,
  117. reconstruct=False)
  118. distance = torch.from_numpy(distance).cuda()
  119. topkindex = torch.LongTensor(topkindex).cuda()
  120. if local_rank != 0:
  121. distance = torch.empty(device_count * len(query_tensor), \
  122. args.faiss_topk_retrievals, dtype=torch.float32).cuda()
  123. topkindex = torch.empty(device_count * len(query_tensor), \
  124. args.faiss_topk_retrievals, dtype=torch.int64).cuda()
  125. torch.distributed.broadcast(distance, src=device_start_rank, \
  126. group=group)
  127. torch.distributed.broadcast(topkindex, src=device_start_rank, \
  128. group=group)
  129. distance = torch.split(distance, len(query_tensor), dim=0)\
  130. [local_rank]
  131. topkindex = torch.split(topkindex, len(query_tensor), dim=0)\
  132. [local_rank]
  133. top_ids_and_scores = []
  134. for darray, topkarray in zip(distance, topkindex):
  135. top_ids_and_scores.append((topkarray.tolist(), darray.tolist()))
  136. passages = self.evidence_dataset.id2text
  137. match_stats = calculate_matches(passages,
  138. reference_list,
  139. top_ids_and_scores,
  140. workers_num=args.num_workers,
  141. match_type=args.faiss_match)
  142. top_k_hits = match_stats.top_k_hits
  143. print_rank_0("{} SET RESULTS".format(split))
  144. print_rank_0("topk-{} documents hits {}".format(
  145. args.faiss_topk_retrievals, top_k_hits))
  146. top_k_hits = [v / len(top_ids_and_scores) for v in top_k_hits]
  147. print_rank_0("top-k documents hits accuracy {}".format(top_k_hits))
  148. for i in args.retriever_report_topk_accuracies:
  149. print_rank_0("top-{}: {:.2f}".format(i, top_k_hits[i-1] * 100))
  150. return