123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252 |
- # coding=utf-8
- # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """ORQA finetuning/evaluation."""
- from functools import partial
- import sys
- import math
- import torch
- import torch.nn.functional as F
- from megatron import get_args, get_timers, get_tokenizer
- from megatron import mpu, print_rank_0
- from megatron.indexer import IndexBuilder
- from megatron.model.biencoder_model import biencoder_model_provider
- from megatron.utils import average_losses_across_data_parallel_group
- from pretrain_ict import get_group_world_size_rank
- from tasks.finetune_utils import finetune
- from tasks.orqa.supervised.eval_utils import accuracy_func_provider
- from tasks.orqa.supervised.eval_utils import process_batch, task_collate_fn
- from tasks.orqa.evaluate_utils import ORQAEvaluator
- # input_ is a 2D tensor
- def check_and_append_tensor_for_gather(group, rank, world_size, input_):
- # gather the size of the first dimension of the tensor from all ranks
- current_length = input_.size()[0]
- first_dim = torch.tensor([[current_length]],
- device=torch.cuda.current_device())
- input_list = [torch.empty_like(first_dim) for _ in range(world_size)]
- input_list[rank].copy_(first_dim)
- torch.distributed.all_gather(input_list, first_dim, group=group)
- all_input_list = torch.cat(input_list, dim=0).contiguous()
- max_length = torch.max(all_input_list)
- # if the size are different than the max, extend the tensor
- # accordingly
- if max_length > current_length:
- padding=tuple([0] * (input_.dim() * 2 - 1)) + \
- tuple([max_length - current_length])
- input_ = F.pad(input=input_, pad=padding)
- return input_
- def orqa(Dataset):
- def cross_entropy_forward_step(batch, model):
- """Simple forward step with cross-entropy loss."""
- timers = get_timers()
- tokenizer = get_tokenizer()
- # Get the batch.
- timers('batch generator').start()
- try:
- batch_ = next(batch)
- except BaseException:
- batch_ = batch
- group, rank, world_size = get_group_world_size_rank()
- query_tokens, query_mask, query_types, query_pad_mask, \
- context_tokens, context_mask, context_types, context_pad_mask, \
- neg_context_tokens, neg_context_mask, neg_context_types, \
- reference = process_batch(batch_)
- timers('batch generator').stop()
- local_batch_size = query_tokens.shape[0]
- # Text representation of query and context
- query_list, context_list = [], []
- for i in range(local_batch_size):
- query_list.append(tokenizer.decode(query_tokens[i].tolist()))
- context_list.append(tokenizer.decode(context_tokens[i].tolist()))
- if neg_context_tokens is not None:
- neg_context_tokens = check_and_append_tensor_for_gather(group,
- rank, world_size, neg_context_tokens)
- neg_context_mask = check_and_append_tensor_for_gather(group,
- rank, world_size, neg_context_mask)
- neg_context_types = check_and_append_tensor_for_gather(group,
- rank, world_size, neg_context_types)
- if neg_context_tokens is not None:
- context_tokens = torch.cat([context_tokens, neg_context_tokens])
- context_mask = torch.cat([context_mask, neg_context_mask])
- context_types = torch.cat([context_types, neg_context_types])
- # Forward model.
- output_tensor = model(query_tokens, query_mask,
- query_types, context_tokens,
- context_mask, context_types)
- return output_tensor, partial(cross_entropy_loss_func, query_tokens, context_tokens)
- def cross_entropy_loss_func(query_tokens, context_tokens, output_tensor):
- args = get_args()
- local_batch_size = query_tokens.shape[0]
- group, rank, world_size = get_group_world_size_rank()
- # recall we assert that model_parallel_size == 1
- global_batch_size = world_size * local_batch_size
- query_logits, context_logits = output_tensor
- if world_size > 1:
- input_ = torch.empty_like(context_logits).copy_(\
- context_logits).detach_()
- tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
- tensor_list[rank].copy_(input_)
- torch.distributed.all_gather(tensor_list, input_, group=group)
- # Check if all-gather happens in order
- assert tensor_list[rank].sum().item() == \
- context_logits.sum().item()
- # Preserves the gradient
- tensor_list[rank] = context_logits
- all_context_logits = torch.cat(tensor_list, dim=0).contiguous()
- # Query tensors
- input_ = torch.empty_like(query_logits).copy_(\
- query_logits).detach_()
- tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
- tensor_list[rank].copy_(input_)
- torch.distributed.all_gather(tensor_list, input_, group=group)
- # Check if all-gather happens in order
- assert tensor_list[rank].sum().item() == query_logits.sum().item()
- # Preserves the gradient
- tensor_list[rank] = query_logits
- all_query_logits = torch.cat(tensor_list, dim=0).contiguous()
- else:
- all_query_logits = query_logits
- all_context_logits = context_logits
- retrieval_scores = torch.matmul(all_query_logits,
- torch.transpose(all_context_logits, 0, 1))
- # Scaling the retrieval scores
- if args.retriever_score_scaling:
- retrieval_scores = retrieval_scores / math.sqrt(args.hidden_size)
- if args.train_with_neg:
- # if the world size is 3, local batch size is 4, and
- # local context size is 8, what we want is
- # labels = [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19]
- labels = []
- local_context_size = context_tokens.shape[0]
- for i in range(world_size):
- j = i * local_context_size
- labels.extend(list(range(j, j + local_batch_size)))
- labels = torch.LongTensor(labels).cuda()
- assert len(labels) == global_batch_size
- else:
- labels = torch.arange(global_batch_size).long().cuda()
- # Cross-entropy loss.
- softmax_scores = F.log_softmax(retrieval_scores, dim=1)
- loss = F.nll_loss(softmax_scores, labels, reduction='mean')
- max_score, max_idxs = torch.max(softmax_scores, 1)
- correct_predictions_count = (max_idxs == labels).sum().float()
- # Reduce loss for logging.
- reduced_loss = average_losses_across_data_parallel_group([loss, \
- correct_predictions_count])
- # Loss scaling for correct losses in Supervised Retrieval
- loss = loss * mpu.get_data_parallel_world_size()
- return loss, {'lm loss': reduced_loss[0],
- 'correct_prediction_count': reduced_loss[1]}
- def train_valid_datasets_provider():
- """Build train and validation dataset."""
- args = get_args()
- tokenizer = get_tokenizer()
- train_dataset = Dataset('training',
- args.train_data,
- tokenizer,
- args.retriever_seq_length,
- evaluate=False)
- valid_dataset = Dataset('validation',
- args.valid_data,
- tokenizer,
- args.retriever_seq_length,
- evaluate=True)
- return train_dataset, valid_dataset
- def model_provider(pre_process=True, post_process=True):
- """Build the model."""
- args = get_args()
- print_rank_0('building retriever model for {} ...'.format(args.task))
- model = biencoder_model_provider(only_context_model=False,
- only_query_model=False,
- biencoder_shared_query_context_model=\
- args.biencoder_shared_query_context_model,
- pre_process=pre_process, post_process=post_process)
- return model
- def single_dataset_provider(datapath):
- args = get_args()
- tokenizer = get_tokenizer()
- name = datapath[0].split('/')[-1].split('.')[0]
- return Dataset(name,
- datapath,
- tokenizer,
- args.retriever_seq_length,
- evaluate=True)
- def metrics_func_provider():
- """Provide metrics callback function."""
- return accuracy_func_provider(single_dataset_provider)
- """Finetune/evaluate."""
- finetune(train_valid_datasets_provider,
- model_provider,
- forward_step=cross_entropy_forward_step,
- end_of_epoch_callback_provider=metrics_func_provider,
- task_collate_fn=task_collate_fn)
- def main():
- args = get_args()
- if args.task == 'RET-FINETUNE-NQ':
- from tasks.orqa.supervised.data import NQSupervisedDataset as Dataset
- else:
- raise NotImplementedError('ORQA task {} is not implemented.'.format(
- args.task))
- orqa(Dataset)
|