Quora-Duplicate-Search.py 2.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. import tensorflow as tf
  2. import tensorflow_hub as hub
  3. import matplotlib.pyplot as plt
  4. import numpy as np
  5. import pandas as pd
  6. import seaborn as sns
  7. csv_fname = "q_quora.csv"
  8. question1 = {}
  9. question2 = {}
  10. print("Loading data from {}".format(csv_fname))
  11. numLines = int(input("Enter number of lines to read: "))
  12. with open(csv_fname,'r') as f:
  13. if numLines == -1:
  14. totalLines = f.readlines()[1:]
  15. else:
  16. totalLines = f.readlines()[1:numLines]
  17. for line in totalLines:
  18. try:
  19. qid1, qid2, q1, q2 = line.strip().split(',')[1:5]
  20. question1[qid1] = q1
  21. question2[qid2] = q2
  22. except:
  23. continue
  24. print("Data loaded successfully")
  25. module_url = "https://tfhub.dev/google/universal-sentence-encoder/2"
  26. print("Loading model from {}".format(module_url))
  27. embed = hub.Module(module_url)
  28. print("Model loaded successfully")
  29. def plot_similarity(labels1, labels2, features1, features2, rotation):
  30. corr = np.inner(features1, features2)
  31. corr2 = corr.copy()
  32. corr2[corr2<0.8]=0
  33. corr2[corr2>=0.8]=1
  34. #print(corr)
  35. sns.set(font_scale=0.6)
  36. #plt.figure(figsize=(100,100))
  37. g = sns.heatmap(corr,\
  38. #xticklabels=labels1,\
  39. #yticklabels=labels2,\
  40. vmin=0,\
  41. vmax=1,\
  42. #cmap="Greys")
  43. cmap="YlOrRd")
  44. #g.set_xticklabels(labels1, rotation=rotation)
  45. g.set_title("Semantic Textual Similarity")
  46. plt.tight_layout()
  47. plt.savefig("Quora.png")
  48. plt.show()
  49. #plt.figure(figsize=(100,100))
  50. #g.set_xticklabels(labels1, rotation=rotation)
  51. similar_qid = {}
  52. for i in range(len(labels1)):
  53. for j in range(len(labels2)):
  54. if corr2[i][j] == 1:
  55. similar_qid[labels1[i]]=labels2[j]
  56. return similar_qid
  57. def run_and_plot(session_, input_tensor_, messages1_, messages2_, labels1_,labels2_, encoding_tensor):
  58. print("Embeddings questions 1")
  59. message_embeddings1_ = session_.run(encoding_tensor, feed_dict={input_tensor_: messages1_})
  60. print("Embeddings questions 2")
  61. message_embeddings2_ = session_.run(encoding_tensor, feed_dict={input_tensor_: messages2_})
  62. similar_qid = plot_similarity(labels1_,labels2_, \
  63. message_embeddings1_,\
  64. message_embeddings2_, 90)
  65. return similar_qid
  66. similarity_input_placeholder = tf.placeholder(tf.string, shape=(None))
  67. similarity_message_encodings = embed(similarity_input_placeholder)
  68. with tf.Session() as session:
  69. session.run(tf.global_variables_initializer())
  70. session.run(tf.tables_initializer())
  71. similar_qid = run_and_plot(session, similarity_input_placeholder,\
  72. list(question1.values()),\
  73. list(question2.values()),\
  74. list(question1.keys()),\
  75. list(question2.keys()),\
  76. similarity_message_encodings)
  77. with open("similarity-results.txt",'w') as f:
  78. for i in list(similar_qid.keys()):
  79. f.write("{},{}\n=======================\n".format(question1[i], question2[similar_qid[i]]))