Parcourir la source

Simplify weight calculation

luozhouyang il y a 5 ans
Parent
commit
f37e391ea0
2 fichiers modifiés avec 20 ajouts et 36 suppressions
  1. 18 28
      strsimpy/weighted_levenshtein.py
  2. 2 8
      strsimpy/weighted_levenshtein_test.py

+ 18 - 28
strsimpy/weighted_levenshtein.py

@@ -21,28 +21,28 @@
 from .string_distance import StringDistance
 
 
-class CharacterInsDelInterface:
+def default_insertion_cost(char):
+    return 1.0
 
-    def deletion_cost(self, c):
-        raise NotImplementedError()
 
-    def insertion_cost(self, c):
-        raise NotImplementedError()
+def default_deletion_cost(char):
+    return 1.0
 
 
-class CharacterSubstitutionInterface:
-
-    def cost(self, c0, c1):
-        raise NotImplementedError()
+def default_substitution_cost(char_a, char_b):
+    return 1.0
 
 
 class WeightedLevenshtein(StringDistance):
 
-    def __init__(self, character_substitution, character_ins_del=None):
-        self.character_ins_del = character_ins_del
-        if character_substitution is None:
-            raise TypeError("Argument character_substitution is NoneType.")
-        self.character_substitution = character_substitution
+    def __init__(self,
+                 substitution_cost_fn=default_substitution_cost,
+                 insertion_cost_fn=default_insertion_cost,
+                 deletion_cost_fn=default_deletion_cost,
+                 ):
+        self.substitution_cost_fn = substitution_cost_fn
+        self.insertion_cost_fn = insertion_cost_fn
+        self.deletion_cost_fn = deletion_cost_fn
 
     def distance(self, s0, s1):
         if s0 is None:
@@ -60,30 +60,20 @@ class WeightedLevenshtein(StringDistance):
 
         v0[0] = 0
         for i in range(1, len(v0)):
-            v0[i] = v0[i - 1] + self._insertion_cost(s1[i - 1])
+            v0[i] = v0[i - 1] + self.insertion_cost_fn(s1[i - 1])
 
         for i in range(len(s0)):
             s1i = s0[i]
-            deletion_cost = self._deletion_cost(s1i)
+            deletion_cost = self.deletion_cost_fn(s1i)
             v1[0] = v0[0] + deletion_cost
 
             for j in range(len(s1)):
                 s2j = s1[j]
                 cost = 0
                 if s1i != s2j:
-                    cost = self.character_substitution.cost(s1i, s2j)
-                insertion_cost = self._insertion_cost(s2j)
+                    cost = self.substitution_cost_fn(s1i, s2j)
+                insertion_cost = self.insertion_cost_fn(s2j)
                 v1[j + 1] = min(v1[j] + insertion_cost, v0[j + 1] + deletion_cost, v0[j] + cost)
             v0, v1 = v1, v0
 
         return v0[len(s1)]
-
-    def _insertion_cost(self, c):
-        if self.character_ins_del is None:
-            return 1.0
-        return self.character_ins_del.insertion_cost(c)
-
-    def _deletion_cost(self, c):
-        if self.character_ins_del is None:
-            return 1.0
-        return self.character_ins_del.deletion_cost(c)

+ 2 - 8
strsimpy/weighted_levenshtein_test.py

@@ -20,19 +20,13 @@
 
 import unittest
 
-from .weighted_levenshtein import WeightedLevenshtein, CharacterSubstitutionInterface
-
-
-class CharSub(CharacterSubstitutionInterface):
-
-    def cost(self, c0, c1):
-        return 1.0
+from .weighted_levenshtein import WeightedLevenshtein
 
 
 class TestWeightedLevenshtein(unittest.TestCase):
 
     def test_weighted_levenshtein(self):
-        a = WeightedLevenshtein(character_substitution=CharSub())
+        a = WeightedLevenshtein()
         s0 = ""
         s1 = ""
         s2 = "上海"