123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189 |
- # coding=utf-8
- # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import torch
- from megatron import get_args, print_rank_0
- from megatron.checkpointing import load_biencoder_checkpoint
- from megatron.data.orqa_wiki_dataset import get_open_retrieval_wiki_dataset
- from megatron.data.realm_index import OpenRetreivalDataStore, FaissMIPSIndex
- from megatron.model.biencoder_model import get_model_provider
- from megatron.training import get_model
- from tasks.orqa.unsupervised.nq import get_nq_dataset
- from tasks.orqa.unsupervised.nq import get_one_epoch_nq_dataloader
- from tasks.orqa.unsupervised.nq import process_nq_batch
- from tasks.orqa.unsupervised.qa_utils import calculate_matches
- class ORQAEvaluator(object):
- def __init__(self):
- args = get_args()
- self.embedding_size = args.hidden_size
- self.faiss_use_gpu = args.faiss_use_gpu
- self.evidence_embedder_obj = None
- self.evidence_dataset = None
- self.mips_index = None
- self.eval_dataset = None
- # Get Evidence (Wikipedia) dataset
- self.get_evidence_dataset()
- # Load query encoder checkpoint
- only_query_model = True
- if args.biencoder_shared_query_context_model:
- only_query_model = False
- model = get_model(get_model_provider(only_query_model=only_query_model,
- biencoder_shared_query_context_model=args.biencoder_shared_query_context_model))
- self.model = load_biencoder_checkpoint(model,
- only_query_model=only_query_model)
- assert len(self.model) == 1
- self.model[0].eval()
- # Load faiss indexer
- self.faiss_wrapper()
- def get_evidence_embedding(self):
- # This will load the embedding from the embedding path
- self.evidence_embedder_obj = OpenRetreivalDataStore(load_from_path=True)
- def get_evidence_dataset(self):
- self.evidence_dataset = get_open_retrieval_wiki_dataset()
- def faiss_wrapper(self):
- # Initialize FAISS wrapper on local rank = 0 as the evidence embeddings
- # is distributed over all the GPUs in a node and FAISS is not
- # thread-safe
- args = get_args()
- if args.local_rank == 0:
- # Get evidence embeddings computed using context encoder
- self.get_evidence_embedding()
- assert self.evidence_embedder_obj is not None
- self.mips_index = FaissMIPSIndex(embed_size=self.embedding_size,
- embed_data=self.evidence_embedder_obj,
- use_gpu=self.faiss_use_gpu)
- # Wait for the FAISS index to be initialized in all the nodes
- torch.distributed.barrier()
- def generate_query_vectors(self, qa_data, split):
- self.eval_dataset = get_nq_dataset(qa_data, split)
- dataloader = get_one_epoch_nq_dataloader(self.eval_dataset)
- query_vectors = []
- reference_list = []
- for batch in dataloader:
- # batch also has query_tokens and query_pad_data
- query_tokens, query_mask, query_types, \
- query_len, reference = process_nq_batch(batch)
- assert len(self.model) == 1
- unwrapped_model = self.model[0]
- while not hasattr(unwrapped_model, 'embed_text'):
- unwrapped_model = unwrapped_model.module
- with torch.no_grad():
- query_logits = unwrapped_model.embed_text(
- unwrapped_model.query_model, query_tokens,
- query_mask, query_types)
- reference_list.extend(reference)
- query_vectors.extend(query_logits.split(1, dim=0))
- if len(query_vectors) % 100 == 0:
- print_rank_0('Encoded queries {}'.format(len(query_vectors)))
- query_tensor = torch.cat(query_vectors, dim=0)
- print_rank_0('Total encoded queries tensor {}'.format(query_tensor.size()))
- assert query_tensor.size(0) == len(self.eval_dataset)
- return query_tensor, reference_list
- def evaluate(self, qa_data, split):
- args = get_args()
- query_tensor, reference_list = self.generate_query_vectors(qa_data, \
- split)
- local_rank = args.local_rank
- rank = torch.distributed.get_rank()
- device_count = torch.cuda.device_count()
- num_nodes = torch.distributed.get_world_size() // device_count
- node_id = rank // device_count
- for node in range(num_nodes):
- start_rank = node * device_count
- end_rank = (node + 1) * device_count
- ranks_list = list(range(start_rank, end_rank))
- node_group = torch.distributed.new_group(ranks=ranks_list)
- if node_id == node:
- device_start_rank = start_rank
- group = node_group
-
- input_ = torch.empty_like(query_tensor).copy_(query_tensor).detach_()
- tensor_list = [torch.empty_like(input_) for _ in range(device_count)]
- torch.distributed.all_gather(tensor_list, query_tensor, group=group)
- if local_rank == 0 and self.mips_index is not None:
- all_query_tensor = torch.cat(tensor_list, dim=0).contiguous()
- distance, topkindex = self.mips_index.search_mips_index(
- all_query_tensor, top_k=args.faiss_topk_retrievals,
- reconstruct=False)
- distance = torch.from_numpy(distance).cuda()
- topkindex = torch.LongTensor(topkindex).cuda()
- if local_rank != 0:
- distance = torch.empty(device_count * len(query_tensor), \
- args.faiss_topk_retrievals, dtype=torch.float32).cuda()
- topkindex = torch.empty(device_count * len(query_tensor), \
- args.faiss_topk_retrievals, dtype=torch.int64).cuda()
- torch.distributed.broadcast(distance, src=device_start_rank, \
- group=group)
- torch.distributed.broadcast(topkindex, src=device_start_rank, \
- group=group)
- distance = torch.split(distance, len(query_tensor), dim=0)\
- [local_rank]
- topkindex = torch.split(topkindex, len(query_tensor), dim=0)\
- [local_rank]
- top_ids_and_scores = []
- for darray, topkarray in zip(distance, topkindex):
- top_ids_and_scores.append((topkarray.tolist(), darray.tolist()))
- passages = self.evidence_dataset.id2text
- match_stats = calculate_matches(passages,
- reference_list,
- top_ids_and_scores,
- workers_num=args.num_workers,
- match_type=args.faiss_match)
- top_k_hits = match_stats.top_k_hits
- print_rank_0("{} SET RESULTS".format(split))
- print_rank_0("topk-{} documents hits {}".format(
- args.faiss_topk_retrievals, top_k_hits))
- top_k_hits = [v / len(top_ids_and_scores) for v in top_k_hits]
- print_rank_0("top-k documents hits accuracy {}".format(top_k_hits))
- for i in args.retriever_report_topk_accuracies:
- print_rank_0("top-{}: {:.2f}".format(i, top_k_hits[i-1] * 100))
- return
|