123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493 |
- # coding=utf-8
- # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """
- Deduplicate downstream tasks from training dataset. 13-grams have been used.
- All split documents with less than 200 characters got filtered. Any document
- with more than 10 splits got filtered as well.
- """
- import argparse
- from functools import partial
- import json
- import multiprocessing
- import nltk
- import pickle
- import re
- import string
- import sys
- import time
- def get_words(text):
- # get all the lowercase words from text
- words, positions = [], []
- for match in re.finditer(r'\w+', text.lower()):
- words.append(match.group(0))
- positions.append(match.start())
- return words, positions
- # splits the text
- def split_text(text, start_position, remove_char_each_side, seq):
- # first part of the text
- punctuations = ".!?"
- pos = start_position - remove_char_each_side
- text_first = ""
- while pos > 0 and not text[pos] in punctuations:
- pos -= 1
- if pos > 0:
- text_first = text[0:pos+1]
- # add length of seq and remove_char_each_side
- pos = start_position + len(seq) + remove_char_each_side
- # last part of the text
- text_second = ""
- while pos < len(text) and not text[pos] in punctuations:
- pos += 1
- if pos + 1 < len(text):
- text_second = text[pos+1:len(text)]
- return text_first, text_second
- def check_and_clean_text(args, words, ngrams, text, start_position, \
- text_buf_ngram_free, text_buf, local_ngram):
- seq = " ".join(words)
- if seq in ngrams:
- print(" [matched]: {}".format(seq), flush=True)
- if args.get_ngram_freq_only:
- # increase freq of this seq and then only consider the later part
- # of the text for further processing
- if seq in local_ngram:
- local_ngram[seq] += 1
- else:
- local_ngram[seq] = 1
- #print(" [increased]: {} {}".format(seq, ngrams[seq]), flush=True)
- if (start_position + len(seq) + 1) < len(text):
- text_buf.append(text[start_position + len(seq) + 1:len(text)])
- return False
- # split the text
- text_first, text_second = split_text(text, start_position, \
- args.remove_char_each_side, seq)
- # first part of ngrams free
- if len(text_first) > args.filter_text_char_len:
- text_buf_ngram_free.append(text_first)
- # add second part for further processing
- if len(text_second) > args.filter_text_char_len:
- text_buf.append(text_second)
- return False # not ngram free
- # ngram free
- return True
- def free_ngram(line, args, key, ngrams, ngrams_freq_sorted):
- # remove all the ngrams
- try:
- myjson = json.loads(line)
- text_buf = [myjson[key]]
- except Exception as e:
- print("Error: {}".format(e), flush=True)
- text_buf = []
- text_buf_ngram_free = []
- local_ngram = {}
- while len(text_buf) > 0:
- # get the first one from the buffer
- text = text_buf.pop(0)
- words, positions = get_words(text)
-
- ngram_free = True
- # find each max n-grams and check dictionary
- for i in range(len(words) - args.max_ngram_size + 1):
- check_ngram_free = check_and_clean_text(args, words[i:\
- i+args.max_ngram_size], ngrams, text, positions[i], \
- text_buf_ngram_free, text_buf, local_ngram)
- # the seq is ngram free? if yes, break
- if not check_ngram_free:
- ngram_free = False
- break
- # if max ngrams doesn't match, check if any other lower n-grams
- # within max ngram macthes
- for ngram_len, _ in ngrams_freq_sorted:
- check_ngram_free = check_and_clean_text(args, words[i:\
- i+ngram_len], ngrams, text, positions[i], \
- text_buf_ngram_free, text_buf, local_ngram)
- # same check as above
- if not check_ngram_free:
- ngram_free = False
- break
- # check break from lower than max ngram loop above
- if not ngram_free:
- break
- # for the last max n-gram, check all the lower ngrams in it
- if ngram_free and len(words) - args.max_ngram_size > 0:
- # get the last words of the lax max ngram
- last_seq_words = words[(len(words)-args.max_ngram_size):len(words)]
- last_seq_start_position = len(words) - args.max_ngram_size
- # check all n-grams lower than the max
- for pos, (ngram_len, _) in enumerate(ngrams_freq_sorted):
- # ignore the max ngram as has been considered already
- if ngram_len == args.max_ngram_size:
- continue
- # find each ngram of ngram_len in max n-grams and check
- for i in range(len(last_seq_words) - ngram_len + 1):
- check_ngram_free = check_and_clean_text(args, \
- last_seq_words[i:i+ngram_len], ngrams, text,\
- positions[last_seq_start_position+i], \
- text_buf_ngram_free, text_buf, local_ngram)
- if not check_ngram_free:
- ngram_free = False
- break
- if not ngram_free:
- break
- # texts are ngram free
- if ngram_free and not args.get_ngram_freq_only:
- text_buf_ngram_free.append(text)
- # check if the text has only been trimmed
- trimmed = 0
- if not args.get_ngram_freq_only and len(text_buf_ngram_free) == 1 and \
- len(text_buf_ngram_free[0]) < len(myjson[key]):
- trimmed = 1
- return text_buf_ngram_free, trimmed, myjson, local_ngram
- # insert word sequence into dictionary
- def insert_dict(words, ngrams, pos):
- seq = " ".join(words)
- if seq not in ngrams:
- ngrams[seq] = 0
- #ngrams[seq] = pos
- # insert each ngram from text into the ngrams dictionary
- def compute_ngrams_insert_dict(args, text, ngrams):
- words, positions = get_words(text)
- if len(words) < args.min_ngram_size:
- return
- if len(words) < args.max_ngram_size:
- insert_dict(words, ngrams, positions[0])
- for i in range(len(words) - args.max_ngram_size+1):
- insert_dict(words[i:i+args.max_ngram_size], ngrams, positions[i])
- # Build ngrams for the lambada dataset
- def process_task_lambda(args, task_file, ngrams):
- print(' reading from {} and computing ngrams'.format(task_file))
- with open(task_file, 'r') as f:
- for line in f:
- try:
- myjson = json.loads(line)
- text = myjson['text']
- compute_ngrams_insert_dict(args, text, ngrams)
- except Exception as e:
- print('Error:', e)
- print(" Entities in ngrams {}".format(len(ngrams)), flush=True)
- # Build ngrams for the dataset of the given task
- def process_task(args, task_name, ngrams):
- print(' reading from {} and computing ngrams'.format('import datasets'))
- print(" Current entities in ngrams {}".format(len(ngrams)), flush=True)
- # using validation/test data from datasets
- from datasets import load_dataset
- entities_in_ngrams = len(ngrams)
- # load the dataset
- if task_name == 'squad':
- dataset = load_dataset('squad_v2', split='validation')
- elif task_name == 'natural_questions':
- dataset = load_dataset('natural_questions', split='validation')
- elif task_name == 'triviaqa':
- dataset = load_dataset('trivia_qa', 'unfiltered', split='test')
- elif task_name == 'webqa':
- dataset = load_dataset('web_questions', split='test')
- elif task_name == 'race':
- dataset = load_dataset('race', 'all', split='test')
- elif task_name == 'drop':
- dataset = load_dataset('drop', split='validation')
- elif task_name == 'coqa':
- dataset = load_dataset('coqa', split='validation')
- elif task_name == 'piqa':
- dataset = load_dataset('piqa', split='test')
- else:
- print("Invalid task name: {}".format(task_name), flush=True)
- return
- # read the dataset and add to ngrams
- for line in dataset:
- try:
- if task_name in ['squad', 'triviaqa', 'webqa', 'race', 'drop']:
- text = line['question']
- compute_ngrams_insert_dict(args, text, ngrams)
- elif task_name == 'natural_questions':
- text = line['question']['text']
- compute_ngrams_insert_dict(args, text, ngrams)
- elif task_name == 'coqa':
- all_questions = line['questions']
- for question in all_questions:
- compute_ngrams_insert_dict(args, question, ngrams)
- elif task_name == 'piqa':
- text = line['goal']
- compute_ngrams_insert_dict(args, text, ngrams)
- except Exception as e:
- print('Error:', e)
- print(" After task {} entities in ngrams {}, added {}".format(task_name, \
- len(ngrams), len(ngrams) - entities_in_ngrams), flush=True)
- def compute_tasks_ngrams(args, ngrams):
- start_time = time.time()
- for _, task_name in enumerate(args.tasks):
- print('Task: {}'.format(task_name), flush=True)
- if task_name == 'lambada':
- assert args.lambada_path is not None
- process_task_lambda(args, args.lambada_path, ngrams)
- else:
- process_task(args, task_name, ngrams)
- print(" Taken time to compute ngrams {:.2f}".format(time.time() - \
- start_time), flush=True)
- def compute_ngram_freq_sorted(args, ngrams):
- ngrams_freq = {}
- for ngram_key in ngrams.keys():
- length = len(ngram_key.split())
- ngrams_freq[length] = ngrams_freq[length] + 1 if length in \
- ngrams_freq else 1
- ngrams_freq_sorted = sorted(ngrams_freq.items(), key=lambda item: item[0])
- print(" Ngram frequencies: {}".format(ngrams_freq_sorted), flush=True)
- print(" Entities in ngrams {} min_ngram_size {} max_ngram_size {}".format(\
- len(ngrams), ngrams_freq_sorted[0][0], ngrams_freq_sorted[len(\
- ngrams_freq_sorted) -1 ][0]), flush=True)
- return ngrams_freq_sorted
- def get_ngrams_below_threshold(args, ngrams, ngrams_below_threshold, \
- dedup_file, dedup_key, ngrams_freq_sorted):
- start_time = time.time()
- # get the ngrams frequency
- args.get_ngram_freq_only = True
-
- # Open the large file to process in parallel
- num_workers = args.num_threads
- pool = multiprocessing.Pool(num_workers)
- fin = open(dedup_file, 'r', encoding='utf-8')
- free_ngram_abt_partial=partial(free_ngram, args=args, key=dedup_key, \
- ngrams=ngrams, ngrams_freq_sorted=ngrams_freq_sorted)
- free_ngrams_abt = pool.imap(free_ngram_abt_partial, fin, 500)
-
- counter = 0
- for _, _, _, local_ngram in free_ngrams_abt:
- counter += 1
- if counter % 1000 == 0:
- print(' [compute_stat]> processed {} documents in {:.2f} seconds ...'.
- format(counter, time.time() - start_time), flush=True)
- for local_key in local_ngram:
- if local_key in ngrams:
- ngrams[local_key] += 1
- local_ngram = {}
- print(' Time taken to compute statistics {:.2f} seconds'.format(time.time() - \
- start_time), flush=True)
- pool.close()
- pool.join()
- start_time = time.time()
- counter_threshold = 0
- # Get ngram below theadhold
- for local_key, local_val in ngrams.items():
- if ngrams[local_key] < args.key_threshold:
- print(" [threshold] {} {}".format(local_key, local_val), flush=True)
- counter_threshold += 1
- ngrams_below_threshold[local_key] = 1
-
- print(' Ngrams below threshold {}'.format(counter_threshold), flush=True)
- fin.close()
- def clean_ngrams_below_threshold(args, ngrams_below_threshold, dedup_file, \
- dedup_key):
- start_time = time.time()
- # Now actually filter the dataset
- args.get_ngram_freq_only = False
- #id_prefix = '-'.join(args.tasks[::2])
- id_prefix = '-'.join(args.tasks[::1])
- # get the range of the size of the ngrams
- ngrams_freq_sorted = compute_ngram_freq_sorted(args, ngrams_below_threshold)
- # Open the large file to process in parallel
- counter = splitted = ignored = split_mt_thld = trimmed_count = 0
- num_workers = args.num_threads
- pool = multiprocessing.Pool(num_workers)
- fin = open(dedup_file, 'r', encoding='utf-8')
- free_ngram_clean_partial=partial(free_ngram, args=args, key=dedup_key, \
- ngrams=ngrams_below_threshold, ngrams_freq_sorted=ngrams_freq_sorted)
- free_ngrams_clean = pool.imap(free_ngram_clean_partial, fin, 500)
-
- out_f = open(args.output, 'wb')
- for text_buf_ngram_free, trimmed, myjson, _ in free_ngrams_clean:
- counter += 1
- try:
- trimmed_count += trimmed
- if len(text_buf_ngram_free) > 1:
- splitted += 1
- if len(text_buf_ngram_free) == 0:
- ignored += 1
- # more than 10 splits ignored
- if len(text_buf_ngram_free) > args.splits_count:
- text_buf_ngram_free = []
- split_mt_thld += 1
- if args.output is not None:
- if "split_id" in myjson:
- use_prefix = myjson["split_id"] + "-"
- else:
- use_prefix = ""
- for i in range(len(text_buf_ngram_free)):
- split_id_string = id_prefix + '-{:010d}'.format(int(\
- counter)) + '-{:04d}'.format(int(i))
- myjson[dedup_key] = text_buf_ngram_free[i]
- myjson["split_id"] = use_prefix + split_id_string
- outjson = json.dumps(myjson, ensure_ascii=False)
- #outjson = json.dumps({"text":text_buf_ngram_free[i],
- # id_prefix+"_split_id":split_id_string},
- # ensure_ascii=False)
- out_f.write(outjson.encode('utf-8'))
- out_f.write('\n'.encode('utf-8'))
- if counter % 1000 == 0:
- print(' [final]> processed {} documents in {:.2f} seconds ...'.
- format(counter, time.time() - start_time), flush=True)
- except Exception as e:
- print('Error:', e)
- print(' [final]> processed {} documents in {:.2f} seconds ...'.
- format(counter, time.time() - start_time), flush=True)
-
- print(' Total docs {} splitted {} ignored {} splits > theshold {} trimmed'\
- ' {}'.format(counter, splitted, ignored, split_mt_thld, trimmed_count)\
- , flush=True)
- pool.close()
- pool.join()
- out_f.close()
- fin.close()
- if __name__ == '__main__':
- # we use 13-grams, any text less than 200 characters got removed
- # any text splitted more than 10 got removed as well
- print('parsing the arguments ...')
- parser = argparse.ArgumentParser()
- parser.add_argument('--tasks', nargs = '*', required=True, default=None, \
- help = 'Tasks to use for deduplication: currently '
- ' suuport [lambada, squad, natural_questions,'
- ' triviaqa, webqa, race, drop, coqa, and piqa]')
- parser.add_argument('--lambada-path', type=str, default=None,
- help='Only Lambada task needs the path')
- parser.add_argument('--dedup-dataset', nargs = '*', default=None,
- help='Dataset to deduplicate with the key to use'
- ' e.g. cc.json text')
- parser.add_argument('--output', type=str, default=None,
- help='Output file name to save dedup dataset')
- parser.add_argument('--num-threads', type=int, default=40,
- help='Number of threads to use')
- # Default dedup values
- parser.add_argument('--max-ngram-size', type=int, default=13,
- help='Maximum size of ngram to use.')
- parser.add_argument('--min-ngram-size', type=int, default=8,
- help='Minimum size of ngram to use.')
- parser.add_argument('--filter-text-char-len', type=int, default=200,
- help='Remove any text below this length.')
- parser.add_argument('--key-threshold', type=int, default=10,
- help='Number of keys to consider as threshold')
- parser.add_argument('--save-dictionary', type=str, default=None,
- help='Save the dictionary')
- parser.add_argument('--load-dictionary', type=str, default=None,
- help='Load the dictionary')
- parser.add_argument('--splits-count', type=int, default=10,
- help='Remove any documents more than this many splits')
- parser.add_argument('--remove-char-each-side', type=int, default=200,
- help='Maximum size of ngram to use.')
- args = parser.parse_args()
- assert len(args.dedup_dataset) == 2
- dedup_file = args.dedup_dataset[0]
- dedup_key = args.dedup_dataset[1]
- # Setup multi-processing
- num_workers = args.num_threads
- if args.load_dictionary is None:
- # Build ngrams
- ngrams = {}
- compute_tasks_ngrams(args, ngrams)
- # get the range of the size of the ngrams
- ngrams_freq_sorted = compute_ngram_freq_sorted(args, ngrams)
- # get ngram freq from large file in parallel
- # get ngrams below threshold
- ngrams_below_threshold = {}
- get_ngrams_below_threshold(args, ngrams, ngrams_below_threshold, \
- dedup_file, dedup_key, ngrams_freq_sorted)
- # save the dictionary if needed
- if args.save_dictionary is not None:
- with open(args.save_dictionary, 'wb') as save_dict_handle:
- pickle.dump(ngrams_below_threshold, save_dict_handle)
- else:
- with open(args.load_dictionary, 'rb') as load_dict_handle:
- ngrams_below_threshold = pickle.load(load_dict_handle)
- # filter the large file
- if args.output is not None:
- clean_ngrams_below_threshold(args, ngrams_below_threshold, \
- dedup_file, dedup_key)
- print('done :-)')
|