gpt_dataset.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431
  1. # coding=utf-8
  2. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """GPT style dataset."""
  16. import os
  17. import time
  18. import numpy as np
  19. import torch
  20. from megatron import mpu, print_rank_0
  21. from megatron.data.blendable_dataset import BlendableDataset
  22. from megatron.data.dataset_utils import get_datasets_weights_and_num_samples
  23. from megatron.data.dataset_utils import get_train_valid_test_split_
  24. from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
  25. def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
  26. train_valid_test_num_samples,
  27. seq_length, seed, skip_warmup):
  28. """Build train, valid, and test datasets."""
  29. # Single dataset.
  30. if len(data_prefix) == 1:
  31. return _build_train_valid_test_datasets(data_prefix[0],
  32. data_impl, splits_string,
  33. train_valid_test_num_samples,
  34. seq_length, seed, skip_warmup)
  35. # Blending dataset.
  36. # Parse the values.
  37. output = get_datasets_weights_and_num_samples(data_prefix,
  38. train_valid_test_num_samples)
  39. prefixes, weights, datasets_train_valid_test_num_samples = output
  40. # Build individual datasets.
  41. train_datasets = []
  42. valid_datasets = []
  43. test_datasets = []
  44. for i in range(len(prefixes)):
  45. train_ds, valid_ds, test_ds = _build_train_valid_test_datasets(
  46. prefixes[i], data_impl, splits_string,
  47. datasets_train_valid_test_num_samples[i],
  48. seq_length, seed, skip_warmup)
  49. if train_ds:
  50. train_datasets.append(train_ds)
  51. if valid_ds:
  52. valid_datasets.append(valid_ds)
  53. if test_ds:
  54. test_datasets.append(test_ds)
  55. # Blend.
  56. blending_train_dataset = None
  57. if train_datasets:
  58. blending_train_dataset = BlendableDataset(train_datasets, weights)
  59. blending_valid_dataset = None
  60. if valid_datasets:
  61. blending_valid_dataset = BlendableDataset(valid_datasets, weights)
  62. blending_test_dataset = None
  63. if test_datasets:
  64. blending_test_dataset = BlendableDataset(test_datasets, weights)
  65. return (blending_train_dataset, blending_valid_dataset,
  66. blending_test_dataset)
  67. def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
  68. train_valid_test_num_samples,
  69. seq_length, seed, skip_warmup):
  70. """Build train, valid, and test datasets."""
  71. # Indexed dataset.
  72. indexed_dataset = get_indexed_dataset_(data_prefix,
  73. data_impl,
  74. skip_warmup)
  75. total_num_of_documents = indexed_dataset.sizes.shape[0]
  76. splits = get_train_valid_test_split_(splits_string, total_num_of_documents)
  77. # Print stats about the splits.
  78. print_rank_0(' > dataset split:')
  79. def print_split_stats(name, index):
  80. print_rank_0(' {}:'.format(name))
  81. print_rank_0(' document indices in [{}, {}) total of {} '
  82. 'documents'.format(splits[index], splits[index + 1],
  83. splits[index + 1] - splits[index]))
  84. print_split_stats('train', 0)
  85. print_split_stats('validation', 1)
  86. print_split_stats('test', 2)
  87. def build_dataset(index, name):
  88. dataset = None
  89. if splits[index + 1] > splits[index]:
  90. documents = np.arange(start=splits[index], stop=splits[index + 1],
  91. step=1, dtype=np.int32)
  92. dataset = GPTDataset(name, data_prefix,
  93. documents, indexed_dataset,
  94. train_valid_test_num_samples[index],
  95. seq_length, seed)
  96. return dataset
  97. train_dataset = build_dataset(0, 'train')
  98. valid_dataset = build_dataset(1, 'valid')
  99. test_dataset = build_dataset(2, 'test')
  100. return (train_dataset, valid_dataset, test_dataset)
  101. def get_indexed_dataset_(data_prefix, data_impl, skip_warmup):
  102. """Build indexed dataset."""
  103. print_rank_0(' > building dataset index ...')
  104. start_time = time.time()
  105. indexed_dataset = make_indexed_dataset(data_prefix,
  106. data_impl,
  107. skip_warmup)
  108. print_rank_0(' > finished creating indexed dataset in {:4f} '
  109. 'seconds'.format(time.time() - start_time))
  110. print_rank_0(' number of documents: {}'.format(
  111. indexed_dataset.sizes.shape[0]))
  112. return indexed_dataset
  113. class GPTDataset(torch.utils.data.Dataset):
  114. def __init__(self, name, data_prefix, documents, indexed_dataset,
  115. num_samples, seq_length, seed):
  116. self.name = name
  117. self.indexed_dataset = indexed_dataset
  118. # Checks
  119. assert np.min(documents) >= 0
  120. assert np.max(documents) < indexed_dataset.sizes.shape[0]
  121. # Build index mappings.
  122. self.doc_idx, self.sample_idx, self.shuffle_idx = _build_index_mappings(
  123. self.name, data_prefix, documents, self.indexed_dataset.sizes,
  124. num_samples, seq_length, seed)
  125. def __len__(self):
  126. # -1 is due to data structure used to retieve the index:
  127. # sample i --> [sample_idx[i], sample_idx[i+1])
  128. return self.sample_idx.shape[0] - 1
  129. def __getitem__(self, idx):
  130. # Get the shuffled index.
  131. idx = self.shuffle_idx[idx]
  132. # Start and end documents and offsets.
  133. doc_index_f = self.sample_idx[idx][0]
  134. doc_index_l = self.sample_idx[idx + 1][0]
  135. offset_f = self.sample_idx[idx][1]
  136. offset_l = self.sample_idx[idx + 1][1]
  137. # If we are within the same document, just extract the chunk.
  138. if doc_index_f == doc_index_l:
  139. sample = self.indexed_dataset.get(self.doc_idx[doc_index_f],
  140. offset=offset_f,
  141. length=offset_l - offset_f + 1)
  142. else:
  143. # Otherwise, get the rest of the initial document.
  144. sample_list = [self.indexed_dataset.get(self.doc_idx[doc_index_f],
  145. offset=offset_f)]
  146. # Loop over all in between documents and add the entire document.
  147. for i in range(doc_index_f + 1, doc_index_l):
  148. sample_list.append(self.indexed_dataset.get(self.doc_idx[i]))
  149. # And finally add the relevant portion of last document.
  150. sample_list.append(self.indexed_dataset.get(
  151. self.doc_idx[doc_index_l],
  152. length=offset_l + 1))
  153. sample = np.concatenate(sample_list)
  154. return {'text': np.array(sample, dtype=np.int64)}
  155. def _build_index_mappings(name, data_prefix, documents, sizes,
  156. num_samples, seq_length, seed):
  157. """Build doc-idx, sample-idx, and shuffle-idx.
  158. doc-idx: is an array (ordered) of documents to be used in training.
  159. sample-idx: is the start document index and document offset for each
  160. training sample.
  161. shuffle-idx: maps the sample index into a random index into sample-idx.
  162. """
  163. # Number of tokens in each epoch and number of required epochs.
  164. tokens_per_epoch = _num_tokens(documents, sizes)
  165. num_epochs = _num_epochs(tokens_per_epoch, seq_length, num_samples)
  166. # rng state
  167. np_rng = np.random.RandomState(seed=seed)
  168. # Filename of the index mappings.
  169. _filename = data_prefix
  170. _filename += '_{}_indexmap'.format(name)
  171. _filename += '_{}ns'.format(num_samples)
  172. _filename += '_{}sl'.format(seq_length)
  173. _filename += '_{}s'.format(seed)
  174. doc_idx_filename = _filename + '_doc_idx.npy'
  175. sample_idx_filename = _filename + '_sample_idx.npy'
  176. shuffle_idx_filename = _filename + '_shuffle_idx.npy'
  177. # Build the indexed mapping if not exist.
  178. if torch.distributed.get_rank() == 0:
  179. if (not os.path.isfile(doc_idx_filename)) or \
  180. (not os.path.isfile(sample_idx_filename)) or \
  181. (not os.path.isfile(shuffle_idx_filename)):
  182. print_rank_0(' > WARNING: could not find index map files, building '
  183. 'the indices on rank 0 ...')
  184. # For the last epoch, decide whether include the entire epoch
  185. # in the global shuffle or not.
  186. # If we need only one epoch, then separating last epoch does
  187. # not mean anything.
  188. if num_epochs == 1:
  189. separate_last_epoch = False
  190. print(' > only one epoch required, setting '
  191. 'separate_last_epoch to False', flush=True)
  192. else:
  193. # Get the number of samples for the last epoch
  194. num_samples_from_epochs_minus_one = (
  195. (num_epochs - 1) * tokens_per_epoch - 1) // seq_length
  196. last_epoch_num_samples = num_samples - \
  197. num_samples_from_epochs_minus_one
  198. assert last_epoch_num_samples >= 0, \
  199. 'last epoch number of samples should be non-negative.'
  200. num_samples_per_epoch = (tokens_per_epoch - 1) // seq_length
  201. assert last_epoch_num_samples < (num_samples_per_epoch + 1), \
  202. 'last epoch number of samples exceeded max value.'
  203. # If we have less than 80% of the samples for the last epoch,
  204. # seperate out the epoch and treat it differently.
  205. # Note: the 80% number is just based on common sense and can
  206. # be adjusted if needed.
  207. separate_last_epoch = (last_epoch_num_samples <
  208. int(0.80 * num_samples_per_epoch))
  209. if separate_last_epoch:
  210. string = ' > last epoch number of samples ({}) is smaller '\
  211. 'than 80% of number of samples per epoch ({}), '\
  212. 'setting separate_last_epoch to True'
  213. else:
  214. string = ' > last epoch number of samples ({}) is larger '\
  215. 'than 80% of number of samples per epoch ({}), '\
  216. 'setting separate_last_epoch to False'
  217. print(string.format(last_epoch_num_samples,
  218. num_samples_per_epoch), flush=True)
  219. # doc-idx.
  220. start_time = time.time()
  221. doc_idx = _build_doc_idx(documents, num_epochs, np_rng,
  222. separate_last_epoch)
  223. np.save(doc_idx_filename, doc_idx, allow_pickle=True)
  224. print_rank_0(' > elasped time to build and save doc-idx mapping '
  225. '(seconds): {:4f}'.format(time.time() - start_time))
  226. # sample-idx.
  227. start_time = time.time()
  228. # Use C++ implementation for speed.
  229. # First compile and then import.
  230. from megatron.data import helpers
  231. assert doc_idx.dtype == np.int32
  232. assert sizes.dtype == np.int32
  233. sample_idx = helpers.build_sample_idx(sizes, doc_idx, seq_length,
  234. num_epochs, tokens_per_epoch)
  235. # sample_idx = _build_sample_idx(sizes, doc_idx, seq_length,
  236. # num_epochs, tokens_per_epoch)
  237. np.save(sample_idx_filename, sample_idx, allow_pickle=True)
  238. print_rank_0(' > elasped time to build and save sample-idx mapping '
  239. '(seconds): {:4f}'.format(time.time() - start_time))
  240. # shuffle-idx.
  241. start_time = time.time()
  242. # -1 is due to data structure used to retieve the index:
  243. # sample i --> [sample_idx[i], sample_idx[i+1])
  244. if separate_last_epoch:
  245. num_samples_ = num_samples_from_epochs_minus_one
  246. else:
  247. num_samples_ = sample_idx.shape[0] - 1
  248. shuffle_idx = _build_shuffle_idx(num_samples_,
  249. sample_idx.shape[0] - 1, np_rng)
  250. np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True)
  251. print_rank_0(' > elasped time to build and save shuffle-idx mapping'
  252. ' (seconds): {:4f}'.format(time.time() - start_time))
  253. # This should be a barrier but nccl barrier assumes
  254. # device_index=rank which is not the case for model
  255. # parallel case
  256. counts = torch.cuda.LongTensor([1])
  257. torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
  258. torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group())
  259. assert counts[0].item() == (
  260. torch.distributed.get_world_size() //
  261. torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group()))
  262. # Load mappings.
  263. start_time = time.time()
  264. print_rank_0(' > loading doc-idx mapping from {}'.format(
  265. doc_idx_filename))
  266. doc_idx = np.load(doc_idx_filename, allow_pickle=True, mmap_mode='r')
  267. print_rank_0(' > loading sample-idx mapping from {}'.format(
  268. sample_idx_filename))
  269. sample_idx = np.load(sample_idx_filename, allow_pickle=True, mmap_mode='r')
  270. print_rank_0(' > loading shuffle-idx mapping from {}'.format(
  271. shuffle_idx_filename))
  272. shuffle_idx = np.load(shuffle_idx_filename, allow_pickle=True, mmap_mode='r')
  273. print_rank_0(' loaded indexed file in {:3.3f} seconds'.format(
  274. time.time() - start_time))
  275. print_rank_0(' total number of samples: {}'.format(
  276. sample_idx.shape[0]))
  277. print_rank_0(' total number of epochs: {}'.format(num_epochs))
  278. return doc_idx, sample_idx, shuffle_idx
  279. def _num_tokens(documents, sizes):
  280. """Total number of tokens in the dataset."""
  281. return np.sum(sizes[documents])
  282. def _num_epochs(tokens_per_epoch, seq_length, num_samples):
  283. """Based on number of samples and sequence lenght, calculate how many
  284. epochs will be needed."""
  285. num_epochs = 0
  286. total_tokens = 0
  287. while True:
  288. num_epochs += 1
  289. total_tokens += tokens_per_epoch
  290. # -1 is because we need to retrieve seq_length + 1 token each time
  291. # but the last token will overlap with the first token of the next
  292. # sample except for the last sample.
  293. if ((total_tokens - 1) // seq_length) >= num_samples:
  294. return num_epochs
  295. def _build_doc_idx(documents, num_epochs, np_rng, separate_last_epoch):
  296. """Build an array with length = number-of-epochs * number-of-dcuments.
  297. Each index is mapped to a corresponding document."""
  298. if not separate_last_epoch or num_epochs == 1:
  299. doc_idx = np.mgrid[0:num_epochs, 0:len(documents)][1]
  300. doc_idx[:] = documents
  301. doc_idx = doc_idx.reshape(-1)
  302. doc_idx = doc_idx.astype(np.int32)
  303. np_rng.shuffle(doc_idx)
  304. return doc_idx
  305. doc_idx_first = _build_doc_idx(documents, num_epochs-1, np_rng, False)
  306. doc_idx_last = _build_doc_idx(documents, 1, np_rng, False)
  307. return np.concatenate((doc_idx_first, doc_idx_last))
  308. def _build_sample_idx(sizes, doc_idx, seq_length,
  309. num_epochs, tokens_per_epoch):
  310. """Sample index mapping is a 2D array with sizes
  311. [number-of-samples + 1, 2] where [..., 0] contains
  312. the index into `doc_idx` and [..., 1] is the
  313. starting offset in that document."""
  314. # Total number of samples. For -1 see comments in `_num_epochs`.
  315. num_samples = (num_epochs * tokens_per_epoch - 1) // seq_length
  316. sample_idx = np.zeros([num_samples + 1, 2], dtype=np.int32)
  317. # Index into sample_idx.
  318. sample_index = 0
  319. # Index into doc_idx.
  320. doc_idx_index = 0
  321. # Begining offset for each document.
  322. doc_offset = 0
  323. # Start with first document and no offset.
  324. sample_idx[sample_index][0] = doc_idx_index
  325. sample_idx[sample_index][1] = doc_offset
  326. sample_index += 1
  327. while sample_index <= num_samples:
  328. # Start with a fresh sequence.
  329. remaining_seq_length = seq_length + 1
  330. while remaining_seq_length != 0:
  331. # Get the document length.
  332. doc_id = doc_idx[doc_idx_index]
  333. doc_length = sizes[doc_id] - doc_offset
  334. # And add it to the current sequence.
  335. remaining_seq_length -= doc_length
  336. # If we have more than a full sequence, adjust offset and set
  337. # remaining length to zero so we return from the while loop.
  338. # Note that -1 here is for the same reason we have -1 in
  339. # `_num_epochs` calculations.
  340. if remaining_seq_length <= 0:
  341. doc_offset += (remaining_seq_length + doc_length - 1)
  342. remaining_seq_length = 0
  343. else:
  344. # Otherwise, start from the begining of the next document.
  345. doc_idx_index += 1
  346. doc_offset = 0
  347. # Record the sequence.
  348. sample_idx[sample_index][0] = doc_idx_index
  349. sample_idx[sample_index][1] = doc_offset
  350. sample_index += 1
  351. return sample_idx
  352. def _build_shuffle_idx(num_samples, total_size, np_rng):
  353. """Build the range [0, size) and shuffle."""
  354. print(' > building shuffle index with split [0, {}) and [{}, {}) '
  355. '...'.format(num_samples, num_samples, total_size), flush=True)
  356. dtype_ = np.uint32
  357. if total_size >= (np.iinfo(np.uint32).max - 1):
  358. dtype_ = np.int64
  359. shuffle_idx_first = np.arange(start=0, stop=num_samples,
  360. step=1, dtype=dtype_)
  361. np_rng.shuffle(shuffle_idx_first)
  362. if num_samples == total_size:
  363. return shuffle_idx_first
  364. shuffle_idx_last = np.arange(start=num_samples, stop=total_size,
  365. step=1, dtype=dtype_)
  366. np_rng.shuffle(shuffle_idx_last)
  367. return np.concatenate((shuffle_idx_first, shuffle_idx_last))