weighted_levenshtein.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. from .string_distance import StringDistance
  2. class CharacterInsDelInterface:
  3. def deletion_cost(self, c):
  4. raise NotImplementedError()
  5. def insertion_cost(self, c):
  6. raise NotImplementedError()
  7. class CharacterSubstitutionInterface:
  8. def cost(self, c0, c1):
  9. raise NotImplementedError()
  10. class WeightedLevenshtein(StringDistance):
  11. def __init__(self, character_substitution, character_ins_del=None):
  12. self.character_ins_del = character_ins_del
  13. if character_substitution is None:
  14. raise TypeError("Argument character_substitution is NoneType.")
  15. self.character_substitution = character_substitution
  16. def distance(self, s0, s1):
  17. if s0 is None:
  18. raise TypeError("Argument s0 is NoneType.")
  19. if s1 is None:
  20. raise TypeError("Argument s1 is NoneType.")
  21. if s0 == s1:
  22. return 0.0
  23. if len(s0) == 0:
  24. return len(s1)
  25. if len(s1) == 0:
  26. return len(s0)
  27. v0, v1 = [0.0] * (len(s1) + 1), [0.0] * (len(s1) + 1)
  28. v0[0] = 0
  29. for i in range(1, len(v0)):
  30. v0[i] = v0[i - 1] + self._insertion_cost(s1[i - 1])
  31. for i in range(len(s0)):
  32. s1i = s0[i]
  33. deletion_cost = self._deletion_cost(s1i)
  34. v1[0] = v0[0] + deletion_cost
  35. for j in range(len(s1)):
  36. s2j = s1[j]
  37. cost = 0
  38. if s1i != s2j:
  39. cost = self.character_substitution.cost(s1i, s2j)
  40. insertion_cost = self._insertion_cost(s2j)
  41. v1[j + 1] = min(v1[j] + insertion_cost, v0[j + 1] + deletion_cost, v0[j] + cost)
  42. v0, v1 = v1, v0
  43. return v0[len(s1)]
  44. def _insertion_cost(self, c):
  45. if self.character_ins_del is None:
  46. return 1.0
  47. return self.character_ins_del.insertion_cost(c)
  48. def _deletion_cost(self, c):
  49. if self.character_ins_del is None:
  50. return 1.0
  51. return self.character_ins_del.deletion_cost(c)