realm_dataset_utils.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. import os
  2. import time
  3. import numpy as np
  4. import torch
  5. from megatron import mpu, print_rank_0
  6. from megatron.data.dataset_utils import create_masked_lm_predictions, pad_and_convert_to_numpy
  7. from megatron import get_args, get_tokenizer, print_rank_0, mpu
  8. def get_one_epoch_dataloader(dataset, micro_batch_size=None):
  9. """Specifically one epoch to be used in an indexing job."""
  10. args = get_args()
  11. world_size = mpu.get_data_parallel_world_size()
  12. rank = mpu.get_data_parallel_rank()
  13. if micro_batch_size is None:
  14. micro_batch_size = args.micro_batch_size
  15. global_batch_size = micro_batch_size * world_size
  16. num_workers = args.num_workers
  17. sampler = torch.utils.data.SequentialSampler(dataset)
  18. # importantly, drop_last must be False to get all the data.
  19. assert False, 'DistributedBatchSampler deprecated, change the implementation'
  20. from megatron.data.samplers import DistributedBatchSampler
  21. batch_sampler = DistributedBatchSampler(sampler,
  22. batch_size=global_batch_size,
  23. drop_last=False,
  24. rank=rank,
  25. world_size=world_size)
  26. return torch.utils.data.DataLoader(dataset,
  27. batch_sampler=batch_sampler,
  28. num_workers=num_workers,
  29. pin_memory=True)
  30. def get_ict_batch(data_iterator):
  31. # Items and their type.
  32. keys = ['query_tokens', 'query_pad_mask',
  33. 'block_tokens', 'block_pad_mask', 'block_data']
  34. datatype = torch.int64
  35. # Broadcast data.
  36. if data_iterator is None:
  37. data = None
  38. else:
  39. data = next(data_iterator)
  40. data_b = mpu.broadcast_data(keys, data, datatype)
  41. # Unpack.
  42. query_tokens = data_b['query_tokens'].long()
  43. query_pad_mask = data_b['query_pad_mask'].long()
  44. block_tokens = data_b['block_tokens'].long()
  45. block_pad_mask = data_b['block_pad_mask'].long()
  46. block_indices = data_b['block_data'].long()
  47. return query_tokens, query_pad_mask,\
  48. block_tokens, block_pad_mask, block_indices
  49. def join_str_list(str_list):
  50. """Join a list of strings, handling spaces appropriately"""
  51. result = ""
  52. for s in str_list:
  53. if s.startswith("##"):
  54. result += s[2:]
  55. else:
  56. result += " " + s
  57. return result
  58. class BlockSampleData(object):
  59. """A struct for fully describing a fixed-size block of data as used in REALM
  60. :param start_idx: for first sentence of the block
  61. :param end_idx: for last sentence of the block (may be partially truncated in sample construction)
  62. :param doc_idx: the index of the document from which the block comes in the original indexed dataset
  63. :param block_idx: a unique integer identifier given to every block.
  64. """
  65. def __init__(self, start_idx, end_idx, doc_idx, block_idx):
  66. self.start_idx = start_idx
  67. self.end_idx = end_idx
  68. self.doc_idx = doc_idx
  69. self.block_idx = block_idx
  70. def as_array(self):
  71. return np.array([self.start_idx, self.end_idx, self.doc_idx, self.block_idx]).astype(np.int64)
  72. def as_tuple(self):
  73. return self.start_idx, self.end_idx, self.doc_idx, self.block_idx
  74. class BlockSamplesMapping(object):
  75. def __init__(self, mapping_array):
  76. # make sure that the array is compatible with BlockSampleData
  77. assert mapping_array.shape[1] == 4
  78. self.mapping_array = mapping_array
  79. def __len__(self):
  80. return self.mapping_array.shape[0]
  81. def __getitem__(self, idx):
  82. """Get the data associated with an indexed sample."""
  83. sample_data = BlockSampleData(*self.mapping_array[idx])
  84. return sample_data
  85. def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epochs,
  86. max_num_samples, max_seq_length, seed, name, use_one_sent_docs=False):
  87. """Get samples mapping for a dataset over fixed size blocks. This function also requires
  88. a dataset of the titles for the source documents since their lengths must be taken into account.
  89. :return: samples_mapping (BlockSamplesMapping)
  90. """
  91. if not num_epochs:
  92. if not max_num_samples:
  93. raise ValueError("Need to specify either max_num_samples "
  94. "or num_epochs")
  95. num_epochs = np.iinfo(np.int32).max - 1
  96. if not max_num_samples:
  97. max_num_samples = np.iinfo(np.int64).max - 1
  98. # Filename of the index mapping
  99. indexmap_filename = data_prefix
  100. indexmap_filename += '_{}_indexmap'.format(name)
  101. if num_epochs != (np.iinfo(np.int32).max - 1):
  102. indexmap_filename += '_{}ep'.format(num_epochs)
  103. if max_num_samples != (np.iinfo(np.int64).max - 1):
  104. indexmap_filename += '_{}mns'.format(max_num_samples)
  105. indexmap_filename += '_{}msl'.format(max_seq_length)
  106. indexmap_filename += '_{}s'.format(seed)
  107. if use_one_sent_docs:
  108. indexmap_filename += '_1sentok'
  109. indexmap_filename += '.npy'
  110. # Build the indexed mapping if not exist.
  111. if mpu.get_data_parallel_rank() == 0 and \
  112. not os.path.isfile(indexmap_filename):
  113. print(' > WARNING: could not find index map file {}, building '
  114. 'the indices on rank 0 ...'.format(indexmap_filename))
  115. # Make sure the types match the helpers input types.
  116. assert block_dataset.doc_idx.dtype == np.int64
  117. assert block_dataset.sizes.dtype == np.int32
  118. # Build samples mapping
  119. verbose = torch.distributed.get_rank() == 0
  120. start_time = time.time()
  121. print_rank_0(' > building samples index mapping for {} ...'.format(
  122. name))
  123. from megatron.data import helpers
  124. mapping_array = helpers.build_blocks_mapping(
  125. block_dataset.doc_idx,
  126. block_dataset.sizes,
  127. title_dataset.sizes,
  128. num_epochs,
  129. max_num_samples,
  130. max_seq_length - 3, # account for added tokens
  131. seed,
  132. verbose,
  133. use_one_sent_docs)
  134. print_rank_0(' > done building samples index mapping')
  135. np.save(indexmap_filename, mapping_array, allow_pickle=True)
  136. print_rank_0(' > saved the index mapping in {}'.format(
  137. indexmap_filename))
  138. # Make sure all the ranks have built the mapping
  139. print_rank_0(' > elapsed time to build and save samples mapping '
  140. '(seconds): {:4f}'.format(
  141. time.time() - start_time))
  142. # This should be a barrier but nccl barrier assumes
  143. # device_index=rank which is not the case for model
  144. # parallel case
  145. counts = torch.cuda.LongTensor([1])
  146. torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
  147. assert counts[0].item() == torch.distributed.get_world_size(
  148. group=mpu.get_data_parallel_group())
  149. # Load indexed dataset.
  150. print_rank_0(' > loading indexed mapping from {}'.format(
  151. indexmap_filename))
  152. start_time = time.time()
  153. mapping_array = np.load(indexmap_filename, allow_pickle=True, mmap_mode='r')
  154. samples_mapping = BlockSamplesMapping(mapping_array)
  155. print_rank_0(' loaded indexed file in {:3.3f} seconds'.format(
  156. time.time() - start_time))
  157. print_rank_0(' total number of samples: {}'.format(
  158. mapping_array.shape[0]))
  159. return samples_mapping