cosine.py 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. import math
  2. from .shingle_based import ShingleBased
  3. from .string_distance import NormalizedStringDistance
  4. from .string_similarity import NormalizedStringSimilarity
  5. class Cosine(ShingleBased, NormalizedStringDistance,
  6. NormalizedStringSimilarity):
  7. def __init__(self, k):
  8. super().__init__(k)
  9. def distance(self, s0, s1):
  10. return 1.0 - self.similarity(s0, s1)
  11. def similarity(self, s0, s1):
  12. if s0 is None:
  13. raise TypeError("Argument s0 is NoneType.")
  14. if s1 is None:
  15. raise TypeError("Argument s1 is NoneType.")
  16. if s0 == s1:
  17. return 1.0
  18. if len(s0) < self.get_k() or len(s1) < self.get_k():
  19. return 0.0
  20. profile0 = self.get_profile(s0)
  21. profile1 = self.get_profile(s1)
  22. return self._dot_product(profile0, profile1) / (
  23. self._norm(profile0) * self._norm(profile1))
  24. def similarity_profiles(self, profile0, profile1):
  25. return self._dot_product(profile0, profile1) / (
  26. self._norm(profile0) * self._norm(profile1))
  27. @staticmethod
  28. def _dot_product(profile0, profile1):
  29. small = profile1
  30. large = profile0
  31. if len(profile0) < len(profile1):
  32. small = profile0
  33. large = profile1
  34. agg = 0.0
  35. for k, v in small.items():
  36. i = large.get(k)
  37. if not i:
  38. continue
  39. agg += 1.0 * v * i
  40. return agg
  41. @staticmethod
  42. def _norm(profile):
  43. agg = 0.0
  44. for k, v in profile.items():
  45. agg += 1.0 * v * v
  46. return math.sqrt(agg)