123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306 |
- # coding=utf-8
- # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import argparse
- from functools import partial
- import itertools
- import json
- from lsh import cache, minhash
- import multiprocessing
- import numpy as np
- import time
- import pickle
- import sys
- import os
- # This function is adapted from:
- # https://github.com/mattilyra/LSH/blob/master/examples/Introduction.ipynb
- def shingles(text, char_ngram=5):
- return set(text[head:head + char_ngram]
- for head in range(0, len(text) - char_ngram))
- # This function is adapted from:
- # https://github.com/mattilyra/LSH/blob/master/examples/Introduction.ipynb
- def jaccard(set_a, set_b, args):
- if len(set_a) < 1 or len(set_b) < 1:
- return 0.0
- intersection = set_a & set_b
- union = set_a | set_b
- if args.jaccard == 'min':
- return len(intersection) / min(len(set_a), len(set_b))
- elif args.jaccard == 'max':
- return len(intersection) / max(len(set_a), len(set_b))
- else:
- return len(intersection) / len(union)
- def compute_fingerprint(line, key):
- try:
- myjson = json.loads(line)
- url = myjson[key]
- text = myjson['text']
- fingerprint = hasher.fingerprint(text)
- except Exception as e:
- print('Error:', e)
- return None, None, None, False
- return url, text, fingerprint, True
- def url_pairs_to_remove(args, bucket_urls, url_doc):
- remove_urls_list = []
- deduped_local, counter_local = 0, 0
- iteration = 0
- while len(bucket_urls) > 1:
- if args.heuristic_iter != -1 and \
- iteration == args.heuristic_iter:
- break
- items = list(bucket_urls)
- remove_urls = []
- main_url = items[np.random.randint(0, len(items))]
- main_dhingles = shingles(url_doc[main_url])
- for i in range(0, len(items)):
- counter_local += 1
- other_url = items[i]
- if other_url == main_url:
- continue
- other_shingles = shingles(url_doc[other_url])
- try:
- jaccard_sim = jaccard(main_dhingles, other_shingles, args)
- except Exception as e:
- print('Error:', e)
- jaccard_sim = 0.0
- if jaccard_sim > 0.5:
- remove_urls.append({other_url: jaccard_sim})
- deduped_local += 1
- bucket_urls.remove(other_url)
- bucket_urls.remove(main_url)
- if len(remove_urls) > 0:
- remove_urls_list.append({main_url: remove_urls})
- iteration += 1
- return remove_urls_list, deduped_local, counter_local
- def write_remove_urls_list(remove_urls_list, f_out):
- if len(remove_urls_list) > 0:
- for each_url_remove in remove_urls_list:
- myjson = json.dumps(each_url_remove, ensure_ascii=False)
- f_out.write(myjson.encode('utf-8'))
- f_out.write('\n'.encode('utf-8'))
- def compute_jaccard(each_bin, num_bins, start_time_local):
- remove_urls_list = []
- deduped_local, counter_local, bucket_local = 0, 0, 0
- for bucket_id in each_bin:
- bucket_local += 1
- if os.getpid() % num_bins == 0 and bucket_local % 100000 == 0:
- print("Counter {}, progress {:.2f} time {:.2f}".\
- format(bucket_local, float(bucket_local)/float(len(each_bin)),\
- time.time() - start_time_local), flush=True)
- if len(each_bin[bucket_id]) <= 1:
- continue
- bucket_urls = each_bin[bucket_id].copy()
- remove_urls_list_sub, deduped_local_sub, counter_local_sub = \
- url_pairs_to_remove(args, bucket_urls, url_doc)
- deduped_local += deduped_local_sub
- counter_local += counter_local_sub
- if len(remove_urls_list_sub) > 0:
- remove_urls_list.extend(remove_urls_list_sub)
- return remove_urls_list, deduped_local, counter_local
- def find_pair_urls_parallel(args, lshcache, url_doc):
- start_time = time.time()
- f_out = open(args.output, 'wb')
- deduped, counter = 0, 0
- # compute jaccards of buckets in bin in parallel (parallelism
- # limited to # of bins)
- num_bins = len(lshcache.bins)
- pool = multiprocessing.Pool(num_bins)
- compute_jaccard_partial = partial(compute_jaccard, num_bins=num_bins, \
- start_time_local=start_time)
- # don't need to pass args and url_doc as they are already shared
- compute_jaccard_iter = pool.imap(compute_jaccard_partial, lshcache.bins)
- print("multiprocessing init took {:.2f}".format(time.time() - start_time),\
- flush=True)
- for remove_urls_list, deduped_local, counter_local in compute_jaccard_iter:
- deduped += deduped_local
- counter += counter_local
- write_remove_urls_list(remove_urls_list, f_out)
- print(' [write]> processed {} documents in {:.2f} '
- 'seoncds and deduped {} documents ...'.format(counter, time.time()\
- - start_time, deduped), flush=True)
- pool.close()
- pool.join()
- f_out.close()
- print(' Taken time for jaccard similariries {:.2f} seconds'.format(\
- time.time() - start_time), flush=True)
- def find_pair_urls_sequential(args, lshcache, url_doc):
- start_time = time.time()
- f_out = open(args.output, 'wb')
- deduped, counter = 0, 0
- for b in lshcache.bins:
- for bucket_id in b:
- if len(b[bucket_id]) <= 1:
- continue
- bucket_urls = b[bucket_id].copy()
- remove_urls_list_sub, deduped_local_sub, counter_local_sub = \
- url_pairs_to_remove(args, bucket_urls, url_doc)
- deduped += deduped_local_sub
- counter += counter_local_sub
- write_remove_urls_list(remove_urls_list_sub, f_out)
- if counter % 10000 == 0:
- print(' [write]> processed {} documents in {:.2f} '
- 'seoncds and deduped {} documents ...'.
- format(counter, time.time() - start_time,
- deduped), flush=True)
- f_out.close()
- print(' [write]> processed {} documents in {:.2f} '
- 'seoncds and deduped {} documents ...'.
- format(counter, time.time() - start_time,
- deduped), flush=True)
- if __name__ == '__main__':
- print('parsing the arguments ...')
- parser = argparse.ArgumentParser()
- parser.add_argument('--seed', type=int, default=1234,
- help='Random seed used for python, numpy')
- parser.add_argument('--inputs', nargs = '*', default=None, help = \
- 'Pairwise list of the input files and keys, '
- 'e.g. --inputs cc.json cc_id news.json news_id')
- parser.add_argument('--load-fingerprints', nargs = '*', default=None,
- help='Load fingerprints from a list of pickle files,'
- ' e.g. cc.pkl news.pkl')
- parser.add_argument('--save-fingerprints', type=str, default=None,
- help='Save the fingerprints of the inputs.')
- parser.add_argument('--output', type=str, default=None,
- help='Output file name that consists of all ids'
- ' with matching similarities')
- parser.add_argument('--jaccard', type=str, default='union',
- choices=['union', 'min', 'max'], help='Jaccard'\
- ' similarity computation')
- parser.add_argument('--heuristic-iter', type=int, default=1,
- help='Number of iterations to run the heuristics'
- ': use -1 for exact')
- parser.add_argument('--num-bands', type=int, default=10,
- help='Number of bands to use in cache')
- parser.add_argument('--num-seeds', type=int, default=100,
- help='Number of seeds to use for minhash. Note that'
- ' this value should be divisible by num-bands')
- parser.add_argument('--jaccard-parallel', action='store_true',
- help='Use this to process large number of documents.')
- args = parser.parse_args()
- print('finding possible duplicate content ...')
- # set seed and get an array of seeds of 100 integers
- np.random.seed(args.seed)
- seeds = np.random.randint(0, 1e6, size=args.num_seeds)
- # initialize minhash and lsh cache
- hasher = minhash.MinHasher(seeds=seeds, char_ngram=5, hashbytes=4)
- lshcache = cache.Cache(num_bands=args.num_bands, hasher=hasher)
- url_doc = {}
- # load fingerprints from pickle file if needed
- if args.load_fingerprints is not None:
- for count_fp, fp_file_name in enumerate(args.load_fingerprints):
- print("Loading fingerprints from pickle file {}".format(
- fp_file_name), flush=True)
- fp = open(fp_file_name, "rb")
- if count_fp == 0:
- # assign directory for the first pkl
- lshcache = pickle.load(fp)
- url_doc = pickle.load(fp)
- else:
- # append these to lshcache and url_doc
- local_lshcache = pickle.load(fp)
- local_url_doc = pickle.load(fp)
- for url in local_lshcache.fingerprints.keys():
- url_doc[url] = local_url_doc[url]
- lshcache.add_fingerprint(local_lshcache.fingerprints[url], url)
- fp.close()
- counter = 0
- start_time = time.time()
- # compute finger prints of the inputs if any
- # input file and the key to use as id
- if args.inputs is not None:
- print("Computing fingerprints", flush=True)
- assert len(args.inputs) % 2 == 0
- for input_file, key in zip(args.inputs[::2], args.inputs[1::2]):
- print(' document processing {} with key {}'.format(input_file, key),
- flush=True)
- # compute fingerprints in parallel
- num_workers = 40
- pool = multiprocessing.Pool(num_workers)
- fin = open(input_file, 'r', encoding='utf-8')
- compute_fingerprint_partial = partial(compute_fingerprint, key=key)
- compute_fingerprint_iter = pool.imap(compute_fingerprint_partial,
- fin, 512)
- # traverse all the texts and add fingerprints
- for url, text, fingerprint, flag in compute_fingerprint_iter:
- counter += 1
- if flag:
- url_doc[url] = text
- lshcache.add_fingerprint(fingerprint, url)
- if counter % 10000 == 0:
- print(' [read]> processed {} documents in {:.2f} '
- 'seconds ...'.format(counter, time.time() - \
- start_time), flush=True)
- fin.close()
- pool.close()
- pool.join()
- # Save the fingerprints if needed
- if args.save_fingerprints is not None:
- print("Saving fingerprints to pickle file {}".format(
- args.save_fingerprints), flush=True)
- with open(args.save_fingerprints, 'wb') as f_save:
- pickle.dump(lshcache, f_save)
- pickle.dump(url_doc, f_save)
- # compute jaccard index of the input texts and write to file if needed
- if args.output is not None:
- print("Compute jaccard similarity", flush=True)
- if args.jaccard_parallel:
- find_pair_urls_parallel(args, lshcache, url_doc)
- else:
- find_pair_urls_sequential(args, lshcache, url_doc)
- print('done :-)')
-
|