weighted_levenshtein.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. # Copyright (c) 2018 luozhouyang
  2. #
  3. # Permission is hereby granted, free of charge, to any person obtaining a copy
  4. # of this software and associated documentation files (the "Software"), to deal
  5. # in the Software without restriction, including without limitation the rights
  6. # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  7. # copies of the Software, and to permit persons to whom the Software is
  8. # furnished to do so, subject to the following conditions:
  9. #
  10. # The above copyright notice and this permission notice shall be included in all
  11. # copies or substantial portions of the Software.
  12. #
  13. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  14. # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  15. # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  16. # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  17. # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  18. # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  19. # SOFTWARE.
  20. from .string_distance import StringDistance
  21. class CharacterInsDelInterface:
  22. def deletion_cost(self, c):
  23. raise NotImplementedError()
  24. def insertion_cost(self, c):
  25. raise NotImplementedError()
  26. class CharacterSubstitutionInterface:
  27. def cost(self, c0, c1):
  28. raise NotImplementedError()
  29. class WeightedLevenshtein(StringDistance):
  30. def __init__(self, character_substitution, character_ins_del=None):
  31. self.character_ins_del = character_ins_del
  32. if character_substitution is None:
  33. raise TypeError("Argument character_substitution is NoneType.")
  34. self.character_substitution = character_substitution
  35. def distance(self, s0, s1):
  36. if s0 is None:
  37. raise TypeError("Argument s0 is NoneType.")
  38. if s1 is None:
  39. raise TypeError("Argument s1 is NoneType.")
  40. if s0 == s1:
  41. return 0.0
  42. if len(s0) == 0:
  43. return len(s1)
  44. if len(s1) == 0:
  45. return len(s0)
  46. v0, v1 = [0.0] * (len(s1) + 1), [0.0] * (len(s1) + 1)
  47. v0[0] = 0
  48. for i in range(1, len(v0)):
  49. v0[i] = v0[i - 1] + self._insertion_cost(s1[i - 1])
  50. for i in range(len(s0)):
  51. s1i = s0[i]
  52. deletion_cost = self._deletion_cost(s1i)
  53. v1[0] = v0[0] + deletion_cost
  54. for j in range(len(s1)):
  55. s2j = s1[j]
  56. cost = 0
  57. if s1i != s2j:
  58. cost = self.character_substitution.cost(s1i, s2j)
  59. insertion_cost = self._insertion_cost(s2j)
  60. v1[j + 1] = min(v1[j] + insertion_cost, v0[j + 1] + deletion_cost, v0[j] + cost)
  61. v0, v1 = v1, v0
  62. return v0[len(s1)]
  63. def _insertion_cost(self, c):
  64. if self.character_ins_del is None:
  65. return 1.0
  66. return self.character_ins_del.insertion_cost(c)
  67. def _deletion_cost(self, c):
  68. if self.character_ins_del is None:
  69. return 1.0
  70. return self.character_ins_del.deletion_cost(c)