cleanup_dataset.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. # coding=utf-8
  2. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import ftfy
  16. import json
  17. from langdetect import detect
  18. import numpy as np
  19. import time
  20. import os
  21. import sys
  22. from tokenizer import Tokenizer
  23. MIN_DOCUMENT_LENGHT = 128
  24. def print_progress(prefix, start_time, num_docs, num_fixed_text,
  25. num_non_english_docs, chars_non_english_docs,
  26. num_small_docs, chars_small_docs):
  27. string = prefix + ' | '
  28. string += 'elapsed time: {:.2f} | '.format(time.time() - start_time)
  29. string += 'documents: {} | '.format(num_docs)
  30. string += 'fixed text: {} | '.format(num_fixed_text)
  31. string += 'non-english: {} | '.format(num_non_english_docs)
  32. string += 'non-english chars: {} | '.format(chars_non_english_docs)
  33. string += 'small docs: {} | '.format(num_small_docs)
  34. string += 'small docs chars: {}'.format(chars_small_docs)
  35. print(string, flush=True)
  36. def filter_corpus(filename, out_filename, print_interval=10000):
  37. print(' > filtering {}'.format(filename))
  38. tokenizer = Tokenizer(cache_dir='./cache')
  39. num_docs = 0
  40. num_written_docs = 0
  41. num_small_docs = 0
  42. num_fixed_text = 0
  43. num_non_english_docs = 0
  44. chars_non_english_docs = 0
  45. chars_small_docs = 0
  46. start_time = time.time()
  47. with open(out_filename, 'wb') as f:
  48. with open(filename, 'r') as fin:
  49. for line in fin:
  50. try:
  51. num_docs += 1
  52. myjson = json.loads(line)
  53. # Fix text
  54. text = ftfy.fix_text(myjson['text'])
  55. if text != myjson['text']:
  56. num_fixed_text += 1
  57. myjson['text'] = text
  58. # Detect language.
  59. if detect(text) != 'en':
  60. print('[non-english text]', myjson)
  61. num_non_english_docs += 1
  62. chars_non_english_docs += len(text)
  63. continue
  64. # On average each token is 5 characters so 8 is an
  65. # upper bound.
  66. if len(text) < (8 * MIN_DOCUMENT_LENGHT):
  67. tokens = tokenizer.tokenize_document(text)
  68. if len(tokens) < MIN_DOCUMENT_LENGHT:
  69. print('[small document, skipping]:', myjson)
  70. num_small_docs += 1
  71. chars_small_docs += len(text)
  72. continue
  73. myjson = json.dumps(myjson, ensure_ascii=False)
  74. f.write(myjson.encode('utf-8'))
  75. f.write('\n'.encode('utf-8'))
  76. num_written_docs += 1
  77. if num_docs % print_interval == 0:
  78. print_progress('[PROGRESS]', start_time, num_docs,
  79. num_fixed_text, num_non_english_docs,
  80. chars_non_english_docs,
  81. num_small_docs, chars_small_docs)
  82. except Exception as e:
  83. print(' skipping ', line, e)
  84. print_progress('[FINAL]', start_time, num_docs,
  85. num_fixed_text, num_non_english_docs,
  86. chars_non_english_docs,
  87. num_small_docs, chars_small_docs)
  88. if __name__ == '__main__':
  89. print('building gpt2 dataset ...')
  90. input_filename = sys.argv[1]
  91. output_filename = sys.argv[2]
  92. print('will be reading {}'.format(input_filename))
  93. print('and will write the results to {}'.format(output_filename))
  94. filter_corpus(input_filename, output_filename)