find_duplicates.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306
  1. # coding=utf-8
  2. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import argparse
  16. from functools import partial
  17. import itertools
  18. import json
  19. from lsh import cache, minhash
  20. import multiprocessing
  21. import numpy as np
  22. import time
  23. import pickle
  24. import sys
  25. import os
  26. # This function is adapted from:
  27. # https://github.com/mattilyra/LSH/blob/master/examples/Introduction.ipynb
  28. def shingles(text, char_ngram=5):
  29. return set(text[head:head + char_ngram]
  30. for head in range(0, len(text) - char_ngram))
  31. # This function is adapted from:
  32. # https://github.com/mattilyra/LSH/blob/master/examples/Introduction.ipynb
  33. def jaccard(set_a, set_b, args):
  34. if len(set_a) < 1 or len(set_b) < 1:
  35. return 0.0
  36. intersection = set_a & set_b
  37. union = set_a | set_b
  38. if args.jaccard == 'min':
  39. return len(intersection) / min(len(set_a), len(set_b))
  40. elif args.jaccard == 'max':
  41. return len(intersection) / max(len(set_a), len(set_b))
  42. else:
  43. return len(intersection) / len(union)
  44. def compute_fingerprint(line, key):
  45. try:
  46. myjson = json.loads(line)
  47. url = myjson[key]
  48. text = myjson['text']
  49. fingerprint = hasher.fingerprint(text)
  50. except Exception as e:
  51. print('Error:', e)
  52. return None, None, None, False
  53. return url, text, fingerprint, True
  54. def url_pairs_to_remove(args, bucket_urls, url_doc):
  55. remove_urls_list = []
  56. deduped_local, counter_local = 0, 0
  57. iteration = 0
  58. while len(bucket_urls) > 1:
  59. if args.heuristic_iter != -1 and \
  60. iteration == args.heuristic_iter:
  61. break
  62. items = list(bucket_urls)
  63. remove_urls = []
  64. main_url = items[np.random.randint(0, len(items))]
  65. main_dhingles = shingles(url_doc[main_url])
  66. for i in range(0, len(items)):
  67. counter_local += 1
  68. other_url = items[i]
  69. if other_url == main_url:
  70. continue
  71. other_shingles = shingles(url_doc[other_url])
  72. try:
  73. jaccard_sim = jaccard(main_dhingles, other_shingles, args)
  74. except Exception as e:
  75. print('Error:', e)
  76. jaccard_sim = 0.0
  77. if jaccard_sim > 0.5:
  78. remove_urls.append({other_url: jaccard_sim})
  79. deduped_local += 1
  80. bucket_urls.remove(other_url)
  81. bucket_urls.remove(main_url)
  82. if len(remove_urls) > 0:
  83. remove_urls_list.append({main_url: remove_urls})
  84. iteration += 1
  85. return remove_urls_list, deduped_local, counter_local
  86. def write_remove_urls_list(remove_urls_list, f_out):
  87. if len(remove_urls_list) > 0:
  88. for each_url_remove in remove_urls_list:
  89. myjson = json.dumps(each_url_remove, ensure_ascii=False)
  90. f_out.write(myjson.encode('utf-8'))
  91. f_out.write('\n'.encode('utf-8'))
  92. def compute_jaccard(each_bin, num_bins, start_time_local):
  93. remove_urls_list = []
  94. deduped_local, counter_local, bucket_local = 0, 0, 0
  95. for bucket_id in each_bin:
  96. bucket_local += 1
  97. if os.getpid() % num_bins == 0 and bucket_local % 100000 == 0:
  98. print("Counter {}, progress {:.2f} time {:.2f}".\
  99. format(bucket_local, float(bucket_local)/float(len(each_bin)),\
  100. time.time() - start_time_local), flush=True)
  101. if len(each_bin[bucket_id]) <= 1:
  102. continue
  103. bucket_urls = each_bin[bucket_id].copy()
  104. remove_urls_list_sub, deduped_local_sub, counter_local_sub = \
  105. url_pairs_to_remove(args, bucket_urls, url_doc)
  106. deduped_local += deduped_local_sub
  107. counter_local += counter_local_sub
  108. if len(remove_urls_list_sub) > 0:
  109. remove_urls_list.extend(remove_urls_list_sub)
  110. return remove_urls_list, deduped_local, counter_local
  111. def find_pair_urls_parallel(args, lshcache, url_doc):
  112. start_time = time.time()
  113. f_out = open(args.output, 'wb')
  114. deduped, counter = 0, 0
  115. # compute jaccards of buckets in bin in parallel (parallelism
  116. # limited to # of bins)
  117. num_bins = len(lshcache.bins)
  118. pool = multiprocessing.Pool(num_bins)
  119. compute_jaccard_partial = partial(compute_jaccard, num_bins=num_bins, \
  120. start_time_local=start_time)
  121. # don't need to pass args and url_doc as they are already shared
  122. compute_jaccard_iter = pool.imap(compute_jaccard_partial, lshcache.bins)
  123. print("multiprocessing init took {:.2f}".format(time.time() - start_time),\
  124. flush=True)
  125. for remove_urls_list, deduped_local, counter_local in compute_jaccard_iter:
  126. deduped += deduped_local
  127. counter += counter_local
  128. write_remove_urls_list(remove_urls_list, f_out)
  129. print(' [write]> processed {} documents in {:.2f} '
  130. 'seoncds and deduped {} documents ...'.format(counter, time.time()\
  131. - start_time, deduped), flush=True)
  132. pool.close()
  133. pool.join()
  134. f_out.close()
  135. print(' Taken time for jaccard similariries {:.2f} seconds'.format(\
  136. time.time() - start_time), flush=True)
  137. def find_pair_urls_sequential(args, lshcache, url_doc):
  138. start_time = time.time()
  139. f_out = open(args.output, 'wb')
  140. deduped, counter = 0, 0
  141. for b in lshcache.bins:
  142. for bucket_id in b:
  143. if len(b[bucket_id]) <= 1:
  144. continue
  145. bucket_urls = b[bucket_id].copy()
  146. remove_urls_list_sub, deduped_local_sub, counter_local_sub = \
  147. url_pairs_to_remove(args, bucket_urls, url_doc)
  148. deduped += deduped_local_sub
  149. counter += counter_local_sub
  150. write_remove_urls_list(remove_urls_list_sub, f_out)
  151. if counter % 10000 == 0:
  152. print(' [write]> processed {} documents in {:.2f} '
  153. 'seoncds and deduped {} documents ...'.
  154. format(counter, time.time() - start_time,
  155. deduped), flush=True)
  156. f_out.close()
  157. print(' [write]> processed {} documents in {:.2f} '
  158. 'seoncds and deduped {} documents ...'.
  159. format(counter, time.time() - start_time,
  160. deduped), flush=True)
  161. if __name__ == '__main__':
  162. print('parsing the arguments ...')
  163. parser = argparse.ArgumentParser()
  164. parser.add_argument('--seed', type=int, default=1234,
  165. help='Random seed used for python, numpy')
  166. parser.add_argument('--inputs', nargs = '*', default=None, help = \
  167. 'Pairwise list of the input files and keys, '
  168. 'e.g. --inputs cc.json cc_id news.json news_id')
  169. parser.add_argument('--load-fingerprints', nargs = '*', default=None,
  170. help='Load fingerprints from a list of pickle files,'
  171. ' e.g. cc.pkl news.pkl')
  172. parser.add_argument('--save-fingerprints', type=str, default=None,
  173. help='Save the fingerprints of the inputs.')
  174. parser.add_argument('--output', type=str, default=None,
  175. help='Output file name that consists of all ids'
  176. ' with matching similarities')
  177. parser.add_argument('--jaccard', type=str, default='union',
  178. choices=['union', 'min', 'max'], help='Jaccard'\
  179. ' similarity computation')
  180. parser.add_argument('--heuristic-iter', type=int, default=1,
  181. help='Number of iterations to run the heuristics'
  182. ': use -1 for exact')
  183. parser.add_argument('--num-bands', type=int, default=10,
  184. help='Number of bands to use in cache')
  185. parser.add_argument('--num-seeds', type=int, default=100,
  186. help='Number of seeds to use for minhash. Note that'
  187. ' this value should be divisible by num-bands')
  188. parser.add_argument('--jaccard-parallel', action='store_true',
  189. help='Use this to process large number of documents.')
  190. args = parser.parse_args()
  191. print('finding possible duplicate content ...')
  192. # set seed and get an array of seeds of 100 integers
  193. np.random.seed(args.seed)
  194. seeds = np.random.randint(0, 1e6, size=args.num_seeds)
  195. # initialize minhash and lsh cache
  196. hasher = minhash.MinHasher(seeds=seeds, char_ngram=5, hashbytes=4)
  197. lshcache = cache.Cache(num_bands=args.num_bands, hasher=hasher)
  198. url_doc = {}
  199. # load fingerprints from pickle file if needed
  200. if args.load_fingerprints is not None:
  201. for count_fp, fp_file_name in enumerate(args.load_fingerprints):
  202. print("Loading fingerprints from pickle file {}".format(
  203. fp_file_name), flush=True)
  204. fp = open(fp_file_name, "rb")
  205. if count_fp == 0:
  206. # assign directory for the first pkl
  207. lshcache = pickle.load(fp)
  208. url_doc = pickle.load(fp)
  209. else:
  210. # append these to lshcache and url_doc
  211. local_lshcache = pickle.load(fp)
  212. local_url_doc = pickle.load(fp)
  213. for url in local_lshcache.fingerprints.keys():
  214. url_doc[url] = local_url_doc[url]
  215. lshcache.add_fingerprint(local_lshcache.fingerprints[url], url)
  216. fp.close()
  217. counter = 0
  218. start_time = time.time()
  219. # compute finger prints of the inputs if any
  220. # input file and the key to use as id
  221. if args.inputs is not None:
  222. print("Computing fingerprints", flush=True)
  223. assert len(args.inputs) % 2 == 0
  224. for input_file, key in zip(args.inputs[::2], args.inputs[1::2]):
  225. print(' document processing {} with key {}'.format(input_file, key),
  226. flush=True)
  227. # compute fingerprints in parallel
  228. num_workers = 40
  229. pool = multiprocessing.Pool(num_workers)
  230. fin = open(input_file, 'r', encoding='utf-8')
  231. compute_fingerprint_partial = partial(compute_fingerprint, key=key)
  232. compute_fingerprint_iter = pool.imap(compute_fingerprint_partial,
  233. fin, 512)
  234. # traverse all the texts and add fingerprints
  235. for url, text, fingerprint, flag in compute_fingerprint_iter:
  236. counter += 1
  237. if flag:
  238. url_doc[url] = text
  239. lshcache.add_fingerprint(fingerprint, url)
  240. if counter % 10000 == 0:
  241. print(' [read]> processed {} documents in {:.2f} '
  242. 'seconds ...'.format(counter, time.time() - \
  243. start_time), flush=True)
  244. fin.close()
  245. pool.close()
  246. pool.join()
  247. # Save the fingerprints if needed
  248. if args.save_fingerprints is not None:
  249. print("Saving fingerprints to pickle file {}".format(
  250. args.save_fingerprints), flush=True)
  251. with open(args.save_fingerprints, 'wb') as f_save:
  252. pickle.dump(lshcache, f_save)
  253. pickle.dump(url_doc, f_save)
  254. # compute jaccard index of the input texts and write to file if needed
  255. if args.output is not None:
  256. print("Compute jaccard similarity", flush=True)
  257. if args.jaccard_parallel:
  258. find_pair_urls_parallel(args, lshcache, url_doc)
  259. else:
  260. find_pair_urls_sequential(args, lshcache, url_doc)
  261. print('done :-)')