filter_ngrams.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493
  1. # coding=utf-8
  2. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """
  16. Deduplicate downstream tasks from training dataset. 13-grams have been used.
  17. All split documents with less than 200 characters got filtered. Any document
  18. with more than 10 splits got filtered as well.
  19. """
  20. import argparse
  21. from functools import partial
  22. import json
  23. import multiprocessing
  24. import nltk
  25. import pickle
  26. import re
  27. import string
  28. import sys
  29. import time
  30. def get_words(text):
  31. # get all the lowercase words from text
  32. words, positions = [], []
  33. for match in re.finditer(r'\w+', text.lower()):
  34. words.append(match.group(0))
  35. positions.append(match.start())
  36. return words, positions
  37. # splits the text
  38. def split_text(text, start_position, remove_char_each_side, seq):
  39. # first part of the text
  40. punctuations = ".!?"
  41. pos = start_position - remove_char_each_side
  42. text_first = ""
  43. while pos > 0 and not text[pos] in punctuations:
  44. pos -= 1
  45. if pos > 0:
  46. text_first = text[0:pos+1]
  47. # add length of seq and remove_char_each_side
  48. pos = start_position + len(seq) + remove_char_each_side
  49. # last part of the text
  50. text_second = ""
  51. while pos < len(text) and not text[pos] in punctuations:
  52. pos += 1
  53. if pos + 1 < len(text):
  54. text_second = text[pos+1:len(text)]
  55. return text_first, text_second
  56. def check_and_clean_text(args, words, ngrams, text, start_position, \
  57. text_buf_ngram_free, text_buf, local_ngram):
  58. seq = " ".join(words)
  59. if seq in ngrams:
  60. print(" [matched]: {}".format(seq), flush=True)
  61. if args.get_ngram_freq_only:
  62. # increase freq of this seq and then only consider the later part
  63. # of the text for further processing
  64. if seq in local_ngram:
  65. local_ngram[seq] += 1
  66. else:
  67. local_ngram[seq] = 1
  68. #print(" [increased]: {} {}".format(seq, ngrams[seq]), flush=True)
  69. if (start_position + len(seq) + 1) < len(text):
  70. text_buf.append(text[start_position + len(seq) + 1:len(text)])
  71. return False
  72. # split the text
  73. text_first, text_second = split_text(text, start_position, \
  74. args.remove_char_each_side, seq)
  75. # first part of ngrams free
  76. if len(text_first) > args.filter_text_char_len:
  77. text_buf_ngram_free.append(text_first)
  78. # add second part for further processing
  79. if len(text_second) > args.filter_text_char_len:
  80. text_buf.append(text_second)
  81. return False # not ngram free
  82. # ngram free
  83. return True
  84. def free_ngram(line, args, key, ngrams, ngrams_freq_sorted):
  85. # remove all the ngrams
  86. try:
  87. myjson = json.loads(line)
  88. text_buf = [myjson[key]]
  89. except Exception as e:
  90. print("Error: {}".format(e), flush=True)
  91. text_buf = []
  92. text_buf_ngram_free = []
  93. local_ngram = {}
  94. while len(text_buf) > 0:
  95. # get the first one from the buffer
  96. text = text_buf.pop(0)
  97. words, positions = get_words(text)
  98. ngram_free = True
  99. # find each max n-grams and check dictionary
  100. for i in range(len(words) - args.max_ngram_size + 1):
  101. check_ngram_free = check_and_clean_text(args, words[i:\
  102. i+args.max_ngram_size], ngrams, text, positions[i], \
  103. text_buf_ngram_free, text_buf, local_ngram)
  104. # the seq is ngram free? if yes, break
  105. if not check_ngram_free:
  106. ngram_free = False
  107. break
  108. # if max ngrams doesn't match, check if any other lower n-grams
  109. # within max ngram macthes
  110. for ngram_len, _ in ngrams_freq_sorted:
  111. check_ngram_free = check_and_clean_text(args, words[i:\
  112. i+ngram_len], ngrams, text, positions[i], \
  113. text_buf_ngram_free, text_buf, local_ngram)
  114. # same check as above
  115. if not check_ngram_free:
  116. ngram_free = False
  117. break
  118. # check break from lower than max ngram loop above
  119. if not ngram_free:
  120. break
  121. # for the last max n-gram, check all the lower ngrams in it
  122. if ngram_free and len(words) - args.max_ngram_size > 0:
  123. # get the last words of the lax max ngram
  124. last_seq_words = words[(len(words)-args.max_ngram_size):len(words)]
  125. last_seq_start_position = len(words) - args.max_ngram_size
  126. # check all n-grams lower than the max
  127. for pos, (ngram_len, _) in enumerate(ngrams_freq_sorted):
  128. # ignore the max ngram as has been considered already
  129. if ngram_len == args.max_ngram_size:
  130. continue
  131. # find each ngram of ngram_len in max n-grams and check
  132. for i in range(len(last_seq_words) - ngram_len + 1):
  133. check_ngram_free = check_and_clean_text(args, \
  134. last_seq_words[i:i+ngram_len], ngrams, text,\
  135. positions[last_seq_start_position+i], \
  136. text_buf_ngram_free, text_buf, local_ngram)
  137. if not check_ngram_free:
  138. ngram_free = False
  139. break
  140. if not ngram_free:
  141. break
  142. # texts are ngram free
  143. if ngram_free and not args.get_ngram_freq_only:
  144. text_buf_ngram_free.append(text)
  145. # check if the text has only been trimmed
  146. trimmed = 0
  147. if not args.get_ngram_freq_only and len(text_buf_ngram_free) == 1 and \
  148. len(text_buf_ngram_free[0]) < len(myjson[key]):
  149. trimmed = 1
  150. return text_buf_ngram_free, trimmed, myjson, local_ngram
  151. # insert word sequence into dictionary
  152. def insert_dict(words, ngrams, pos):
  153. seq = " ".join(words)
  154. if seq not in ngrams:
  155. ngrams[seq] = 0
  156. #ngrams[seq] = pos
  157. # insert each ngram from text into the ngrams dictionary
  158. def compute_ngrams_insert_dict(args, text, ngrams):
  159. words, positions = get_words(text)
  160. if len(words) < args.min_ngram_size:
  161. return
  162. if len(words) < args.max_ngram_size:
  163. insert_dict(words, ngrams, positions[0])
  164. for i in range(len(words) - args.max_ngram_size+1):
  165. insert_dict(words[i:i+args.max_ngram_size], ngrams, positions[i])
  166. # Build ngrams for the lambada dataset
  167. def process_task_lambda(args, task_file, ngrams):
  168. print(' reading from {} and computing ngrams'.format(task_file))
  169. with open(task_file, 'r') as f:
  170. for line in f:
  171. try:
  172. myjson = json.loads(line)
  173. text = myjson['text']
  174. compute_ngrams_insert_dict(args, text, ngrams)
  175. except Exception as e:
  176. print('Error:', e)
  177. print(" Entities in ngrams {}".format(len(ngrams)), flush=True)
  178. # Build ngrams for the dataset of the given task
  179. def process_task(args, task_name, ngrams):
  180. print(' reading from {} and computing ngrams'.format('import datasets'))
  181. print(" Current entities in ngrams {}".format(len(ngrams)), flush=True)
  182. # using validation/test data from datasets
  183. from datasets import load_dataset
  184. entities_in_ngrams = len(ngrams)
  185. # load the dataset
  186. if task_name == 'squad':
  187. dataset = load_dataset('squad_v2', split='validation')
  188. elif task_name == 'natural_questions':
  189. dataset = load_dataset('natural_questions', split='validation')
  190. elif task_name == 'triviaqa':
  191. dataset = load_dataset('trivia_qa', 'unfiltered', split='test')
  192. elif task_name == 'webqa':
  193. dataset = load_dataset('web_questions', split='test')
  194. elif task_name == 'race':
  195. dataset = load_dataset('race', 'all', split='test')
  196. elif task_name == 'drop':
  197. dataset = load_dataset('drop', split='validation')
  198. elif task_name == 'coqa':
  199. dataset = load_dataset('coqa', split='validation')
  200. elif task_name == 'piqa':
  201. dataset = load_dataset('piqa', split='test')
  202. else:
  203. print("Invalid task name: {}".format(task_name), flush=True)
  204. return
  205. # read the dataset and add to ngrams
  206. for line in dataset:
  207. try:
  208. if task_name in ['squad', 'triviaqa', 'webqa', 'race', 'drop']:
  209. text = line['question']
  210. compute_ngrams_insert_dict(args, text, ngrams)
  211. elif task_name == 'natural_questions':
  212. text = line['question']['text']
  213. compute_ngrams_insert_dict(args, text, ngrams)
  214. elif task_name == 'coqa':
  215. all_questions = line['questions']
  216. for question in all_questions:
  217. compute_ngrams_insert_dict(args, question, ngrams)
  218. elif task_name == 'piqa':
  219. text = line['goal']
  220. compute_ngrams_insert_dict(args, text, ngrams)
  221. except Exception as e:
  222. print('Error:', e)
  223. print(" After task {} entities in ngrams {}, added {}".format(task_name, \
  224. len(ngrams), len(ngrams) - entities_in_ngrams), flush=True)
  225. def compute_tasks_ngrams(args, ngrams):
  226. start_time = time.time()
  227. for _, task_name in enumerate(args.tasks):
  228. print('Task: {}'.format(task_name), flush=True)
  229. if task_name == 'lambada':
  230. assert args.lambada_path is not None
  231. process_task_lambda(args, args.lambada_path, ngrams)
  232. else:
  233. process_task(args, task_name, ngrams)
  234. print(" Taken time to compute ngrams {:.2f}".format(time.time() - \
  235. start_time), flush=True)
  236. def compute_ngram_freq_sorted(args, ngrams):
  237. ngrams_freq = {}
  238. for ngram_key in ngrams.keys():
  239. length = len(ngram_key.split())
  240. ngrams_freq[length] = ngrams_freq[length] + 1 if length in \
  241. ngrams_freq else 1
  242. ngrams_freq_sorted = sorted(ngrams_freq.items(), key=lambda item: item[0])
  243. print(" Ngram frequencies: {}".format(ngrams_freq_sorted), flush=True)
  244. print(" Entities in ngrams {} min_ngram_size {} max_ngram_size {}".format(\
  245. len(ngrams), ngrams_freq_sorted[0][0], ngrams_freq_sorted[len(\
  246. ngrams_freq_sorted) -1 ][0]), flush=True)
  247. return ngrams_freq_sorted
  248. def get_ngrams_below_threshold(args, ngrams, ngrams_below_threshold, \
  249. dedup_file, dedup_key, ngrams_freq_sorted):
  250. start_time = time.time()
  251. # get the ngrams frequency
  252. args.get_ngram_freq_only = True
  253. # Open the large file to process in parallel
  254. num_workers = args.num_threads
  255. pool = multiprocessing.Pool(num_workers)
  256. fin = open(dedup_file, 'r', encoding='utf-8')
  257. free_ngram_abt_partial=partial(free_ngram, args=args, key=dedup_key, \
  258. ngrams=ngrams, ngrams_freq_sorted=ngrams_freq_sorted)
  259. free_ngrams_abt = pool.imap(free_ngram_abt_partial, fin, 500)
  260. counter = 0
  261. for _, _, _, local_ngram in free_ngrams_abt:
  262. counter += 1
  263. if counter % 1000 == 0:
  264. print(' [compute_stat]> processed {} documents in {:.2f} seconds ...'.
  265. format(counter, time.time() - start_time), flush=True)
  266. for local_key in local_ngram:
  267. if local_key in ngrams:
  268. ngrams[local_key] += 1
  269. local_ngram = {}
  270. print(' Time taken to compute statistics {:.2f} seconds'.format(time.time() - \
  271. start_time), flush=True)
  272. pool.close()
  273. pool.join()
  274. start_time = time.time()
  275. counter_threshold = 0
  276. # Get ngram below theadhold
  277. for local_key, local_val in ngrams.items():
  278. if ngrams[local_key] < args.key_threshold:
  279. print(" [threshold] {} {}".format(local_key, local_val), flush=True)
  280. counter_threshold += 1
  281. ngrams_below_threshold[local_key] = 1
  282. print(' Ngrams below threshold {}'.format(counter_threshold), flush=True)
  283. fin.close()
  284. def clean_ngrams_below_threshold(args, ngrams_below_threshold, dedup_file, \
  285. dedup_key):
  286. start_time = time.time()
  287. # Now actually filter the dataset
  288. args.get_ngram_freq_only = False
  289. #id_prefix = '-'.join(args.tasks[::2])
  290. id_prefix = '-'.join(args.tasks[::1])
  291. # get the range of the size of the ngrams
  292. ngrams_freq_sorted = compute_ngram_freq_sorted(args, ngrams_below_threshold)
  293. # Open the large file to process in parallel
  294. counter = splitted = ignored = split_mt_thld = trimmed_count = 0
  295. num_workers = args.num_threads
  296. pool = multiprocessing.Pool(num_workers)
  297. fin = open(dedup_file, 'r', encoding='utf-8')
  298. free_ngram_clean_partial=partial(free_ngram, args=args, key=dedup_key, \
  299. ngrams=ngrams_below_threshold, ngrams_freq_sorted=ngrams_freq_sorted)
  300. free_ngrams_clean = pool.imap(free_ngram_clean_partial, fin, 500)
  301. out_f = open(args.output, 'wb')
  302. for text_buf_ngram_free, trimmed, myjson, _ in free_ngrams_clean:
  303. counter += 1
  304. try:
  305. trimmed_count += trimmed
  306. if len(text_buf_ngram_free) > 1:
  307. splitted += 1
  308. if len(text_buf_ngram_free) == 0:
  309. ignored += 1
  310. # more than 10 splits ignored
  311. if len(text_buf_ngram_free) > args.splits_count:
  312. text_buf_ngram_free = []
  313. split_mt_thld += 1
  314. if args.output is not None:
  315. if "split_id" in myjson:
  316. use_prefix = myjson["split_id"] + "-"
  317. else:
  318. use_prefix = ""
  319. for i in range(len(text_buf_ngram_free)):
  320. split_id_string = id_prefix + '-{:010d}'.format(int(\
  321. counter)) + '-{:04d}'.format(int(i))
  322. myjson[dedup_key] = text_buf_ngram_free[i]
  323. myjson["split_id"] = use_prefix + split_id_string
  324. outjson = json.dumps(myjson, ensure_ascii=False)
  325. #outjson = json.dumps({"text":text_buf_ngram_free[i],
  326. # id_prefix+"_split_id":split_id_string},
  327. # ensure_ascii=False)
  328. out_f.write(outjson.encode('utf-8'))
  329. out_f.write('\n'.encode('utf-8'))
  330. if counter % 1000 == 0:
  331. print(' [final]> processed {} documents in {:.2f} seconds ...'.
  332. format(counter, time.time() - start_time), flush=True)
  333. except Exception as e:
  334. print('Error:', e)
  335. print(' [final]> processed {} documents in {:.2f} seconds ...'.
  336. format(counter, time.time() - start_time), flush=True)
  337. print(' Total docs {} splitted {} ignored {} splits > theshold {} trimmed'\
  338. ' {}'.format(counter, splitted, ignored, split_mt_thld, trimmed_count)\
  339. , flush=True)
  340. pool.close()
  341. pool.join()
  342. out_f.close()
  343. fin.close()
  344. if __name__ == '__main__':
  345. # we use 13-grams, any text less than 200 characters got removed
  346. # any text splitted more than 10 got removed as well
  347. print('parsing the arguments ...')
  348. parser = argparse.ArgumentParser()
  349. parser.add_argument('--tasks', nargs = '*', required=True, default=None, \
  350. help = 'Tasks to use for deduplication: currently '
  351. ' suuport [lambada, squad, natural_questions,'
  352. ' triviaqa, webqa, race, drop, coqa, and piqa]')
  353. parser.add_argument('--lambada-path', type=str, default=None,
  354. help='Only Lambada task needs the path')
  355. parser.add_argument('--dedup-dataset', nargs = '*', default=None,
  356. help='Dataset to deduplicate with the key to use'
  357. ' e.g. cc.json text')
  358. parser.add_argument('--output', type=str, default=None,
  359. help='Output file name to save dedup dataset')
  360. parser.add_argument('--num-threads', type=int, default=40,
  361. help='Number of threads to use')
  362. # Default dedup values
  363. parser.add_argument('--max-ngram-size', type=int, default=13,
  364. help='Maximum size of ngram to use.')
  365. parser.add_argument('--min-ngram-size', type=int, default=8,
  366. help='Minimum size of ngram to use.')
  367. parser.add_argument('--filter-text-char-len', type=int, default=200,
  368. help='Remove any text below this length.')
  369. parser.add_argument('--key-threshold', type=int, default=10,
  370. help='Number of keys to consider as threshold')
  371. parser.add_argument('--save-dictionary', type=str, default=None,
  372. help='Save the dictionary')
  373. parser.add_argument('--load-dictionary', type=str, default=None,
  374. help='Load the dictionary')
  375. parser.add_argument('--splits-count', type=int, default=10,
  376. help='Remove any documents more than this many splits')
  377. parser.add_argument('--remove-char-each-side', type=int, default=200,
  378. help='Maximum size of ngram to use.')
  379. args = parser.parse_args()
  380. assert len(args.dedup_dataset) == 2
  381. dedup_file = args.dedup_dataset[0]
  382. dedup_key = args.dedup_dataset[1]
  383. # Setup multi-processing
  384. num_workers = args.num_threads
  385. if args.load_dictionary is None:
  386. # Build ngrams
  387. ngrams = {}
  388. compute_tasks_ngrams(args, ngrams)
  389. # get the range of the size of the ngrams
  390. ngrams_freq_sorted = compute_ngram_freq_sorted(args, ngrams)
  391. # get ngram freq from large file in parallel
  392. # get ngrams below threshold
  393. ngrams_below_threshold = {}
  394. get_ngrams_below_threshold(args, ngrams, ngrams_below_threshold, \
  395. dedup_file, dedup_key, ngrams_freq_sorted)
  396. # save the dictionary if needed
  397. if args.save_dictionary is not None:
  398. with open(args.save_dictionary, 'wb') as save_dict_handle:
  399. pickle.dump(ngrams_below_threshold, save_dict_handle)
  400. else:
  401. with open(args.load_dictionary, 'rb') as load_dict_handle:
  402. ngrams_below_threshold = pickle.load(load_dict_handle)
  403. # filter the large file
  404. if args.output is not None:
  405. clean_ngrams_below_threshold(args, ngrams_below_threshold, \
  406. dedup_file, dedup_key)
  407. print('done :-)')