jarowinkler.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. from .string_distance import NormalizedStringDistance
  2. from .string_similarity import NormalizedStringSimilarity
  3. class JaroWinkler(NormalizedStringSimilarity, NormalizedStringDistance):
  4. def __init__(self, threshold=0.7):
  5. self.threshold = threshold
  6. self.three = 3
  7. self.jw_coef = 0.1
  8. def get_threshold(self):
  9. return self.threshold
  10. def similarity(self, s0, s1):
  11. if s0 is None:
  12. raise TypeError("Argument s0 is NoneType.")
  13. if s1 is None:
  14. raise TypeError("Argument s1 is NoneType.")
  15. if s0 == s1:
  16. return 1.0
  17. mtp = self.matches(s0, s1)
  18. m = mtp[0]
  19. if m == 0:
  20. return 0.0
  21. j = (m / len(s0) + m / len(s1) + (m - mtp[1]) / m) / self.three
  22. jw = j
  23. if j > self.get_threshold():
  24. jw = j + min(self.jw_coef, 1.0 / mtp[self.three]) * mtp[2] * (1 - j)
  25. return jw
  26. def distance(self, s0, s1):
  27. return 1.0 - self.similarity(s0, s1)
  28. @staticmethod
  29. def matches(s0, s1):
  30. if len(s0) > len(s1):
  31. max_str = s0
  32. min_str = s1
  33. else:
  34. max_str = s1
  35. min_str = s0
  36. ran = int(max(len(max_str) / 2 - 1, 0))
  37. match_indexes = [-1] * len(min_str)
  38. match_flags = [False] * len(max_str)
  39. matches = 0
  40. for mi in range(len(min_str)):
  41. c1 = min_str[mi]
  42. for xi in range(max(mi - ran, 0), min(mi + ran + 1, len(max_str))):
  43. if not match_flags[xi] and c1 == max_str[xi]:
  44. match_indexes[mi] = xi
  45. match_flags[xi] = True
  46. matches += 1
  47. break
  48. ms0, ms1 = [0] * matches, [0] * matches
  49. si = 0
  50. for i in range(len(min_str)):
  51. if match_indexes[i] != -1:
  52. ms0[si] = min_str[i]
  53. si += 1
  54. si = 0
  55. for j in range(len(max_str)):
  56. if match_flags[j]:
  57. ms1[si] = max_str[j]
  58. si += 1
  59. transpositions = 0
  60. for mi in range(len(ms0)):
  61. if ms0[mi] != ms1[mi]:
  62. transpositions += 1
  63. prefix = 0
  64. for mi in range(len(min_str)):
  65. if s0[mi] == s1[mi]:
  66. prefix += 1
  67. else:
  68. break
  69. return [matches, int(transpositions / 2), prefix, len(max_str)]