pretrain_ict.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. # coding=utf-8
  2. # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Pretrain BERT for Inverse Cloze Task"""
  16. from functools import partial
  17. import math
  18. import torch
  19. import torch.distributed as dist
  20. import torch.nn.functional as F
  21. from megatron import get_args
  22. from megatron import print_rank_0
  23. from megatron import get_timers
  24. from megatron import mpu
  25. from megatron.data.biencoder_dataset_utils import get_ict_batch
  26. from megatron.data.dataset_utils import build_train_valid_test_datasets
  27. from megatron.model.biencoder_model import biencoder_model_provider
  28. from megatron.training import pretrain
  29. from megatron.utils import average_losses_across_data_parallel_group
  30. def pretrain_ict_model_provider(pre_process=True, post_process=True):
  31. args = get_args()
  32. model = biencoder_model_provider(
  33. only_context_model=False,
  34. only_query_model=False,
  35. biencoder_shared_query_context_model=\
  36. args.biencoder_shared_query_context_model,
  37. pre_process=pre_process, post_process=post_process)
  38. return model
  39. def get_group_world_size_rank():
  40. group = mpu.get_data_parallel_group()
  41. rank = torch.distributed.get_rank(group=group)
  42. world_size = torch.distributed.get_world_size(group=group)
  43. return group, rank, world_size
  44. class AllgatherFromDataParallelRegion(torch.autograd.Function):
  45. @staticmethod
  46. def forward(ctx, input_):
  47. assert input_.dim() == 2
  48. group, rank, world_size = get_group_world_size_rank()
  49. tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
  50. tensor_list[rank] = input_
  51. torch.distributed.all_gather(tensor_list, input_, group=group)
  52. output = torch.cat(tensor_list, dim=0).contiguous()
  53. return output
  54. @staticmethod
  55. def backward(ctx, grad_output):
  56. group, rank, world_size = get_group_world_size_rank()
  57. assert grad_output.shape[0] % world_size == 0
  58. dim_size = grad_output.shape[0] // world_size
  59. output_list = torch.split(grad_output, dim_size, dim=0)
  60. # get chunk from this rank
  61. output = output_list[rank].contiguous()
  62. return output
  63. def loss_func(output_tensor):
  64. args = get_args()
  65. query_logits, context_logits = output_tensor
  66. micro_batch_size = query_logits.shape[0]
  67. # recall we assert that tensor_model_parallel_size == 1
  68. assert mpu.get_tensor_model_parallel_world_size() == 1, \
  69. "Model parallel size > 1 not supported for ICT"
  70. global_batch_size = dist.get_world_size() * micro_batch_size
  71. all_query_logits = AllgatherFromDataParallelRegion.apply(query_logits)
  72. all_context_logits = AllgatherFromDataParallelRegion.apply(context_logits)
  73. # scores are inner products between query and context embeddings
  74. retrieval_scores = torch.matmul(all_query_logits,
  75. torch.transpose(all_context_logits, 0, 1))
  76. # scaling the retriever scores
  77. if args.retriever_score_scaling:
  78. retrieval_scores = retrieval_scores / math.sqrt(args.hidden_size)
  79. softmax_scores = F.log_softmax(retrieval_scores, dim=1)
  80. sorted_vals, sorted_indices = torch.topk(softmax_scores,
  81. k=softmax_scores.shape[1], sorted=True)
  82. def topk_accuracy(k):
  83. return torch.cuda.FloatTensor([sum([int(i in sorted_indices[i, :k]) \
  84. for i in range(global_batch_size)]) / global_batch_size])
  85. topk_accs = [topk_accuracy(int(k)) for k in args.retriever_report_topk_accuracies]
  86. labels = torch.arange(global_batch_size).long().cuda()
  87. loss = F.nll_loss(softmax_scores, labels, reduction='mean')
  88. reduced_losses = average_losses_across_data_parallel_group([loss, *topk_accs])
  89. # Scale the retrieval loss
  90. loss = loss * mpu.get_data_parallel_world_size()
  91. # create stats_dict with retrieval loss and all specified top-k accuracies
  92. topk_acc_dict = {'top{}_acc'.format(k): v * 100 for k, v in \
  93. zip(args.retriever_report_topk_accuracies, reduced_losses[1:])}
  94. stats_dict = dict(loss=reduced_losses[0], **topk_acc_dict)
  95. return loss, stats_dict
  96. def forward_step(data_iterator, model):
  97. """Forward step."""
  98. args = get_args()
  99. timers = get_timers()
  100. # Get the batch.
  101. timers('batch-generator').start()
  102. query_tokens, query_mask, \
  103. context_tokens, context_mask, context_indices = get_ict_batch(data_iterator)
  104. timers('batch-generator').stop()
  105. # Query and Context Types
  106. query_types = torch.cuda.LongTensor(*query_tokens.shape).fill_(0)
  107. context_types = torch.cuda.LongTensor(*context_tokens.shape).fill_(0)
  108. # Forward model.
  109. output_tensor = model(query_tokens, query_mask, query_types, context_tokens,
  110. context_mask, context_types)
  111. return output_tensor, partial(loss_func)
  112. def train_valid_test_datasets_provider(train_val_test_num_samples):
  113. """Build train, valid and test datasets."""
  114. args = get_args()
  115. print_rank_0('> building train, validation, and test datasets '
  116. 'for BERT ICT...')
  117. train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
  118. data_prefix=args.data_path,
  119. data_impl=args.data_impl,
  120. splits_string=args.split,
  121. train_valid_test_num_samples=train_val_test_num_samples,
  122. max_seq_length=args.seq_length,
  123. masked_lm_prob=args.mask_prob,
  124. short_seq_prob=args.short_seq_prob,
  125. seed=args.seed,
  126. skip_warmup=(not args.mmap_warmup),
  127. binary_head=False,
  128. dataset_type='ict')
  129. print_rank_0("> finished creating BERT ICT datasets ...")
  130. return train_ds, valid_ds, test_ds
  131. if __name__ == "__main__":
  132. pretrain(train_valid_test_datasets_provider,
  133. pretrain_ict_model_provider,
  134. forward_step,
  135. args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})