finetune.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. # coding=utf-8
  2. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """ORQA finetuning/evaluation."""
  16. from functools import partial
  17. import sys
  18. import math
  19. import torch
  20. import torch.nn.functional as F
  21. from megatron import get_args, get_timers, get_tokenizer
  22. from megatron import mpu, print_rank_0
  23. from megatron.indexer import IndexBuilder
  24. from megatron.model.biencoder_model import biencoder_model_provider
  25. from megatron.utils import average_losses_across_data_parallel_group
  26. from pretrain_ict import get_group_world_size_rank
  27. from tasks.finetune_utils import finetune
  28. from tasks.orqa.supervised.eval_utils import accuracy_func_provider
  29. from tasks.orqa.supervised.eval_utils import process_batch, task_collate_fn
  30. from tasks.orqa.evaluate_utils import ORQAEvaluator
  31. # input_ is a 2D tensor
  32. def check_and_append_tensor_for_gather(group, rank, world_size, input_):
  33. # gather the size of the first dimension of the tensor from all ranks
  34. current_length = input_.size()[0]
  35. first_dim = torch.tensor([[current_length]],
  36. device=torch.cuda.current_device())
  37. input_list = [torch.empty_like(first_dim) for _ in range(world_size)]
  38. input_list[rank].copy_(first_dim)
  39. torch.distributed.all_gather(input_list, first_dim, group=group)
  40. all_input_list = torch.cat(input_list, dim=0).contiguous()
  41. max_length = torch.max(all_input_list)
  42. # if the size are different than the max, extend the tensor
  43. # accordingly
  44. if max_length > current_length:
  45. padding=tuple([0] * (input_.dim() * 2 - 1)) + \
  46. tuple([max_length - current_length])
  47. input_ = F.pad(input=input_, pad=padding)
  48. return input_
  49. def orqa(Dataset):
  50. def cross_entropy_forward_step(batch, model):
  51. """Simple forward step with cross-entropy loss."""
  52. timers = get_timers()
  53. tokenizer = get_tokenizer()
  54. # Get the batch.
  55. timers('batch generator').start()
  56. try:
  57. batch_ = next(batch)
  58. except BaseException:
  59. batch_ = batch
  60. group, rank, world_size = get_group_world_size_rank()
  61. query_tokens, query_mask, query_types, query_pad_mask, \
  62. context_tokens, context_mask, context_types, context_pad_mask, \
  63. neg_context_tokens, neg_context_mask, neg_context_types, \
  64. reference = process_batch(batch_)
  65. timers('batch generator').stop()
  66. local_batch_size = query_tokens.shape[0]
  67. # Text representation of query and context
  68. query_list, context_list = [], []
  69. for i in range(local_batch_size):
  70. query_list.append(tokenizer.decode(query_tokens[i].tolist()))
  71. context_list.append(tokenizer.decode(context_tokens[i].tolist()))
  72. if neg_context_tokens is not None:
  73. neg_context_tokens = check_and_append_tensor_for_gather(group,
  74. rank, world_size, neg_context_tokens)
  75. neg_context_mask = check_and_append_tensor_for_gather(group,
  76. rank, world_size, neg_context_mask)
  77. neg_context_types = check_and_append_tensor_for_gather(group,
  78. rank, world_size, neg_context_types)
  79. if neg_context_tokens is not None:
  80. context_tokens = torch.cat([context_tokens, neg_context_tokens])
  81. context_mask = torch.cat([context_mask, neg_context_mask])
  82. context_types = torch.cat([context_types, neg_context_types])
  83. # Forward model.
  84. output_tensor = model(query_tokens, query_mask,
  85. query_types, context_tokens,
  86. context_mask, context_types)
  87. return output_tensor, partial(cross_entropy_loss_func, query_tokens, context_tokens)
  88. def cross_entropy_loss_func(query_tokens, context_tokens, output_tensor):
  89. args = get_args()
  90. local_batch_size = query_tokens.shape[0]
  91. group, rank, world_size = get_group_world_size_rank()
  92. # recall we assert that model_parallel_size == 1
  93. global_batch_size = world_size * local_batch_size
  94. query_logits, context_logits = output_tensor
  95. if world_size > 1:
  96. input_ = torch.empty_like(context_logits).copy_(\
  97. context_logits).detach_()
  98. tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
  99. tensor_list[rank].copy_(input_)
  100. torch.distributed.all_gather(tensor_list, input_, group=group)
  101. # Check if all-gather happens in order
  102. assert tensor_list[rank].sum().item() == \
  103. context_logits.sum().item()
  104. # Preserves the gradient
  105. tensor_list[rank] = context_logits
  106. all_context_logits = torch.cat(tensor_list, dim=0).contiguous()
  107. # Query tensors
  108. input_ = torch.empty_like(query_logits).copy_(\
  109. query_logits).detach_()
  110. tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
  111. tensor_list[rank].copy_(input_)
  112. torch.distributed.all_gather(tensor_list, input_, group=group)
  113. # Check if all-gather happens in order
  114. assert tensor_list[rank].sum().item() == query_logits.sum().item()
  115. # Preserves the gradient
  116. tensor_list[rank] = query_logits
  117. all_query_logits = torch.cat(tensor_list, dim=0).contiguous()
  118. else:
  119. all_query_logits = query_logits
  120. all_context_logits = context_logits
  121. retrieval_scores = torch.matmul(all_query_logits,
  122. torch.transpose(all_context_logits, 0, 1))
  123. # Scaling the retrieval scores
  124. if args.retriever_score_scaling:
  125. retrieval_scores = retrieval_scores / math.sqrt(args.hidden_size)
  126. if args.train_with_neg:
  127. # if the world size is 3, local batch size is 4, and
  128. # local context size is 8, what we want is
  129. # labels = [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19]
  130. labels = []
  131. local_context_size = context_tokens.shape[0]
  132. for i in range(world_size):
  133. j = i * local_context_size
  134. labels.extend(list(range(j, j + local_batch_size)))
  135. labels = torch.LongTensor(labels).cuda()
  136. assert len(labels) == global_batch_size
  137. else:
  138. labels = torch.arange(global_batch_size).long().cuda()
  139. # Cross-entropy loss.
  140. softmax_scores = F.log_softmax(retrieval_scores, dim=1)
  141. loss = F.nll_loss(softmax_scores, labels, reduction='mean')
  142. max_score, max_idxs = torch.max(softmax_scores, 1)
  143. correct_predictions_count = (max_idxs == labels).sum().float()
  144. # Reduce loss for logging.
  145. reduced_loss = average_losses_across_data_parallel_group([loss, \
  146. correct_predictions_count])
  147. # Loss scaling for correct losses in Supervised Retrieval
  148. loss = loss * mpu.get_data_parallel_world_size()
  149. return loss, {'lm loss': reduced_loss[0],
  150. 'correct_prediction_count': reduced_loss[1]}
  151. def train_valid_datasets_provider():
  152. """Build train and validation dataset."""
  153. args = get_args()
  154. tokenizer = get_tokenizer()
  155. train_dataset = Dataset('training',
  156. args.train_data,
  157. tokenizer,
  158. args.retriever_seq_length,
  159. evaluate=False)
  160. valid_dataset = Dataset('validation',
  161. args.valid_data,
  162. tokenizer,
  163. args.retriever_seq_length,
  164. evaluate=True)
  165. return train_dataset, valid_dataset
  166. def model_provider(pre_process=True, post_process=True):
  167. """Build the model."""
  168. args = get_args()
  169. print_rank_0('building retriever model for {} ...'.format(args.task))
  170. model = biencoder_model_provider(only_context_model=False,
  171. only_query_model=False,
  172. biencoder_shared_query_context_model=\
  173. args.biencoder_shared_query_context_model,
  174. pre_process=pre_process, post_process=post_process)
  175. return model
  176. def single_dataset_provider(datapath):
  177. args = get_args()
  178. tokenizer = get_tokenizer()
  179. name = datapath[0].split('/')[-1].split('.')[0]
  180. return Dataset(name,
  181. datapath,
  182. tokenizer,
  183. args.retriever_seq_length,
  184. evaluate=True)
  185. def metrics_func_provider():
  186. """Provide metrics callback function."""
  187. return accuracy_func_provider(single_dataset_provider)
  188. """Finetune/evaluate."""
  189. finetune(train_valid_datasets_provider,
  190. model_provider,
  191. forward_step=cross_entropy_forward_step,
  192. end_of_epoch_callback_provider=metrics_func_provider,
  193. task_collate_fn=task_collate_fn)
  194. def main():
  195. args = get_args()
  196. if args.task == 'RET-FINETUNE-NQ':
  197. from tasks.orqa.supervised.data import NQSupervisedDataset as Dataset
  198. else:
  199. raise NotImplementedError('ORQA task {} is not implemented.'.format(
  200. args.task))
  201. orqa(Dataset)