bert_dataset.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. # coding=utf-8
  2. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """BERT Style dataset."""
  16. import numpy as np
  17. import torch
  18. from megatron import (
  19. get_args,
  20. get_tokenizer,
  21. mpu,
  22. print_rank_0
  23. )
  24. from megatron.data.dataset_utils import (
  25. get_samples_mapping,
  26. get_a_and_b_segments,
  27. truncate_segments,
  28. create_tokens_and_tokentypes,
  29. create_masked_lm_predictions
  30. )
  31. class BertDataset(torch.utils.data.Dataset):
  32. def __init__(self, name, indexed_dataset, data_prefix,
  33. num_epochs, max_num_samples, masked_lm_prob,
  34. max_seq_length, short_seq_prob, seed, binary_head):
  35. # Params to store.
  36. self.name = name
  37. self.seed = seed
  38. self.masked_lm_prob = masked_lm_prob
  39. self.max_seq_length = max_seq_length
  40. self.binary_head = binary_head
  41. # Dataset.
  42. self.indexed_dataset = indexed_dataset
  43. # Build the samples mapping.
  44. self.samples_mapping = get_samples_mapping(self.indexed_dataset,
  45. data_prefix,
  46. num_epochs,
  47. max_num_samples,
  48. self.max_seq_length - 3, # account for added tokens
  49. short_seq_prob,
  50. self.seed,
  51. self.name,
  52. self.binary_head)
  53. # Vocab stuff.
  54. tokenizer = get_tokenizer()
  55. self.vocab_id_list = list(tokenizer.inv_vocab.keys())
  56. self.vocab_id_to_token_dict = tokenizer.inv_vocab
  57. self.cls_id = tokenizer.cls
  58. self.sep_id = tokenizer.sep
  59. self.mask_id = tokenizer.mask
  60. self.pad_id = tokenizer.pad
  61. def __len__(self):
  62. return self.samples_mapping.shape[0]
  63. def __getitem__(self, idx):
  64. start_idx, end_idx, seq_length = self.samples_mapping[idx]
  65. sample = [self.indexed_dataset[i] for i in range(start_idx, end_idx)]
  66. # Note that this rng state should be numpy and not python since
  67. # python randint is inclusive whereas the numpy one is exclusive.
  68. # We % 2**32 since numpy requres the seed to be between 0 and 2**32 - 1
  69. np_rng = np.random.RandomState(seed=((self.seed + idx) % 2**32))
  70. return build_training_sample(sample, seq_length,
  71. self.max_seq_length, # needed for padding
  72. self.vocab_id_list,
  73. self.vocab_id_to_token_dict,
  74. self.cls_id, self.sep_id,
  75. self.mask_id, self.pad_id,
  76. self.masked_lm_prob, np_rng,
  77. self.binary_head)
  78. def build_training_sample(sample,
  79. target_seq_length, max_seq_length,
  80. vocab_id_list, vocab_id_to_token_dict,
  81. cls_id, sep_id, mask_id, pad_id,
  82. masked_lm_prob, np_rng, binary_head):
  83. """Biuld training sample.
  84. Arguments:
  85. sample: A list of sentences in which each sentence is a list token ids.
  86. target_seq_length: Desired sequence length.
  87. max_seq_length: Maximum length of the sequence. All values are padded to
  88. this length.
  89. vocab_id_list: List of vocabulary ids. Used to pick a random id.
  90. vocab_id_to_token_dict: A dictionary from vocab ids to text tokens.
  91. cls_id: Start of example id.
  92. sep_id: Separator id.
  93. mask_id: Mask token id.
  94. pad_id: Padding token id.
  95. masked_lm_prob: Probability to mask tokens.
  96. np_rng: Random number genenrator. Note that this rng state should be
  97. numpy and not python since python randint is inclusive for
  98. the opper bound whereas the numpy one is exclusive.
  99. """
  100. if binary_head:
  101. # We assume that we have at least two sentences in the sample
  102. assert len(sample) > 1
  103. assert target_seq_length <= max_seq_length
  104. # Divide sample into two segments (A and B).
  105. if binary_head:
  106. tokens_a, tokens_b, is_next_random = get_a_and_b_segments(sample,
  107. np_rng)
  108. else:
  109. tokens_a = []
  110. for j in range(len(sample)):
  111. tokens_a.extend(sample[j])
  112. tokens_b = []
  113. is_next_random = False
  114. # Truncate to `target_sequence_length`.
  115. max_num_tokens = target_seq_length
  116. truncated = truncate_segments(tokens_a, tokens_b, len(tokens_a),
  117. len(tokens_b), max_num_tokens, np_rng)
  118. # Build tokens and toketypes.
  119. tokens, tokentypes = create_tokens_and_tokentypes(tokens_a, tokens_b,
  120. cls_id, sep_id)
  121. # Masking.
  122. max_predictions_per_seq = masked_lm_prob * max_num_tokens
  123. (tokens, masked_positions, masked_labels, _, _) = create_masked_lm_predictions(
  124. tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob,
  125. cls_id, sep_id, mask_id, max_predictions_per_seq, np_rng)
  126. # Padding.
  127. tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np \
  128. = pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
  129. masked_labels, pad_id, max_seq_length)
  130. train_sample = {
  131. 'text': tokens_np,
  132. 'types': tokentypes_np,
  133. 'labels': labels_np,
  134. 'is_random': int(is_next_random),
  135. 'loss_mask': loss_mask_np,
  136. 'padding_mask': padding_mask_np,
  137. 'truncated': int(truncated)}
  138. return train_sample
  139. def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
  140. masked_labels, pad_id, max_seq_length):
  141. """Pad sequences and convert them to numpy."""
  142. # Some checks.
  143. num_tokens = len(tokens)
  144. padding_length = max_seq_length - num_tokens
  145. assert padding_length >= 0
  146. assert len(tokentypes) == num_tokens
  147. assert len(masked_positions) == len(masked_labels)
  148. # Tokens and token types.
  149. filler = [pad_id] * padding_length
  150. tokens_np = np.array(tokens + filler, dtype=np.int64)
  151. tokentypes_np = np.array(tokentypes + filler, dtype=np.int64)
  152. # Padding mask.
  153. padding_mask_np = np.array([1] * num_tokens + [0] * padding_length,
  154. dtype=np.int64)
  155. # Lables and loss mask.
  156. labels = [-1] * max_seq_length
  157. loss_mask = [0] * max_seq_length
  158. for i in range(len(masked_positions)):
  159. assert masked_positions[i] < num_tokens
  160. labels[masked_positions[i]] = masked_labels[i]
  161. loss_mask[masked_positions[i]] = 1
  162. labels_np = np.array(labels, dtype=np.int64)
  163. loss_mask_np = np.array(loss_mask, dtype=np.int64)
  164. return tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np