evaluate.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. # coding=utf-8
  2. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """GPT zero-shot evaluation."""
  16. import math
  17. import torch
  18. from megatron import get_args
  19. from megatron import print_rank_0, is_last_rank
  20. from megatron import get_tokenizer
  21. from megatron import mpu
  22. from megatron.checkpointing import load_checkpoint
  23. from megatron.model import GPTModel
  24. from megatron.training import get_model
  25. from megatron.utils import get_ltor_masks_and_position_ids, unwrap_model
  26. from megatron.p2p_communication import recv_forward, send_forward
  27. from tasks.finetune_utils import build_data_loader
  28. from .datasets import build_dataset
  29. # These are needed to unwrap the model, would be nice to put these in megatron.utils if possible?
  30. from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
  31. from megatron.model import DistributedDataParallel as LocalDDP
  32. from megatron.model import Float16Module
  33. def get_model_provider(eval_metric):
  34. """Based on evaluation metric set the parallel-output flag and
  35. return the model provider."""
  36. def model_provider(pre_process=True, post_process=True):
  37. """Build the model."""
  38. if eval_metric == 'loss':
  39. parallel_output = True
  40. elif eval_metric == 'accuracy':
  41. parallel_output = False
  42. else:
  43. raise NotImplementedError('output type for {} evaluation metric '
  44. 'is not supported.'.format(eval_metric))
  45. print_rank_0('building GPT model ...')
  46. model = GPTModel(num_tokentypes=0, parallel_output=parallel_output,
  47. pre_process=pre_process, post_process=post_process)
  48. return model
  49. return model_provider
  50. def process_batch(batch):
  51. """Process batch and produce inputs for the model."""
  52. args = get_args()
  53. tokenizer = get_tokenizer()
  54. loss_mask = batch['pad_mask'].long().cuda().contiguous().byte()
  55. tokens_ = batch['text'].long().cuda().contiguous()
  56. labels = tokens_[:, 1:].contiguous()
  57. tokens = tokens_[:, :-1].contiguous()
  58. # Get the masks and postition ids.
  59. attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
  60. tokens,
  61. tokenizer.eod,
  62. args.reset_position_ids,
  63. args.reset_attention_mask,
  64. args.eod_mask_loss)
  65. return tokens, labels, attention_mask, position_ids, loss_mask
  66. def forward_step(batch, model, eval_metric):
  67. """Forward step."""
  68. # Get the batch.
  69. tokens, labels, attention_mask, position_ids, loss_mask = process_batch(
  70. batch)
  71. # Tell the model what our actual batch size will be
  72. args = get_args()
  73. args.micro_batch_size = len(labels)
  74. input_tensor = recv_forward()
  75. # Forward pass through the model.
  76. unwrapped_model = unwrap_model(
  77. model, (torchDDP, LocalDDP, Float16Module))
  78. unwrapped_model.set_input_tensor(input_tensor)
  79. output = model(tokens, position_ids, attention_mask)
  80. send_forward(output)
  81. if mpu.is_pipeline_last_stage():
  82. # For loss, return the unreduced loss.
  83. if eval_metric == 'loss':
  84. losses = mpu.vocab_parallel_cross_entropy(
  85. output.contiguous().float(), labels.contiguous())
  86. loss = torch.sum(
  87. losses.view(-1) * loss_mask.contiguous().view(-1).float())
  88. return loss
  89. # For accuracy, return the number of correctly predicted samples.
  90. if eval_metric == 'accuracy':
  91. outputs = torch.argmax(output, -1)
  92. correct = (outputs == labels).float()
  93. correct[(1 - loss_mask).bool()] = 1
  94. correct = correct.prod(-1)
  95. return correct.sum()
  96. raise NotImplementedError('forward method for evaluation metric {} '
  97. 'is not implemented.'.format(eval_metric))
  98. return None
  99. def evaluate(data_loader, model, eval_metric):
  100. """Evaluation."""
  101. args = get_args()
  102. # Turn on evaluation mode which disables dropout.
  103. model.eval()
  104. total_output = 0.0
  105. with torch.no_grad():
  106. # For all the batches in the dataset.
  107. for iteration, batch in enumerate(data_loader):
  108. if iteration % args.log_interval == 0:
  109. print_rank_0('> working on iteration: {}'.format(iteration))
  110. # Forward evaluation.
  111. output = forward_step(batch, model, eval_metric)
  112. # Reduce across processes.
  113. if mpu.is_pipeline_last_stage():
  114. torch.distributed.all_reduce(output,
  115. group=mpu.get_data_parallel_group())
  116. total_output += output
  117. return total_output
  118. def evaluate_and_print_results(task, data_loader, model, eval_metric):
  119. """Evaluate and print results on screen."""
  120. # Evaluate and get results.
  121. output = evaluate(data_loader, model, eval_metric)
  122. string = ' validation results on {} | '.format(task)
  123. if is_last_rank():
  124. if eval_metric == 'loss':
  125. num_tokenized_tokens = data_loader.dataset.num_tokenized_tokens
  126. num_original_tokens = data_loader.dataset.num_original_tokens
  127. val_loss = output / (num_tokenized_tokens - 1)
  128. ppl = math.exp(min(20, val_loss))
  129. token_ratio = (num_tokenized_tokens - 1) / (num_original_tokens - 1)
  130. adjusted_ppl = math.exp(min(20, val_loss * token_ratio))
  131. string += 'avg loss: {:.4E} | '.format(val_loss)
  132. string += 'ppl: {:.4E} | '.format(ppl)
  133. string += 'adjusted ppl: {:.4E} | '.format(adjusted_ppl)
  134. string += 'token ratio: {} |'.format(token_ratio)
  135. elif eval_metric == 'accuracy':
  136. num_examples = len(data_loader.dataset)
  137. acc = output / num_examples
  138. string += 'number correct: {:.4E} | '.format(output)
  139. string += 'total examples: {:.4E} | '.format(num_examples)
  140. string += 'avg accuracy: {:.4E}'.format(acc)
  141. else:
  142. raise NotImplementedError('evaluation method for {} metric is not '
  143. 'implemented yet.'.format(eval_metric))
  144. length = len(string) + 1
  145. print('-' * length)
  146. print(string)
  147. print('-' * length)
  148. def main():
  149. """Main program."""
  150. args = get_args()
  151. if args.num_layers_per_virtual_pipeline_stage is not None:
  152. print("Interleaved pipeline schedule is not yet supported for text generation.")
  153. exit()
  154. if args.task == 'LAMBADA':
  155. eval_metric = 'accuracy'
  156. elif args.task == 'WIKITEXT103':
  157. eval_metric = 'loss'
  158. else:
  159. raise NotImplementedError('{} task is not implemented.'.format(
  160. args.task))
  161. # Set up model and load checkpoint.
  162. model = get_model(get_model_provider(eval_metric))
  163. if args.load is not None:
  164. _ = load_checkpoint(model, None, None)
  165. assert len(model) == 1, "Above condition should have caught this"
  166. model = model[0]
  167. # Data stuff.
  168. dataset = build_dataset(args.task)
  169. dataloader = build_data_loader(dataset, args.micro_batch_size,
  170. args.num_workers, drop_last=False)
  171. # Run evaluation.
  172. evaluate_and_print_results(args.task, dataloader, model, eval_metric)
  173. print_rank_0('done :-)')