cosine.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. import math
  2. from .shingle_based import ShingleBased
  3. from .string_distance import NormalizedStringDistance
  4. from .string_similarity import NormalizedStringSimilarity
  5. class Cosine(ShingleBased, NormalizedStringDistance, NormalizedStringSimilarity):
  6. def __init__(self, k):
  7. super().__init__(k)
  8. def distance(self, s0, s1):
  9. return 1.0 - self.similarity(s0, s1)
  10. def similarity(self, s0, s1):
  11. if s0 is None:
  12. raise TypeError("Argument s0 is NoneType.")
  13. if s1 is None:
  14. raise TypeError("Argument s1 is NoneType.")
  15. if s0 == s1:
  16. return 1.0
  17. if len(s0) < self.get_k() or len(s1) < self.get_k():
  18. return 0.0
  19. profile0 = self.get_profile(s0)
  20. profile1 = self.get_profile(s1)
  21. return self._dot_product(profile0, profile1) / (self._norm(profile0) * self._norm(profile1))
  22. def similarity_profiles(self, profile0, profile1):
  23. return self._dot_product(profile0, profile1) / (self._norm(profile0) * self._norm(profile1))
  24. @staticmethod
  25. def _dot_product(profile0, profile1):
  26. small = profile1
  27. large = profile0
  28. if len(profile0) < len(profile1):
  29. small = profile0
  30. large = profile1
  31. agg = 0.0
  32. for k, v in small.items():
  33. i = large.get(k)
  34. if not i:
  35. continue
  36. agg += 1.0 * v * i
  37. return agg
  38. @staticmethod
  39. def _norm(profile):
  40. agg = 0.0
  41. for k, v in profile.items():
  42. agg += 1.0 * v * v
  43. return math.sqrt(agg)
  44. if __name__ == "__main__":
  45. cosine = Cosine(1)
  46. str0 = "上海市宝山区 你好"
  47. str1 = "上海浦东新区 你好吗"
  48. d = cosine.distance(str0, str1)
  49. s = cosine.similarity(str0, str1)
  50. print(d)
  51. print(s)