preprocess_data.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. # coding=utf-8
  2. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Processing data for pretraining."""
  16. import argparse
  17. import json
  18. import multiprocessing
  19. import os
  20. import sys
  21. sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
  22. os.path.pardir)))
  23. import time
  24. import torch
  25. try:
  26. import nltk
  27. nltk_available = True
  28. except ImportError:
  29. nltk_available = False
  30. from megatron.tokenizer import build_tokenizer
  31. from megatron.data import indexed_dataset
  32. # https://stackoverflow.com/questions/33139531/preserve-empty-lines-with-nltks-punkt-tokenizer
  33. class CustomLanguageVars(nltk.tokenize.punkt.PunktLanguageVars):
  34. _period_context_fmt = r"""
  35. \S* # some word material
  36. %(SentEndChars)s # a potential sentence ending
  37. \s* # <-- THIS is what I changed
  38. (?=(?P<after_tok>
  39. %(NonWord)s # either other punctuation
  40. |
  41. (?P<next_tok>\S+) # <-- Normally you would have \s+ here
  42. ))"""
  43. class IdentitySplitter(object):
  44. def tokenize(self, *text):
  45. return text
  46. class Encoder(object):
  47. def __init__(self, args):
  48. self.args = args
  49. def initializer(self):
  50. # Use Encoder class as a container for global data
  51. Encoder.tokenizer = build_tokenizer(self.args)
  52. if self.args.split_sentences:
  53. if not nltk_available:
  54. print("NLTK is not available to split sentences.")
  55. exit()
  56. splitter = nltk.load("tokenizers/punkt/english.pickle")
  57. if self.args.keep_newlines:
  58. # this prevents punkt from eating newlines after sentences
  59. Encoder.splitter = nltk.tokenize.punkt.PunktSentenceTokenizer(
  60. train_text = splitter._params,
  61. lang_vars = CustomLanguageVars())
  62. else:
  63. Encoder.splitter = splitter
  64. else:
  65. Encoder.splitter = IdentitySplitter()
  66. def encode(self, json_line):
  67. data = json.loads(json_line)
  68. ids = {}
  69. for key in self.args.json_keys:
  70. text = data[key]
  71. doc_ids = []
  72. for sentence in Encoder.splitter.tokenize(text):
  73. sentence_ids = Encoder.tokenizer.tokenize(sentence)
  74. if len(sentence_ids) > 0:
  75. doc_ids.append(sentence_ids)
  76. if len(doc_ids) > 0 and self.args.append_eod:
  77. doc_ids[-1].append(Encoder.tokenizer.eod)
  78. ids[key] = doc_ids
  79. return ids, len(json_line)
  80. def get_args():
  81. parser = argparse.ArgumentParser()
  82. group = parser.add_argument_group(title='input data')
  83. group.add_argument('--input', type=str, required=True,
  84. help='Path to input JSON')
  85. group.add_argument('--json-keys', nargs='+', default=['text'],
  86. help='space separate listed of keys to extract from json')
  87. group.add_argument('--split-sentences', action='store_true',
  88. help='Split documents into sentences.')
  89. group.add_argument('--keep-newlines', action='store_true',
  90. help='Keep newlines between sentences when splitting.')
  91. group = parser.add_argument_group(title='tokenizer')
  92. group.add_argument('--tokenizer-type', type=str, required=True,
  93. choices=['BertWordPieceLowerCase','BertWordPieceCase',
  94. 'GPT2BPETokenizer'],
  95. help='What type of tokenizer to use.')
  96. group.add_argument('--vocab-file', type=str, default=None,
  97. help='Path to the vocab file')
  98. group.add_argument('--merge-file', type=str, default=None,
  99. help='Path to the BPE merge file (if necessary).')
  100. group.add_argument('--append-eod', action='store_true',
  101. help='Append an <eod> token to the end of a document.')
  102. group = parser.add_argument_group(title='output data')
  103. group.add_argument('--output-prefix', type=str, required=True,
  104. help='Path to binary output file without suffix')
  105. group.add_argument('--dataset-impl', type=str, default='mmap',
  106. choices=['lazy', 'cached', 'mmap'])
  107. group = parser.add_argument_group(title='runtime')
  108. group.add_argument('--workers', type=int, default=1,
  109. help='Number of worker processes to launch')
  110. group.add_argument('--log-interval', type=int, default=100,
  111. help='Interval between progress updates')
  112. args = parser.parse_args()
  113. args.keep_empty = False
  114. if args.tokenizer_type.lower().startswith('bert'):
  115. if not args.split_sentences:
  116. print("Bert tokenizer detected, are you sure you don't want to split sentences?")
  117. # some default/dummy values for the tokenizer
  118. args.rank = 0
  119. args.make_vocab_size_divisible_by = 128
  120. args.tensor_model_parallel_size = 1
  121. args.vocab_extra_ids = 0
  122. return args
  123. def main():
  124. args = get_args()
  125. startup_start = time.time()
  126. print("Opening", args.input)
  127. fin = open(args.input, 'r', encoding='utf-8')
  128. if nltk_available and args.split_sentences:
  129. nltk.download("punkt", quiet=True)
  130. encoder = Encoder(args)
  131. tokenizer = build_tokenizer(args)
  132. pool = multiprocessing.Pool(args.workers, initializer=encoder.initializer)
  133. encoded_docs = pool.imap(encoder.encode, fin, 25)
  134. #encoded_docs = map(encoder.encode, fin)
  135. level = "document"
  136. if args.split_sentences:
  137. level = "sentence"
  138. print(f"Vocab size: {tokenizer.vocab_size}")
  139. print(f"Output prefix: {args.output_prefix}")
  140. output_bin_files = {}
  141. output_idx_files = {}
  142. builders = {}
  143. for key in args.json_keys:
  144. output_bin_files[key] = "{}_{}_{}.bin".format(args.output_prefix,
  145. key, level)
  146. output_idx_files[key] = "{}_{}_{}.idx".format(args.output_prefix,
  147. key, level)
  148. builders[key] = indexed_dataset.make_builder(output_bin_files[key],
  149. impl=args.dataset_impl,
  150. vocab_size=tokenizer.vocab_size)
  151. startup_end = time.time()
  152. proc_start = time.time()
  153. total_bytes_processed = 0
  154. print("Time to startup:", startup_end - startup_start)
  155. for i, (doc, bytes_processed) in enumerate(encoded_docs, start=1):
  156. total_bytes_processed += bytes_processed
  157. for key, sentences in doc.items():
  158. if len(sentences) == 0:
  159. continue
  160. for sentence in sentences:
  161. builders[key].add_item(torch.IntTensor(sentence))
  162. builders[key].end_document()
  163. if i % args.log_interval == 0:
  164. current = time.time()
  165. elapsed = current - proc_start
  166. mbs = total_bytes_processed/elapsed/1024/1024
  167. print(f"Processed {i} documents",
  168. f"({i/elapsed} docs/s, {mbs} MB/s).",
  169. file=sys.stderr)
  170. for key in args.json_keys:
  171. builders[key].finalize(output_idx_files[key])
  172. if __name__ == '__main__':
  173. main()